"""Tests for paperlib database manager.""" import shutil from pathlib import Path import pytest from paperlib.config import LibraryPaths from paperlib.index import DatabaseManager from paperlib.models import ConversionStatus, PaperMetadata, SourceType, SummaryStatus class TestDatabaseManager: """Test DatabaseManager functionality.""" @pytest.fixture def temp_library(self): """Create a temporary library for testing.""" temp_dir = Path("./.tmp") / f"test_db_{hash(self)}" temp_dir.mkdir(parents=True, exist_ok=True) library_paths = LibraryPaths.from_root(temp_dir) library_paths.create_directories() yield library_paths # Cleanup if temp_dir.exists(): shutil.rmtree(temp_dir) @pytest.fixture def db_manager(self, temp_library): """Create a database manager for testing.""" manager = DatabaseManager(temp_library) manager.initialize_database() return manager @pytest.fixture def sample_metadata(self): """Create sample paper metadata for testing.""" return PaperMetadata( paper_id="test-paper-1", source_type=SourceType.LOCAL, source_id=None, title="A Test Paper on Machine Learning", authors=["Alice Smith", "Bob Jones", "Charlie Brown"], categories=["cs.AI", "stat.ML"], tags=["machine-learning", "neural-networks", "test"], notes="This is a test paper for unit testing.", pdf_path="papers/local/test-paper-1/source.pdf", paper_md_path="papers/local/test-paper-1/paper.md", summary_json_path="papers/local/test-paper-1/summary.json", summary_md_path="papers/local/test-paper-1/summary.md", ) def test_initialize_database(self, temp_library): """Test database initialization.""" db_manager = DatabaseManager(temp_library) # Database file shouldn't exist initially assert not db_manager.db_path.exists() # Initialize database db_manager.initialize_database() # Database file should now exist assert db_manager.db_path.exists() # Should be able to connect and query with db_manager._get_connection() as conn: cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = [row[0] for row in cursor.fetchall()] assert "papers" in tables assert "papers_fts" in tables def test_index_paper(self, db_manager, sample_metadata): """Test indexing a paper.""" # Index the paper db_manager.index_paper(sample_metadata) # Verify it was indexed paper = db_manager.get_paper(sample_metadata.paper_id) assert paper is not None assert paper["paper_id"] == "test-paper-1" assert paper["title"] == "A Test Paper on Machine Learning" assert paper["source_type"] == "local" def test_get_paper(self, db_manager, sample_metadata): """Test getting a paper by ID.""" # Initially not found paper = db_manager.get_paper("nonexistent") assert paper is None # Index a paper db_manager.index_paper(sample_metadata) # Now it should be found paper = db_manager.get_paper(sample_metadata.paper_id) assert paper is not None assert paper["paper_id"] == sample_metadata.paper_id assert paper["title"] == sample_metadata.title def test_remove_paper(self, db_manager, sample_metadata): """Test removing a paper from index.""" # Index a paper db_manager.index_paper(sample_metadata) assert db_manager.get_paper(sample_metadata.paper_id) is not None # Remove it result = db_manager.remove_paper(sample_metadata.paper_id) assert result is True # Verify it's gone assert db_manager.get_paper(sample_metadata.paper_id) is None # Removing again should return False result = db_manager.remove_paper(sample_metadata.paper_id) assert result is False def test_list_papers(self, db_manager): """Test listing papers with filtering.""" # Create multiple test papers paper1 = PaperMetadata( paper_id="paper-1", source_type=SourceType.LOCAL, title="Local Paper", conversion_status=ConversionStatus.PENDING, summary_status=SummaryStatus.NOT_REQUESTED, ) paper2 = PaperMetadata( paper_id="paper-2", source_type=SourceType.ARXIV, title="ArXiv Paper", conversion_status=ConversionStatus.SUCCESS, summary_status=SummaryStatus.PENDING, ) # Index papers db_manager.index_paper(paper1) db_manager.index_paper(paper2) # List all papers all_papers = list(db_manager.list_papers()) assert len(all_papers) == 2 # Filter by source type local_papers = list(db_manager.list_papers(source_type=SourceType.LOCAL)) assert len(local_papers) == 1 assert local_papers[0]["source_type"] == "local" arxiv_papers = list(db_manager.list_papers(source_type=SourceType.ARXIV)) assert len(arxiv_papers) == 1 assert arxiv_papers[0]["source_type"] == "arxiv" # Filter by conversion status pending_papers = list( db_manager.list_papers(conversion_status=ConversionStatus.PENDING) ) assert len(pending_papers) == 1 assert pending_papers[0]["conversion_status"] == "pending" # Test limit and offset limited_papers = list(db_manager.list_papers(limit=1)) assert len(limited_papers) == 1 def test_search_papers_fts(self, db_manager, sample_metadata): """Test full-text search.""" # Index a paper db_manager.index_paper(sample_metadata) # Search by title words results = list(db_manager.search_papers("Machine Learning")) assert len(results) == 1 assert results[0]["paper_id"] == sample_metadata.paper_id # Search by author results = list(db_manager.search_papers("Alice Smith")) assert len(results) == 1 # Search by tag (quoted for FTS) results = list(db_manager.search_papers('"neural-networks"')) assert len(results) == 1 # Search for non-existent term results = list(db_manager.search_papers("nonexistent")) assert len(results) == 0 def test_search_by_field(self, db_manager, sample_metadata): """Test searching by specific field.""" # Index a paper db_manager.index_paper(sample_metadata) # Search by title results = list(db_manager.search_by_field("title", "Machine Learning")) assert len(results) == 1 # Search by author list results = list(db_manager.search_by_field("author_list", "Alice")) assert len(results) == 1 # Exact match results = list( db_manager.search_by_field( "title", "A Test Paper on Machine Learning", exact_match=True ) ) assert len(results) == 1 results = list( db_manager.search_by_field("title", "Partial Title", exact_match=True) ) assert len(results) == 0 # Invalid field should raise error with pytest.raises(ValueError): list(db_manager.search_by_field("invalid_field", "test")) def test_get_statistics(self, db_manager): """Test getting library statistics.""" # Initially empty stats = db_manager.get_statistics() assert stats["total_papers"] == 0 assert stats["by_source_type"] == {} # Add some papers paper1 = PaperMetadata( paper_id="paper-1", source_type=SourceType.LOCAL, title="Local Paper", conversion_status=ConversionStatus.PENDING, ) paper2 = PaperMetadata( paper_id="paper-2", source_type=SourceType.ARXIV, title="ArXiv Paper 1", conversion_status=ConversionStatus.SUCCESS, ) paper3 = PaperMetadata( paper_id="paper-3", source_type=SourceType.ARXIV, title="ArXiv Paper 2", conversion_status=ConversionStatus.FAILED, ) db_manager.index_paper(paper1) db_manager.index_paper(paper2) db_manager.index_paper(paper3) # Check updated statistics stats = db_manager.get_statistics() assert stats["total_papers"] == 3 assert stats["by_source_type"]["local"] == 1 assert stats["by_source_type"]["arxiv"] == 2 assert stats["by_conversion_status"]["pending"] == 1 assert stats["by_conversion_status"]["success"] == 1 assert stats["by_conversion_status"]["failed"] == 1 def test_reindex_from_storage(self, db_manager, temp_library): """Test reindexing from storage files.""" from paperlib.storage import PaperStorageManager # Create storage manager and add some papers storage_manager = PaperStorageManager(temp_library) # Create a mock PDF file pdf_file = Path("./.tmp") / "test.pdf" with pdf_file.open("wb") as f: f.write(b"%PDF-1.4\n%%EOF\n") try: # Store papers in storage metadata1 = storage_manager.store_paper( pdf_path=pdf_file, source_type=SourceType.LOCAL, title="Paper 1" ) metadata2 = storage_manager.store_paper( pdf_path=pdf_file, source_type=SourceType.ARXIV, source_id="2212.06340", title="Paper 2", ) # Database should initially be empty stats = db_manager.get_statistics() assert stats["total_papers"] == 0 # Reindex from storage success_count, error_count = db_manager.reindex_from_storage( storage_manager ) # Check results assert success_count == 2 assert error_count == 0 # Verify papers are now in database stats = db_manager.get_statistics() assert stats["total_papers"] == 2 paper1 = db_manager.get_paper(metadata1.paper_id) assert paper1 is not None assert paper1["title"] == "Paper 1" paper2 = db_manager.get_paper(metadata2.paper_id) assert paper2 is not None assert paper2["title"] == "Paper 2" finally: if pdf_file.exists(): pdf_file.unlink()