Files
paperlib/tests/test_database.py
2026-04-17 15:56:04 -04:00

313 lines
11 KiB
Python

"""Tests for paperlib database manager."""
import shutil
from pathlib import Path
import pytest
from paperlib.config import LibraryPaths
from paperlib.index import DatabaseManager
from paperlib.models import ConversionStatus, PaperMetadata, SourceType, SummaryStatus
class TestDatabaseManager:
"""Test DatabaseManager functionality."""
@pytest.fixture
def temp_library(self):
"""Create a temporary library for testing."""
temp_dir = Path("./.tmp") / f"test_db_{hash(self)}"
temp_dir.mkdir(parents=True, exist_ok=True)
library_paths = LibraryPaths.from_root(temp_dir)
library_paths.create_directories()
yield library_paths
# Cleanup
if temp_dir.exists():
shutil.rmtree(temp_dir)
@pytest.fixture
def db_manager(self, temp_library):
"""Create a database manager for testing."""
manager = DatabaseManager(temp_library)
manager.initialize_database()
return manager
@pytest.fixture
def sample_metadata(self):
"""Create sample paper metadata for testing."""
return PaperMetadata(
paper_id="test-paper-1",
source_type=SourceType.LOCAL,
source_id=None,
title="A Test Paper on Machine Learning",
authors=["Alice Smith", "Bob Jones", "Charlie Brown"],
categories=["cs.AI", "stat.ML"],
tags=["machine-learning", "neural-networks", "test"],
notes="This is a test paper for unit testing.",
pdf_path="papers/local/test-paper-1/source.pdf",
paper_md_path="papers/local/test-paper-1/paper.md",
summary_json_path="papers/local/test-paper-1/summary.json",
summary_md_path="papers/local/test-paper-1/summary.md",
)
def test_initialize_database(self, temp_library):
"""Test database initialization."""
db_manager = DatabaseManager(temp_library)
# Database file shouldn't exist initially
assert not db_manager.db_path.exists()
# Initialize database
db_manager.initialize_database()
# Database file should now exist
assert db_manager.db_path.exists()
# Should be able to connect and query
with db_manager._get_connection() as conn:
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
assert "papers" in tables
assert "papers_fts" in tables
def test_index_paper(self, db_manager, sample_metadata):
"""Test indexing a paper."""
# Index the paper
db_manager.index_paper(sample_metadata)
# Verify it was indexed
paper = db_manager.get_paper(sample_metadata.paper_id)
assert paper is not None
assert paper["paper_id"] == "test-paper-1"
assert paper["title"] == "A Test Paper on Machine Learning"
assert paper["source_type"] == "local"
def test_get_paper(self, db_manager, sample_metadata):
"""Test getting a paper by ID."""
# Initially not found
paper = db_manager.get_paper("nonexistent")
assert paper is None
# Index a paper
db_manager.index_paper(sample_metadata)
# Now it should be found
paper = db_manager.get_paper(sample_metadata.paper_id)
assert paper is not None
assert paper["paper_id"] == sample_metadata.paper_id
assert paper["title"] == sample_metadata.title
def test_remove_paper(self, db_manager, sample_metadata):
"""Test removing a paper from index."""
# Index a paper
db_manager.index_paper(sample_metadata)
assert db_manager.get_paper(sample_metadata.paper_id) is not None
# Remove it
result = db_manager.remove_paper(sample_metadata.paper_id)
assert result is True
# Verify it's gone
assert db_manager.get_paper(sample_metadata.paper_id) is None
# Removing again should return False
result = db_manager.remove_paper(sample_metadata.paper_id)
assert result is False
def test_list_papers(self, db_manager):
"""Test listing papers with filtering."""
# Create multiple test papers
paper1 = PaperMetadata(
paper_id="paper-1",
source_type=SourceType.LOCAL,
title="Local Paper",
conversion_status=ConversionStatus.PENDING,
summary_status=SummaryStatus.NOT_REQUESTED,
)
paper2 = PaperMetadata(
paper_id="paper-2",
source_type=SourceType.ARXIV,
title="ArXiv Paper",
conversion_status=ConversionStatus.SUCCESS,
summary_status=SummaryStatus.PENDING,
)
# Index papers
db_manager.index_paper(paper1)
db_manager.index_paper(paper2)
# List all papers
all_papers = list(db_manager.list_papers())
assert len(all_papers) == 2
# Filter by source type
local_papers = list(db_manager.list_papers(source_type=SourceType.LOCAL))
assert len(local_papers) == 1
assert local_papers[0]["source_type"] == "local"
arxiv_papers = list(db_manager.list_papers(source_type=SourceType.ARXIV))
assert len(arxiv_papers) == 1
assert arxiv_papers[0]["source_type"] == "arxiv"
# Filter by conversion status
pending_papers = list(
db_manager.list_papers(conversion_status=ConversionStatus.PENDING)
)
assert len(pending_papers) == 1
assert pending_papers[0]["conversion_status"] == "pending"
# Test limit and offset
limited_papers = list(db_manager.list_papers(limit=1))
assert len(limited_papers) == 1
def test_search_papers_fts(self, db_manager, sample_metadata):
"""Test full-text search."""
# Index a paper
db_manager.index_paper(sample_metadata)
# Search by title words
results = list(db_manager.search_papers("Machine Learning"))
assert len(results) == 1
assert results[0]["paper_id"] == sample_metadata.paper_id
# Search by author
results = list(db_manager.search_papers("Alice Smith"))
assert len(results) == 1
# Search by tag (quoted for FTS)
results = list(db_manager.search_papers('"neural-networks"'))
assert len(results) == 1
# Search for non-existent term
results = list(db_manager.search_papers("nonexistent"))
assert len(results) == 0
def test_search_by_field(self, db_manager, sample_metadata):
"""Test searching by specific field."""
# Index a paper
db_manager.index_paper(sample_metadata)
# Search by title
results = list(db_manager.search_by_field("title", "Machine Learning"))
assert len(results) == 1
# Search by author list
results = list(db_manager.search_by_field("author_list", "Alice"))
assert len(results) == 1
# Exact match
results = list(
db_manager.search_by_field(
"title", "A Test Paper on Machine Learning", exact_match=True
)
)
assert len(results) == 1
results = list(
db_manager.search_by_field("title", "Partial Title", exact_match=True)
)
assert len(results) == 0
# Invalid field should raise error
with pytest.raises(ValueError):
list(db_manager.search_by_field("invalid_field", "test"))
def test_get_statistics(self, db_manager):
"""Test getting library statistics."""
# Initially empty
stats = db_manager.get_statistics()
assert stats["total_papers"] == 0
assert stats["by_source_type"] == {}
# Add some papers
paper1 = PaperMetadata(
paper_id="paper-1",
source_type=SourceType.LOCAL,
title="Local Paper",
conversion_status=ConversionStatus.PENDING,
)
paper2 = PaperMetadata(
paper_id="paper-2",
source_type=SourceType.ARXIV,
title="ArXiv Paper 1",
conversion_status=ConversionStatus.SUCCESS,
)
paper3 = PaperMetadata(
paper_id="paper-3",
source_type=SourceType.ARXIV,
title="ArXiv Paper 2",
conversion_status=ConversionStatus.FAILED,
)
db_manager.index_paper(paper1)
db_manager.index_paper(paper2)
db_manager.index_paper(paper3)
# Check updated statistics
stats = db_manager.get_statistics()
assert stats["total_papers"] == 3
assert stats["by_source_type"]["local"] == 1
assert stats["by_source_type"]["arxiv"] == 2
assert stats["by_conversion_status"]["pending"] == 1
assert stats["by_conversion_status"]["success"] == 1
assert stats["by_conversion_status"]["failed"] == 1
def test_reindex_from_storage(self, db_manager, temp_library):
"""Test reindexing from storage files."""
from paperlib.storage import PaperStorageManager
# Create storage manager and add some papers
storage_manager = PaperStorageManager(temp_library)
# Create a mock PDF file
pdf_file = Path("./.tmp") / "test.pdf"
with pdf_file.open("wb") as f:
f.write(b"%PDF-1.4\n%%EOF\n")
try:
# Store papers in storage
metadata1 = storage_manager.store_paper(
pdf_path=pdf_file, source_type=SourceType.LOCAL, title="Paper 1"
)
metadata2 = storage_manager.store_paper(
pdf_path=pdf_file,
source_type=SourceType.ARXIV,
source_id="2212.06340",
title="Paper 2",
)
# Database should initially be empty
stats = db_manager.get_statistics()
assert stats["total_papers"] == 0
# Reindex from storage
success_count, error_count = db_manager.reindex_from_storage(
storage_manager
)
# Check results
assert success_count == 2
assert error_count == 0
# Verify papers are now in database
stats = db_manager.get_statistics()
assert stats["total_papers"] == 2
paper1 = db_manager.get_paper(metadata1.paper_id)
assert paper1 is not None
assert paper1["title"] == "Paper 1"
paper2 = db_manager.get_paper(metadata2.paper_id)
assert paper2 is not None
assert paper2["title"] == "Paper 2"
finally:
if pdf_file.exists():
pdf_file.unlink()