diff --git a/src/documents/ai_deletion_manager.py b/src/documents/ai_deletion_manager.py index aab51a7f6..21848b813 100644 --- a/src/documents/ai_deletion_manager.py +++ b/src/documents/ai_deletion_manager.py @@ -14,14 +14,10 @@ According to agents.md requirements: from __future__ import annotations import logging -from typing import TYPE_CHECKING from typing import Any from django.contrib.auth.models import User -if TYPE_CHECKING: - pass - logger = logging.getLogger("paperless.ai_deletion") diff --git a/src/documents/ai_scanner.py b/src/documents/ai_scanner.py index 422aec5f5..9cf2287fe 100644 --- a/src/documents/ai_scanner.py +++ b/src/documents/ai_scanner.py @@ -29,7 +29,6 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING from typing import Any -from typing import Dict from typing import TypedDict from django.conf import settings @@ -142,34 +141,34 @@ class AIScanResult: """ # Convert internal tuple format to TypedDict format result: AIScanResultDict = { - 'tags': [{'tag_id': tag_id, 'confidence': conf} for tag_id, conf in self.tags], - 'custom_fields': { - field_id: {'value': value, 'confidence': conf} + "tags": [{"tag_id": tag_id, "confidence": conf} for tag_id, conf in self.tags], + "custom_fields": { + field_id: {"value": value, "confidence": conf} for field_id, (value, conf) in self.custom_fields.items() }, - 'workflows': [{'workflow_id': wf_id, 'confidence': conf} for wf_id, conf in self.workflows], - 'extracted_entities': self.extracted_entities, - 'metadata': self.metadata, + "workflows": [{"workflow_id": wf_id, "confidence": conf} for wf_id, conf in self.workflows], + "extracted_entities": self.extracted_entities, + "metadata": self.metadata, } # Add optional fields only if present if self.correspondent: - result['correspondent'] = { - 'correspondent_id': self.correspondent[0], - 'confidence': self.correspondent[1], + result["correspondent"] = { + "correspondent_id": self.correspondent[0], + "confidence": self.correspondent[1], } if self.document_type: - result['document_type'] = { - 'type_id': self.document_type[0], - 'confidence': self.document_type[1], + result["document_type"] = { + "type_id": self.document_type[0], + "confidence": self.document_type[1], } if self.storage_path: - result['storage_path'] = { - 'path_id': self.storage_path[0], - 'confidence': self.storage_path[1], + result["storage_path"] = { + "path_id": self.storage_path[0], + "confidence": self.storage_path[1], } if self.title_suggestion: - result['title_suggestion'] = self.title_suggestion + result["title_suggestion"] = self.title_suggestion return result @@ -257,14 +256,14 @@ class AIDocumentScanner: if self._classifier is None and self.ml_enabled: try: from documents.ml.classifier import TransformerDocumentClassifier - + # Get model name from settings model_name = getattr( - settings, + settings, "PAPERLESS_ML_CLASSIFIER_MODEL", "distilbert-base-uncased", ) - + self._classifier = TransformerDocumentClassifier( model_name=model_name, use_cache=True, @@ -291,14 +290,14 @@ class AIDocumentScanner: if self._semantic_search is None and self.ml_enabled: try: from documents.ml.semantic_search import SemanticSearch - + # Get cache directory from settings cache_dir = getattr( settings, "PAPERLESS_ML_MODEL_CACHE", None, ) - + self._semantic_search = SemanticSearch( cache_dir=cache_dir, use_cache=True, @@ -1004,13 +1003,13 @@ class AIDocumentScanner: import time logger.info("Starting ML model warm-up...") start_time = time.time() - + from documents.ml.model_cache import ModelCacheManager cache_manager = ModelCacheManager.get_instance() - + # Define model loaders model_loaders = {} - + # Classifier if self.ml_enabled: def load_classifier(): @@ -1025,14 +1024,14 @@ class AIDocumentScanner: use_cache=True, ) model_loaders["classifier"] = load_classifier - + # NER if self.ml_enabled: def load_ner(): from documents.ml.ner import DocumentNER return DocumentNER(use_cache=True) model_loaders["ner"] = load_ner - + # Semantic Search if self.ml_enabled: def load_semantic(): @@ -1040,21 +1039,21 @@ class AIDocumentScanner: cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None) return SemanticSearch(cache_dir=cache_dir, use_cache=True) model_loaders["semantic_search"] = load_semantic - + # Table Extractor if self.advanced_ocr_enabled: def load_table(): from documents.ocr.table_extractor import TableExtractor return TableExtractor() model_loaders["table_extractor"] = load_table - + # Warm up all models cache_manager.warm_up(model_loaders) - + warm_up_time = time.time() - start_time logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s") - def get_cache_metrics(self) -> Dict[str, Any]: + def get_cache_metrics(self) -> dict[str, Any]: """ Get cache performance metrics. @@ -1062,7 +1061,7 @@ class AIDocumentScanner: Dictionary with cache statistics """ from documents.ml.model_cache import ModelCacheManager - + try: cache_manager = ModelCacheManager.get_instance() return cache_manager.get_metrics() @@ -1075,7 +1074,7 @@ class AIDocumentScanner: def clear_cache(self) -> None: """Clear all model caches.""" from documents.ml.model_cache import ModelCacheManager - + try: cache_manager = ModelCacheManager.get_instance() cache_manager.clear_all() diff --git a/src/documents/apps.py b/src/documents/apps.py index b49588bd1..5bfc35147 100644 --- a/src/documents/apps.py +++ b/src/documents/apps.py @@ -38,22 +38,22 @@ class DocumentsConfig(AppConfig): def _initialize_ml_cache(self): """Initialize ML model cache and optionally warm up models.""" from django.conf import settings - + # Only initialize if ML features are enabled if not getattr(settings, "PAPERLESS_ENABLE_ML_FEATURES", False): return - + # Initialize cache manager with settings from documents.ml.model_cache import ModelCacheManager - + max_models = getattr(settings, "PAPERLESS_ML_CACHE_MAX_MODELS", 3) cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None) - + cache_manager = ModelCacheManager.get_instance( max_models=max_models, disk_cache_dir=str(cache_dir) if cache_dir else None, ) - + # Warm up models if configured warmup_enabled = getattr(settings, "PAPERLESS_ML_CACHE_WARMUP", False) if warmup_enabled: diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 8142eacdf..92a4e72c5 100644 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -310,7 +310,7 @@ class ConsumerPlugin( # Parsing phase document_parser = self._create_parser_instance(parser_class) text, date, thumbnail, archive_path, page_count = self._parse_document( - document_parser, mime_type + document_parser, mime_type, ) # Storage phase @@ -394,7 +394,7 @@ class ConsumerPlugin( def _attempt_pdf_recovery( self, tempdir: tempfile.TemporaryDirectory, - original_mime_type: str + original_mime_type: str, ) -> str: """ Attempt to recover a PDF file with incorrect MIME type using qpdf. @@ -438,7 +438,7 @@ class ConsumerPlugin( def _get_parser_class( self, mime_type: str, - tempdir: tempfile.TemporaryDirectory + tempdir: tempfile.TemporaryDirectory, ) -> type[DocumentParser]: """ Determine which parser to use based on MIME type. @@ -468,7 +468,7 @@ class ConsumerPlugin( def _create_parser_instance( self, - parser_class: type[DocumentParser] + parser_class: type[DocumentParser], ) -> DocumentParser: """ Create a parser instance with progress callback. @@ -496,7 +496,7 @@ class ConsumerPlugin( def _parse_document( self, document_parser: DocumentParser, - mime_type: str + mime_type: str, ) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]: """ Parse the document and extract metadata. @@ -670,7 +670,7 @@ class ConsumerPlugin( self, document: Document, thumbnail: Path, - archive_path: Path | None + archive_path: Path | None, ) -> None: """ Store document files (source, thumbnail, archive) to disk. @@ -949,7 +949,7 @@ class ConsumerPlugin( text: The extracted document text """ # Check if AI scanner is enabled - if not getattr(settings, 'PAPERLESS_ENABLE_AI_SCANNER', True): + if not getattr(settings, "PAPERLESS_ENABLE_AI_SCANNER", True): self.log.debug("AI scanner is disabled, skipping AI analysis") return diff --git a/src/documents/migrations/1075_add_performance_indexes.py b/src/documents/migrations/1075_add_performance_indexes.py index 0bf6a6a21..1d76517a6 100644 --- a/src/documents/migrations/1075_add_performance_indexes.py +++ b/src/documents/migrations/1075_add_performance_indexes.py @@ -1,6 +1,7 @@ # Generated manually for performance optimization -from django.db import migrations, models +from django.db import migrations +from django.db import models class Migration(migrations.Migration): diff --git a/src/documents/migrations/1076_add_deletion_request.py b/src/documents/migrations/1076_add_deletion_request.py index 3b27a19d1..3440a4507 100644 --- a/src/documents/migrations/1076_add_deletion_request.py +++ b/src/documents/migrations/1076_add_deletion_request.py @@ -1,9 +1,10 @@ # Generated manually for DeletionRequest model # Based on model definition in documents/models.py -from django.conf import settings -from django.db import migrations, models import django.db.models.deletion +from django.conf import settings +from django.db import migrations +from django.db import models class Migration(migrations.Migration): @@ -48,7 +49,7 @@ class Migration(migrations.Migration): ( "ai_reason", models.TextField( - help_text="Detailed explanation from AI about why deletion is recommended" + help_text="Detailed explanation from AI about why deletion is recommended", ), ), ( diff --git a/src/documents/migrations/1077_add_deletionrequest_performance_indexes.py b/src/documents/migrations/1077_add_deletionrequest_performance_indexes.py index 12d823f11..37cdd67d3 100644 --- a/src/documents/migrations/1077_add_deletionrequest_performance_indexes.py +++ b/src/documents/migrations/1077_add_deletionrequest_performance_indexes.py @@ -1,6 +1,7 @@ # Generated manually for DeletionRequest performance optimization -from django.db import migrations, models +from django.db import migrations +from django.db import models class Migration(migrations.Migration): diff --git a/src/documents/migrations/1078_aisuggestionfeedback.py b/src/documents/migrations/1078_aisuggestionfeedback.py index 405821ca1..b8bce01fa 100644 --- a/src/documents/migrations/1078_aisuggestionfeedback.py +++ b/src/documents/migrations/1078_aisuggestionfeedback.py @@ -1,9 +1,10 @@ # Generated manually for AI Suggestions API -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion import django.core.validators +import django.db.models.deletion +from django.conf import settings +from django.db import migrations +from django.db import models class Migration(migrations.Migration): diff --git a/src/documents/ml/__init__.py b/src/documents/ml/__init__.py index 347028daf..0b088b60b 100644 --- a/src/documents/ml/__init__.py +++ b/src/documents/ml/__init__.py @@ -10,9 +10,9 @@ Provides AI/ML capabilities including: from __future__ import annotations __all__ = [ - "TransformerDocumentClassifier", "DocumentNER", "SemanticSearch", + "TransformerDocumentClassifier", ] # Lazy imports to avoid loading heavy ML libraries unless needed diff --git a/src/documents/ml/classifier.py b/src/documents/ml/classifier.py index b70d12a3f..f3d185dec 100644 --- a/src/documents/ml/classifier.py +++ b/src/documents/ml/classifier.py @@ -15,23 +15,16 @@ Logging levels used in this module: from __future__ import annotations import logging -from pathlib import Path -from typing import TYPE_CHECKING import torch from torch.utils.data import Dataset -from transformers import ( - AutoModelForSequenceClassification, - AutoTokenizer, - Trainer, - TrainingArguments, -) +from transformers import AutoModelForSequenceClassification +from transformers import AutoTokenizer +from transformers import Trainer +from transformers import TrainingArguments from documents.ml.model_cache import ModelCacheManager -if TYPE_CHECKING: - from documents.models import Document - logger = logging.getLogger("paperless.ml.classifier") @@ -129,10 +122,10 @@ class TransformerDocumentClassifier: self.model_name = model_name self.use_cache = use_cache self.cache_manager = ModelCacheManager.get_instance() if use_cache else None - + # Cache key for this model configuration self.cache_key = f"classifier_{model_name}" - + # Load tokenizer (lightweight, not cached) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = None @@ -141,7 +134,7 @@ class TransformerDocumentClassifier: logger.info( f"Initialized TransformerDocumentClassifier with {model_name} " - f"(caching: {use_cache})" + f"(caching: {use_cache})", ) def train( @@ -264,14 +257,14 @@ class TransformerDocumentClassifier: if self.use_cache and self.cache_manager: # Try to get from cache first cache_key = f"{self.cache_key}_{model_dir}" - + def loader(): logger.info(f"Loading model from {model_dir}") model = AutoModelForSequenceClassification.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir) model.eval() # Set to evaluation mode return {"model": model, "tokenizer": tokenizer} - + cached = self.cache_manager.get_or_load_model(cache_key, loader) self.model = cached["model"] self.tokenizer = cached["tokenizer"] diff --git a/src/documents/ml/model_cache.py b/src/documents/ml/model_cache.py index b4d280404..47f2437bb 100644 --- a/src/documents/ml/model_cache.py +++ b/src/documents/ml/model_cache.py @@ -24,8 +24,9 @@ import pickle import threading import time from collections import OrderedDict +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any logger = logging.getLogger("paperless.ml.model_cache") @@ -58,7 +59,7 @@ class CacheMetrics: with self.lock: self.loads += 1 - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: with self.lock: total = self.hits + self.misses hit_rate = (self.hits / total * 100) if total > 0 else 0.0 @@ -98,7 +99,7 @@ class LRUCache: self.lock = threading.Lock() self.metrics = CacheMetrics() - def get(self, key: str) -> Optional[Any]: + def get(self, key: str) -> Any | None: """ Get item from cache. @@ -153,7 +154,7 @@ class LRUCache: with self.lock: return len(self.cache) - def get_metrics(self) -> Dict[str, Any]: + def get_metrics(self) -> dict[str, Any]: """Get cache metrics.""" return self.metrics.get_stats() @@ -173,7 +174,7 @@ class ModelCacheManager: model = cache.get_or_load_model("classifier", loader_func) """ - _instance: Optional[ModelCacheManager] = None + _instance: ModelCacheManager | None = None _lock = threading.Lock() def __new__(cls, *args, **kwargs): @@ -187,7 +188,7 @@ class ModelCacheManager: def __init__( self, max_models: int = 3, - disk_cache_dir: Optional[str] = None, + disk_cache_dir: str | None = None, ): """ Initialize model cache manager. @@ -215,7 +216,7 @@ class ModelCacheManager: def get_instance( cls, max_models: int = 3, - disk_cache_dir: Optional[str] = None, + disk_cache_dir: str | None = None, ) -> ModelCacheManager: """ Get singleton instance of ModelCacheManager. @@ -278,7 +279,7 @@ class ModelCacheManager: load_time = time.time() - start_time logger.info( f"Model loaded successfully: {model_key} " - f"(took {load_time:.2f}s)" + f"(took {load_time:.2f}s)", ) return model @@ -289,7 +290,7 @@ class ModelCacheManager: def save_embeddings_to_disk( self, key: str, - embeddings: Dict[int, Any], + embeddings: dict[int, Any], ) -> bool: """ Save embeddings to disk cache. @@ -311,7 +312,7 @@ class ModelCacheManager: cache_file = self.disk_cache_dir / f"{key}.pkl" try: - with open(cache_file, 'wb') as f: + with open(cache_file, "wb") as f: pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}") return True @@ -330,7 +331,7 @@ class ModelCacheManager: def load_embeddings_from_disk( self, key: str, - ) -> Optional[Dict[int, Any]]: + ) -> dict[int, Any] | None: """ Load embeddings from disk cache. @@ -344,7 +345,7 @@ class ModelCacheManager: return None cache_file = self.disk_cache_dir / f"{key}.pkl" - + if not cache_file.exists(): return None @@ -384,7 +385,7 @@ class ModelCacheManager: def clear_all(self) -> None: """Clear all caches (memory and disk).""" self.model_cache.clear() - + if self.disk_cache_dir and self.disk_cache_dir.exists(): for cache_file in self.disk_cache_dir.glob("*.pkl"): try: @@ -393,7 +394,7 @@ class ModelCacheManager: except Exception as e: logger.error(f"Failed to delete {cache_file}: {e}") - def get_metrics(self) -> Dict[str, Any]: + def get_metrics(self) -> dict[str, Any]: """ Get cache performance metrics. @@ -403,20 +404,20 @@ class ModelCacheManager: metrics = self.model_cache.get_metrics() metrics["cache_size"] = self.model_cache.size() metrics["max_size"] = self.model_cache.max_size - + if self.disk_cache_dir and self.disk_cache_dir.exists(): disk_files = list(self.disk_cache_dir.glob("*.pkl")) metrics["disk_cache_files"] = len(disk_files) - + # Calculate total disk cache size total_size = sum(f.stat().st_size for f in disk_files) metrics["disk_cache_size_mb"] = f"{total_size / 1024 / 1024:.2f}" - + return metrics def warm_up( self, - model_loaders: Dict[str, Callable[[], Any]], + model_loaders: dict[str, Callable[[], Any]], ) -> None: """ Pre-load models on startup (warm-up). @@ -426,12 +427,12 @@ class ModelCacheManager: """ logger.info(f"Starting model warm-up ({len(model_loaders)} models)...") start_time = time.time() - + for model_key, loader_func in model_loaders.items(): try: self.get_or_load_model(model_key, loader_func) except Exception as e: logger.warning(f"Failed to warm-up model {model_key}: {e}") - + warm_up_time = time.time() - start_time logger.info(f"Model warm-up completed in {warm_up_time:.2f}s") diff --git a/src/documents/ml/ner.py b/src/documents/ml/ner.py index 96612f7e1..e85949715 100644 --- a/src/documents/ml/ner.py +++ b/src/documents/ml/ner.py @@ -14,15 +14,11 @@ from __future__ import annotations import logging import re -from typing import TYPE_CHECKING from transformers import pipeline from documents.ml.model_cache import ModelCacheManager -if TYPE_CHECKING: - pass - logger = logging.getLogger("paperless.ml.ner") @@ -69,10 +65,10 @@ class DocumentNER: self.model_name = model_name self.use_cache = use_cache self.cache_manager = ModelCacheManager.get_instance() if use_cache else None - + # Cache key for this model cache_key = f"ner_{model_name}" - + if self.use_cache and self.cache_manager: # Load from cache or create new def loader(): @@ -81,7 +77,7 @@ class DocumentNER: model=model_name, aggregation_strategy="simple", ) - + self.ner_pipeline = self.cache_manager.get_or_load_model( cache_key, loader, diff --git a/src/documents/ml/semantic_search.py b/src/documents/ml/semantic_search.py index 7091561a1..8ed8d9d6f 100644 --- a/src/documents/ml/semantic_search.py +++ b/src/documents/ml/semantic_search.py @@ -18,18 +18,14 @@ Examples: from __future__ import annotations import logging -from pathlib import Path -from typing import TYPE_CHECKING import numpy as np import torch -from sentence_transformers import SentenceTransformer, util +from sentence_transformers import SentenceTransformer +from sentence_transformers import util from documents.ml.model_cache import ModelCacheManager -if TYPE_CHECKING: - pass - logger = logging.getLogger("paperless.ml.semantic_search") @@ -67,7 +63,7 @@ class SemanticSearch: """ logger.info( f"Initializing SemanticSearch with model: {model_name} " - f"(caching: {use_cache})" + f"(caching: {use_cache})", ) self.model_name = model_name @@ -75,10 +71,10 @@ class SemanticSearch: self.cache_manager = ModelCacheManager.get_instance( disk_cache_dir=cache_dir, ) if use_cache else None - + # Cache key for this model cache_key = f"semantic_search_{model_name}" - + if self.use_cache and self.cache_manager: # Load model from cache def loader(): @@ -127,11 +123,11 @@ class SemanticSearch: if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor): logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor") return False - if hasattr(embedding, 'size'): + if hasattr(embedding, "size"): if embedding.size == 0: logger.warning(f"Embedding for doc {doc_id} is empty") return False - elif hasattr(embedding, 'numel'): + elif hasattr(embedding, "numel"): if embedding.numel() == 0: logger.warning(f"Embedding for doc {doc_id} is empty") return False @@ -216,11 +212,11 @@ class SemanticSearch: try: result = self.cache_manager.save_embeddings_to_disk( "document_embeddings", - self.document_embeddings + self.document_embeddings, ) if result: logger.info( - f"Successfully saved {len(self.document_embeddings)} embeddings to disk" + f"Successfully saved {len(self.document_embeddings)} embeddings to disk", ) else: logger.error("Failed to save embeddings to disk (returned False)") diff --git a/src/documents/models.py b/src/documents/models.py index 42a9a048e..457430ea1 100644 --- a/src/documents/models.py +++ b/src/documents/models.py @@ -1596,59 +1596,59 @@ class DeletionRequest(models.Model): This ensures no documents are deleted without explicit user consent, implementing the safety requirement from agents.md. """ - + # Request metadata created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) - + # Requester (AI system) requested_by_ai = models.BooleanField(default=True) ai_reason = models.TextField( - help_text=_("Detailed explanation from AI about why deletion is recommended") + help_text=_("Detailed explanation from AI about why deletion is recommended"), ) - + # User who must approve user = models.ForeignKey( User, on_delete=models.CASCADE, - related_name='deletion_requests', + related_name="deletion_requests", help_text=_("User who must approve this deletion"), ) - + # Status tracking - STATUS_PENDING = 'pending' - STATUS_APPROVED = 'approved' - STATUS_REJECTED = 'rejected' - STATUS_CANCELLED = 'cancelled' - STATUS_COMPLETED = 'completed' - + STATUS_PENDING = "pending" + STATUS_APPROVED = "approved" + STATUS_REJECTED = "rejected" + STATUS_CANCELLED = "cancelled" + STATUS_COMPLETED = "completed" + STATUS_CHOICES = [ - (STATUS_PENDING, _('Pending')), - (STATUS_APPROVED, _('Approved')), - (STATUS_REJECTED, _('Rejected')), - (STATUS_CANCELLED, _('Cancelled')), - (STATUS_COMPLETED, _('Completed')), + (STATUS_PENDING, _("Pending")), + (STATUS_APPROVED, _("Approved")), + (STATUS_REJECTED, _("Rejected")), + (STATUS_CANCELLED, _("Cancelled")), + (STATUS_COMPLETED, _("Completed")), ] - + status = models.CharField( max_length=20, choices=STATUS_CHOICES, default=STATUS_PENDING, ) - + # Documents to be deleted documents = models.ManyToManyField( Document, - related_name='deletion_requests', + related_name="deletion_requests", help_text=_("Documents that would be deleted if approved"), ) - + # Impact summary (JSON field with details) impact_summary = models.JSONField( default=dict, help_text=_("Summary of what will be affected by this deletion"), ) - + # Approval tracking reviewed_at = models.DateTimeField(null=True, blank=True) reviewed_by = models.ForeignKey( @@ -1656,43 +1656,43 @@ class DeletionRequest(models.Model): on_delete=models.SET_NULL, null=True, blank=True, - related_name='reviewed_deletion_requests', + related_name="reviewed_deletion_requests", help_text=_("User who reviewed and approved/rejected"), ) review_comment = models.TextField( blank=True, help_text=_("User's comment when reviewing"), ) - + # Completion tracking completed_at = models.DateTimeField(null=True, blank=True) completion_details = models.JSONField( default=dict, help_text=_("Details about the deletion execution"), ) - + class Meta: - ordering = ['-created_at'] + ordering = ["-created_at"] verbose_name = _("deletion request") verbose_name_plural = _("deletion requests") indexes = [ # Composite index for common listing queries (by user, filtered by status, sorted by date) # PostgreSQL can use this index for queries on: user, user+status, user+status+created_at - models.Index(fields=['user', 'status', 'created_at'], name='delreq_user_status_created_idx'), + models.Index(fields=["user", "status", "created_at"], name="delreq_user_status_created_idx"), # Index for queries filtering by status and date without user filter - models.Index(fields=['status', 'created_at'], name='delreq_status_created_idx'), + models.Index(fields=["status", "created_at"], name="delreq_status_created_idx"), # Index for queries filtering by user and date (common for user-specific views) - models.Index(fields=['user', 'created_at'], name='delreq_user_created_idx'), + models.Index(fields=["user", "created_at"], name="delreq_user_created_idx"), # Index for queries filtering by review date - models.Index(fields=['reviewed_at'], name='delreq_reviewed_at_idx'), + models.Index(fields=["reviewed_at"], name="delreq_reviewed_at_idx"), # Index for queries filtering by completion date - models.Index(fields=['completed_at'], name='delreq_completed_at_idx'), + models.Index(fields=["completed_at"], name="delreq_completed_at_idx"), ] - + def __str__(self): doc_count = self.documents.count() return f"Deletion Request {self.id} - {doc_count} documents - {self.status}" - + def approve(self, user: User, comment: str = "") -> bool: """ Approve the deletion request. @@ -1706,15 +1706,15 @@ class DeletionRequest(models.Model): """ if self.status != self.STATUS_PENDING: return False - + self.status = self.STATUS_APPROVED self.reviewed_by = user self.reviewed_at = timezone.now() self.review_comment = comment self.save() - + return True - + def reject(self, user: User, comment: str = "") -> bool: """ Reject the deletion request. @@ -1728,13 +1728,13 @@ class DeletionRequest(models.Model): """ if self.status != self.STATUS_PENDING: return False - + self.status = self.STATUS_REJECTED self.reviewed_by = user self.reviewed_at = timezone.now() self.review_comment = comment self.save() - + return True @@ -1743,109 +1743,109 @@ class AISuggestionFeedback(models.Model): Model to track user feedback on AI suggestions (applied/rejected). Used for improving AI accuracy and providing statistics. """ - + # Suggestion types - TYPE_TAG = 'tag' - TYPE_CORRESPONDENT = 'correspondent' - TYPE_DOCUMENT_TYPE = 'document_type' - TYPE_STORAGE_PATH = 'storage_path' - TYPE_CUSTOM_FIELD = 'custom_field' - TYPE_WORKFLOW = 'workflow' - TYPE_TITLE = 'title' - + TYPE_TAG = "tag" + TYPE_CORRESPONDENT = "correspondent" + TYPE_DOCUMENT_TYPE = "document_type" + TYPE_STORAGE_PATH = "storage_path" + TYPE_CUSTOM_FIELD = "custom_field" + TYPE_WORKFLOW = "workflow" + TYPE_TITLE = "title" + SUGGESTION_TYPES = ( - (TYPE_TAG, _('Tag')), - (TYPE_CORRESPONDENT, _('Correspondent')), - (TYPE_DOCUMENT_TYPE, _('Document Type')), - (TYPE_STORAGE_PATH, _('Storage Path')), - (TYPE_CUSTOM_FIELD, _('Custom Field')), - (TYPE_WORKFLOW, _('Workflow')), - (TYPE_TITLE, _('Title')), + (TYPE_TAG, _("Tag")), + (TYPE_CORRESPONDENT, _("Correspondent")), + (TYPE_DOCUMENT_TYPE, _("Document Type")), + (TYPE_STORAGE_PATH, _("Storage Path")), + (TYPE_CUSTOM_FIELD, _("Custom Field")), + (TYPE_WORKFLOW, _("Workflow")), + (TYPE_TITLE, _("Title")), ) - + # Feedback status - STATUS_APPLIED = 'applied' - STATUS_REJECTED = 'rejected' - + STATUS_APPLIED = "applied" + STATUS_REJECTED = "rejected" + FEEDBACK_STATUS = ( - (STATUS_APPLIED, _('Applied')), - (STATUS_REJECTED, _('Rejected')), + (STATUS_APPLIED, _("Applied")), + (STATUS_REJECTED, _("Rejected")), ) - + document = models.ForeignKey( Document, on_delete=models.CASCADE, - related_name='ai_suggestion_feedbacks', - verbose_name=_('document'), + related_name="ai_suggestion_feedbacks", + verbose_name=_("document"), ) - + suggestion_type = models.CharField( - _('suggestion type'), + _("suggestion type"), max_length=50, choices=SUGGESTION_TYPES, ) - + suggested_value_id = models.IntegerField( - _('suggested value ID'), + _("suggested value ID"), null=True, blank=True, - help_text=_('ID of the suggested object (tag, correspondent, etc.)'), + help_text=_("ID of the suggested object (tag, correspondent, etc.)"), ) - + suggested_value_text = models.TextField( - _('suggested value text'), + _("suggested value text"), blank=True, - help_text=_('Text representation of the suggested value'), + help_text=_("Text representation of the suggested value"), ) - + confidence = models.FloatField( - _('confidence'), - help_text=_('AI confidence score (0.0 to 1.0)'), + _("confidence"), + help_text=_("AI confidence score (0.0 to 1.0)"), validators=[MinValueValidator(0.0), MaxValueValidator(1.0)], ) - + status = models.CharField( - _('status'), + _("status"), max_length=20, choices=FEEDBACK_STATUS, ) - + user = models.ForeignKey( User, on_delete=models.SET_NULL, null=True, blank=True, - related_name='ai_suggestion_feedbacks', - verbose_name=_('user'), - help_text=_('User who applied or rejected the suggestion'), + related_name="ai_suggestion_feedbacks", + verbose_name=_("user"), + help_text=_("User who applied or rejected the suggestion"), ) - + created_at = models.DateTimeField( - _('created at'), + _("created at"), auto_now_add=True, ) - + applied_at = models.DateTimeField( - _('applied/rejected at'), + _("applied/rejected at"), auto_now=True, ) - + metadata = models.JSONField( - _('metadata'), + _("metadata"), default=dict, blank=True, - help_text=_('Additional metadata about the suggestion'), + help_text=_("Additional metadata about the suggestion"), ) - + class Meta: - verbose_name = _('AI suggestion feedback') - verbose_name_plural = _('AI suggestion feedbacks') - ordering = ['-created_at'] + verbose_name = _("AI suggestion feedback") + verbose_name_plural = _("AI suggestion feedbacks") + ordering = ["-created_at"] indexes = [ - models.Index(fields=['document', 'suggestion_type']), - models.Index(fields=['status', 'created_at']), - models.Index(fields=['suggestion_type', 'status']), + models.Index(fields=["document", "suggestion_type"]), + models.Index(fields=["status", "created_at"]), + models.Index(fields=["suggestion_type", "status"]), ] - + def __str__(self): return f"{self.suggestion_type} suggestion for document {self.document_id} - {self.status}" diff --git a/src/documents/ocr/__init__.py b/src/documents/ocr/__init__.py index 3fdbb3db4..ef3d47354 100644 --- a/src/documents/ocr/__init__.py +++ b/src/documents/ocr/__init__.py @@ -11,21 +11,21 @@ Lazy imports are used to avoid loading heavy dependencies unless needed. """ __all__ = [ - 'TableExtractor', - 'HandwritingRecognizer', - 'FormFieldDetector', + "FormFieldDetector", + "HandwritingRecognizer", + "TableExtractor", ] def __getattr__(name): """Lazy import to avoid loading heavy ML models on startup.""" - if name == 'TableExtractor': + if name == "TableExtractor": from .table_extractor import TableExtractor return TableExtractor - elif name == 'HandwritingRecognizer': + elif name == "HandwritingRecognizer": from .handwriting import HandwritingRecognizer return HandwritingRecognizer - elif name == 'FormFieldDetector': + elif name == "FormFieldDetector": from .form_detector import FormFieldDetector return FormFieldDetector raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/documents/ocr/form_detector.py b/src/documents/ocr/form_detector.py index a11e7e49f..d52e080a5 100644 --- a/src/documents/ocr/form_detector.py +++ b/src/documents/ocr/form_detector.py @@ -8,8 +8,8 @@ This module provides capabilities to: """ import logging -from pathlib import Path -from typing import List, Dict, Any, Optional, Tuple +from typing import Any + import numpy as np from PIL import Image @@ -37,7 +37,7 @@ class FormFieldDetector: >>> for cb in checkboxes: ... print(f"{cb['label']}: {'✓' if cb['checked'] else '☐'}") """ - + def __init__(self, use_gpu: bool = True): """ Initialize the form field detector. @@ -47,20 +47,20 @@ class FormFieldDetector: """ self.use_gpu = use_gpu self._handwriting_recognizer = None - + def _get_handwriting_recognizer(self): """Lazy load handwriting recognizer for field value extraction.""" if self._handwriting_recognizer is None: from .handwriting import HandwritingRecognizer self._handwriting_recognizer = HandwritingRecognizer(use_gpu=self.use_gpu) return self._handwriting_recognizer - + def detect_checkboxes( - self, + self, image: Image.Image, min_size: int = 10, - max_size: int = 50 - ) -> List[Dict[str, Any]]: + max_size: int = 50, + ) -> list[dict[str, Any]]: """ Detect checkboxes in a form image. @@ -82,54 +82,54 @@ class FormFieldDetector: """ try: import cv2 - + # Convert to OpenCV format img_array = np.array(image) if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array - + # Detect edges edges = cv2.Canny(gray, 50, 150) - + # Find contours contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - + checkboxes = [] for contour in contours: # Get bounding box x, y, w, h = cv2.boundingRect(contour) - + # Check if it looks like a checkbox (square-ish, right size) aspect_ratio = w / h if h > 0 else 0 - if (min_size <= w <= max_size and - min_size <= h <= max_size and + if (min_size <= w <= max_size and + min_size <= h <= max_size and 0.7 <= aspect_ratio <= 1.3): - + # Extract checkbox region checkbox_region = gray[y:y+h, x:x+w] - + # Determine if checked (look for marks inside) checked, confidence = self._is_checkbox_checked(checkbox_region) - + checkboxes.append({ - 'bbox': [x, y, x+w, y+h], - 'checked': checked, - 'confidence': confidence + "bbox": [x, y, x+w, y+h], + "checked": checked, + "confidence": confidence, }) - + logger.info(f"Detected {len(checkboxes)} checkboxes") return checkboxes - + except ImportError: logger.error("opencv-python not installed. Install with: pip install opencv-python") return [] except Exception as e: logger.error(f"Error detecting checkboxes: {e}") return [] - - def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> Tuple[bool, float]: + + def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> tuple[bool, float]: """ Determine if a checkbox is checked. @@ -141,34 +141,34 @@ class FormFieldDetector: """ try: import cv2 - + # Binarize _, binary = cv2.threshold(checkbox_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) - + # Count dark pixels in the center region (where mark would be) h, w = binary.shape center_region = binary[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)] - + if center_region.size == 0: return False, 0.0 - + dark_pixel_ratio = np.sum(center_region > 0) / center_region.size - + # If more than 15% of center is dark, consider it checked checked = dark_pixel_ratio > 0.15 confidence = min(dark_pixel_ratio * 2, 1.0) # Scale confidence - + return checked, confidence - + except Exception as e: logger.warning(f"Error checking checkbox state: {e}") return False, 0.0 - + def detect_text_fields( - self, + self, image: Image.Image, - min_width: int = 100 - ) -> List[Dict[str, Any]]: + min_width: int = 100, + ) -> list[dict[str, Any]]: """ Detect text input fields in a form. @@ -188,73 +188,73 @@ class FormFieldDetector: """ try: import cv2 - + # Convert to OpenCV format img_array = np.array(image) if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array - + # Detect horizontal lines (underlines for text fields) horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_width, 1)) detect_horizontal = cv2.morphologyEx( cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1], cv2.MORPH_OPEN, horizontal_kernel, - iterations=2 + iterations=2, ) - + # Find contours of horizontal lines contours, _ = cv2.findContours( - detect_horizontal, - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE + detect_horizontal, + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE, ) - + text_fields = [] for contour in contours: x, y, w, h = cv2.boundingRect(contour) - + # Check if it's a horizontal line (field underline) if w >= min_width and h < 10: # Expand upward to include text area text_bbox = [x, max(0, y-30), x+w, y+h] text_fields.append({ - 'bbox': text_bbox, - 'type': 'line' + "bbox": text_bbox, + "type": "line", }) - + # Detect rectangular boxes (bordered text fields) edges = cv2.Canny(gray, 50, 150) contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - + for contour in contours: x, y, w, h = cv2.boundingRect(contour) - + # Check if it's a rectangular box aspect_ratio = w / h if h > 0 else 0 if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2: text_fields.append({ - 'bbox': [x, y, x+w, y+h], - 'type': 'box' + "bbox": [x, y, x+w, y+h], + "type": "box", }) - + logger.info(f"Detected {len(text_fields)} text fields") return text_fields - + except ImportError: logger.error("opencv-python not installed") return [] except Exception as e: logger.error(f"Error detecting text fields: {e}") return [] - + def detect_labels( - self, + self, image: Image.Image, - field_bboxes: List[List[int]] - ) -> List[Dict[str, Any]]: + field_bboxes: list[list[int]], + ) -> list[dict[str, Any]]: """ Detect labels near form fields. @@ -267,47 +267,47 @@ class FormFieldDetector: """ try: import pytesseract - + # Get all text with bounding boxes ocr_data = pytesseract.image_to_data( - image, - output_type=pytesseract.Output.DICT + image, + output_type=pytesseract.Output.DICT, ) - + # Group text into potential labels labels = [] - for i, text in enumerate(ocr_data['text']): + for i, text in enumerate(ocr_data["text"]): if text.strip() and len(text.strip()) > 2: - x = ocr_data['left'][i] - y = ocr_data['top'][i] - w = ocr_data['width'][i] - h = ocr_data['height'][i] - + x = ocr_data["left"][i] + y = ocr_data["top"][i] + w = ocr_data["width"][i] + h = ocr_data["height"][i] + label_bbox = [x, y, x+w, y+h] - + # Find closest field closest_field_idx = self._find_closest_field(label_bbox, field_bboxes) - + labels.append({ - 'text': text.strip(), - 'bbox': label_bbox, - 'field_index': closest_field_idx + "text": text.strip(), + "bbox": label_bbox, + "field_index": closest_field_idx, }) - + return labels - + except ImportError: logger.error("pytesseract not installed") return [] except Exception as e: logger.error(f"Error detecting labels: {e}") return [] - + def _find_closest_field( - self, - label_bbox: List[int], - field_bboxes: List[List[int]] - ) -> Optional[int]: + self, + label_bbox: list[int], + field_bboxes: list[list[int]], + ) -> int | None: """ Find the closest field to a label. @@ -320,36 +320,36 @@ class FormFieldDetector: """ if not field_bboxes: return None - + # Calculate center of label label_center_x = (label_bbox[0] + label_bbox[2]) / 2 label_center_y = (label_bbox[1] + label_bbox[3]) / 2 - - min_distance = float('inf') + + min_distance = float("inf") closest_idx = 0 - + for i, field_bbox in enumerate(field_bboxes): # Calculate center of field field_center_x = (field_bbox[0] + field_bbox[2]) / 2 field_center_y = (field_bbox[1] + field_bbox[3]) / 2 - + # Euclidean distance distance = np.sqrt( - (label_center_x - field_center_x)**2 + - (label_center_y - field_center_y)**2 + (label_center_x - field_center_x)**2 + + (label_center_y - field_center_y)**2, ) - + if distance < min_distance: min_distance = distance closest_idx = i - + return closest_idx - + def detect_form_fields( - self, + self, image_path: str, - extract_values: bool = True - ) -> List[Dict[str, Any]]: + extract_values: bool = True, + ) -> list[dict[str, Any]]: """ Detect all form fields and extract their values. @@ -372,69 +372,69 @@ class FormFieldDetector: """ try: # Load image - image = Image.open(image_path).convert('RGB') - + image = Image.open(image_path).convert("RGB") + # Detect different field types text_fields = self.detect_text_fields(image) checkboxes = self.detect_checkboxes(image) - + # Combine all field bboxes for label detection - all_field_bboxes = [f['bbox'] for f in text_fields] + [cb['bbox'] for cb in checkboxes] - + all_field_bboxes = [f["bbox"] for f in text_fields] + [cb["bbox"] for cb in checkboxes] + # Detect labels labels = self.detect_labels(image, all_field_bboxes) - + # Build results results = [] - + # Add text fields for i, field in enumerate(text_fields): # Find associated label label_text = self._find_label_for_field(i, labels, len(text_fields)) - + result = { - 'type': 'text', - 'label': label_text, - 'bbox': field['bbox'], + "type": "text", + "label": label_text, + "bbox": field["bbox"], } - + # Extract value if requested if extract_values: - x1, y1, x2, y2 = field['bbox'] + x1, y1, x2, y2 = field["bbox"] field_image = image.crop((x1, y1, x2, y2)) - + recognizer = self._get_handwriting_recognizer() value = recognizer.recognize_from_image(field_image, preprocess=True) - result['value'] = value.strip() - result['confidence'] = recognizer._estimate_confidence(value) - + result["value"] = value.strip() + result["confidence"] = recognizer._estimate_confidence(value) + results.append(result) - + # Add checkboxes for i, checkbox in enumerate(checkboxes): field_idx = len(text_fields) + i label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes)) - + results.append({ - 'type': 'checkbox', - 'label': label_text, - 'value': checkbox['checked'], - 'bbox': checkbox['bbox'], - 'confidence': checkbox['confidence'] + "type": "checkbox", + "label": label_text, + "value": checkbox["checked"], + "bbox": checkbox["bbox"], + "confidence": checkbox["confidence"], }) - + logger.info(f"Detected {len(results)} form fields from {image_path}") return results - + except Exception as e: logger.error(f"Error detecting form fields: {e}") return [] - + def _find_label_for_field( - self, - field_idx: int, - labels: List[Dict[str, Any]], - total_fields: int + self, + field_idx: int, + labels: list[dict[str, Any]], + total_fields: int, ) -> str: """ Find the label text for a specific field. @@ -448,20 +448,20 @@ class FormFieldDetector: Label text or empty string if not found """ matching_labels = [ - label for label in labels - if label['field_index'] == field_idx + label for label in labels + if label["field_index"] == field_idx ] - + if matching_labels: # Combine multiple label parts if found - return ' '.join(label['text'] for label in matching_labels) - + return " ".join(label["text"] for label in matching_labels) + return f"Field_{field_idx + 1}" - + def extract_form_data( - self, + self, image_path: str, - output_format: str = 'dict' + output_format: str = "dict", ) -> Any: """ Extract all form data as structured output. @@ -475,19 +475,19 @@ class FormFieldDetector: """ # Detect and extract fields fields = self.detect_form_fields(image_path, extract_values=True) - - if output_format == 'dict': + + if output_format == "dict": # Return as dictionary - return {field['label']: field['value'] for field in fields} - - elif output_format == 'json': + return {field["label"]: field["value"] for field in fields} + + elif output_format == "json": import json - data = {field['label']: field['value'] for field in fields} + data = {field["label"]: field["value"] for field in fields} return json.dumps(data, indent=2) - - elif output_format == 'dataframe': + + elif output_format == "dataframe": import pandas as pd return pd.DataFrame(fields) - + else: raise ValueError(f"Invalid output format: {output_format}") diff --git a/src/documents/ocr/handwriting.py b/src/documents/ocr/handwriting.py index b9453693d..7a3d9467b 100644 --- a/src/documents/ocr/handwriting.py +++ b/src/documents/ocr/handwriting.py @@ -8,8 +8,8 @@ This module provides handwriting OCR capabilities using: """ import logging -from pathlib import Path -from typing import List, Dict, Any, Optional, Tuple +from typing import Any + import numpy as np from PIL import Image @@ -34,7 +34,7 @@ class HandwritingRecognizer: >>> for line in lines: ... print(f"{line['text']} (confidence: {line['confidence']:.2f})") """ - + def __init__( self, model_name: str = "microsoft/trocr-base-handwritten", @@ -58,39 +58,40 @@ class HandwritingRecognizer: self.confidence_threshold = confidence_threshold self._model = None self._processor = None - + def _load_model(self): """Lazy load the handwriting recognition model.""" if self._model is not None: return - + try: - from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch - + from transformers import TrOCRProcessor + from transformers import VisionEncoderDecoderModel + logger.info(f"Loading handwriting recognition model: {self.model_name}") - + self._processor = TrOCRProcessor.from_pretrained(self.model_name) self._model = VisionEncoderDecoderModel.from_pretrained(self.model_name) - + # Move to GPU if available and requested if self.use_gpu and torch.cuda.is_available(): self._model = self._model.cuda() logger.info("Using GPU for handwriting recognition") else: logger.info("Using CPU for handwriting recognition") - + self._model.eval() # Set to evaluation mode - + except ImportError as e: logger.error(f"Failed to load handwriting model: {e}") logger.error("Please install: pip install transformers torch pillow") raise - + def recognize_from_image( - self, + self, image: Image.Image, - preprocess: bool = True + preprocess: bool = True, ) -> str: """ Recognize text from a single image. @@ -103,34 +104,34 @@ class HandwritingRecognizer: Recognized text string """ self._load_model() - + try: import torch - + # Preprocess image if requested if preprocess: image = self._preprocess_image(image) - + # Prepare image for model pixel_values = self._processor(images=image, return_tensors="pt").pixel_values - + if self.use_gpu and torch.cuda.is_available(): pixel_values = pixel_values.cuda() - + # Generate text with torch.no_grad(): generated_ids = self._model.generate(pixel_values) - + # Decode to text text = self._processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - + logger.debug(f"Recognized text: {text[:100]}...") return text - + except Exception as e: logger.error(f"Error recognizing handwriting: {e}") return "" - + def _preprocess_image(self, image: Image.Image) -> Image.Image: """ Preprocess image for better recognition. @@ -142,29 +143,30 @@ class HandwritingRecognizer: Preprocessed PIL Image """ try: - from PIL import ImageEnhance, ImageFilter - + from PIL import ImageEnhance + from PIL import ImageFilter + # Convert to grayscale - if image.mode != 'L': - image = image.convert('L') - + if image.mode != "L": + image = image.convert("L") + # Enhance contrast enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(2.0) - + # Denoise image = image.filter(ImageFilter.MedianFilter(size=3)) - + # Convert back to RGB (required by model) - image = image.convert('RGB') - + image = image.convert("RGB") + return image - + except Exception as e: logger.warning(f"Error preprocessing image: {e}") return image - - def detect_text_lines(self, image: Image.Image) -> List[Dict[str, Any]]: + + def detect_text_lines(self, image: Image.Image) -> list[dict[str, Any]]: """ Detect individual text lines in an image. @@ -184,52 +186,52 @@ class HandwritingRecognizer: try: import cv2 import numpy as np - + # Convert PIL to OpenCV format img_array = np.array(image) if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array - + # Binarize _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) - + # Find contours contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - + # Get bounding boxes for each contour lines = [] for contour in contours: x, y, w, h = cv2.boundingRect(contour) - + # Filter out very small regions if w > 20 and h > 10: # Crop line from original image line_img = image.crop((x, y, x+w, y+h)) lines.append({ - 'bbox': [x, y, x+w, y+h], - 'image': line_img + "bbox": [x, y, x+w, y+h], + "image": line_img, }) - + # Sort lines top to bottom - lines.sort(key=lambda l: l['bbox'][1]) - + lines.sort(key=lambda l: l["bbox"][1]) + logger.info(f"Detected {len(lines)} text lines") return lines - + except ImportError: logger.error("opencv-python not installed. Install with: pip install opencv-python") return [] except Exception as e: logger.error(f"Error detecting text lines: {e}") return [] - + def recognize_lines( - self, + self, image_path: str, - return_confidence: bool = True - ) -> List[Dict[str, Any]]: + return_confidence: bool = True, + ) -> list[dict[str, Any]]: """ Recognize text from each line in an image. @@ -250,38 +252,38 @@ class HandwritingRecognizer: """ try: # Load image - image = Image.open(image_path).convert('RGB') - + image = Image.open(image_path).convert("RGB") + # Detect lines lines = self.detect_text_lines(image) - + # Recognize each line results = [] for i, line in enumerate(lines): logger.debug(f"Recognizing line {i+1}/{len(lines)}") - - text = self.recognize_from_image(line['image'], preprocess=True) - + + text = self.recognize_from_image(line["image"], preprocess=True) + result = { - 'text': text, - 'bbox': line['bbox'], - 'line_index': i + "text": text, + "bbox": line["bbox"], + "line_index": i, } - + if return_confidence: # Simple confidence based on text length and content confidence = self._estimate_confidence(text) - result['confidence'] = confidence - + result["confidence"] = confidence + results.append(result) - + logger.info(f"Recognized {len(results)} lines from {image_path}") return results - + except Exception as e: logger.error(f"Error recognizing lines from {image_path}: {e}") return [] - + def _estimate_confidence(self, text: str) -> float: """ Estimate confidence of recognition result. @@ -294,36 +296,36 @@ class HandwritingRecognizer: """ if not text: return 0.0 - + # Factors that indicate good recognition score = 0.5 # Base score - + # Longer text tends to be more reliable if len(text) > 10: score += 0.1 if len(text) > 20: score += 0.1 - + # Text with alphanumeric characters is more reliable if any(c.isalnum() for c in text): score += 0.1 - + # Text with spaces (words) is more reliable - if ' ' in text: + if " " in text: score += 0.1 - + # Penalize if too many special characters special_chars = sum(1 for c in text if not c.isalnum() and not c.isspace()) if special_chars / len(text) > 0.5: score -= 0.2 - + return max(0.0, min(1.0, score)) - + def recognize_from_file( - self, + self, image_path: str, - mode: str = 'full' - ) -> Dict[str, Any]: + mode: str = "full", + ) -> dict[str, Any]: """ Recognize handwriting from an image file. @@ -337,49 +339,49 @@ class HandwritingRecognizer: Dictionary with recognized text and metadata """ try: - if mode == 'full': + if mode == "full": # Recognize entire image - image = Image.open(image_path).convert('RGB') + image = Image.open(image_path).convert("RGB") text = self.recognize_from_image(image, preprocess=True) - + return { - 'text': text, - 'mode': 'full', - 'confidence': self._estimate_confidence(text) + "text": text, + "mode": "full", + "confidence": self._estimate_confidence(text), } - - elif mode == 'lines': + + elif mode == "lines": # Recognize line by line lines = self.recognize_lines(image_path, return_confidence=True) - + # Combine all lines - full_text = '\n'.join(line['text'] for line in lines) - avg_confidence = np.mean([line['confidence'] for line in lines]) if lines else 0.0 - + full_text = "\n".join(line["text"] for line in lines) + avg_confidence = np.mean([line["confidence"] for line in lines]) if lines else 0.0 + return { - 'text': full_text, - 'lines': lines, - 'mode': 'lines', - 'confidence': float(avg_confidence) + "text": full_text, + "lines": lines, + "mode": "lines", + "confidence": float(avg_confidence), } - + else: raise ValueError(f"Invalid mode: {mode}. Use 'full' or 'lines'") - + except Exception as e: logger.error(f"Error recognizing from file {image_path}: {e}") return { - 'text': '', - 'mode': mode, - 'confidence': 0.0, - 'error': str(e) + "text": "", + "mode": mode, + "confidence": 0.0, + "error": str(e), } - + def recognize_form_fields( - self, + self, image_path: str, - field_regions: List[Dict[str, Any]] - ) -> Dict[str, str]: + field_regions: list[dict[str, Any]], + ) -> dict[str, str]: """ Recognize text from specific form fields. @@ -399,35 +401,35 @@ class HandwritingRecognizer: """ try: # Load image - image = Image.open(image_path).convert('RGB') - + image = Image.open(image_path).convert("RGB") + # Extract and recognize each field results = {} for field in field_regions: - name = field['name'] - bbox = field['bbox'] - + name = field["name"] + bbox = field["bbox"] + # Crop field region x1, y1, x2, y2 = bbox field_image = image.crop((x1, y1, x2, y2)) - + # Recognize text text = self.recognize_from_image(field_image, preprocess=True) results[name] = text.strip() - + logger.debug(f"Field '{name}': {text[:50]}...") - + return results - + except Exception as e: logger.error(f"Error recognizing form fields: {e}") return {} - + def batch_recognize( - self, - image_paths: List[str], - mode: str = 'full' - ) -> List[Dict[str, Any]]: + self, + image_paths: list[str], + mode: str = "full", + ) -> list[dict[str, Any]]: """ Recognize handwriting from multiple images in batch. @@ -442,7 +444,7 @@ class HandwritingRecognizer: for i, path in enumerate(image_paths): logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}") result = self.recognize_from_file(path, mode=mode) - result['image_path'] = path + result["image_path"] = path results.append(result) - + return results diff --git a/src/documents/ocr/table_extractor.py b/src/documents/ocr/table_extractor.py index b94b2a236..4716757ae 100644 --- a/src/documents/ocr/table_extractor.py +++ b/src/documents/ocr/table_extractor.py @@ -8,9 +8,8 @@ This module uses various techniques to detect and extract tables from documents: """ import logging -from pathlib import Path -from typing import List, Dict, Any, Optional, Tuple -import numpy as np +from typing import Any + from PIL import Image logger = logging.getLogger(__name__) @@ -32,7 +31,7 @@ class TableExtractor: ... print(table['data']) # pandas DataFrame ... print(table['bbox']) # bounding box coordinates """ - + def __init__( self, model_name: str = "microsoft/table-transformer-detection", @@ -52,34 +51,35 @@ class TableExtractor: self.use_gpu = use_gpu self._model = None self._processor = None - + def _load_model(self): """Lazy load the table detection model.""" if self._model is not None: return - + try: - from transformers import AutoImageProcessor, AutoModelForObjectDetection import torch - + from transformers import AutoImageProcessor + from transformers import AutoModelForObjectDetection + logger.info(f"Loading table detection model: {self.model_name}") - + self._processor = AutoImageProcessor.from_pretrained(self.model_name) self._model = AutoModelForObjectDetection.from_pretrained(self.model_name) - + # Move to GPU if available and requested if self.use_gpu and torch.cuda.is_available(): self._model = self._model.cuda() logger.info("Using GPU for table detection") else: logger.info("Using CPU for table detection") - + except ImportError as e: logger.error(f"Failed to load table detection model: {e}") logger.error("Please install required packages: pip install transformers torch pillow") raise - - def detect_tables(self, image: Image.Image) -> List[Dict[str, Any]]: + + def detect_tables(self, image: Image.Image) -> list[dict[str, Any]]: """ Detect tables in an image. @@ -98,50 +98,50 @@ class TableExtractor: ] """ self._load_model() - + try: import torch - + # Prepare image inputs = self._processor(images=image, return_tensors="pt") - + if self.use_gpu and torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} - + # Run detection with torch.no_grad(): outputs = self._model(**inputs) - + # Post-process results target_sizes = torch.tensor([image.size[::-1]]) results = self._processor.post_process_object_detection( - outputs, + outputs, threshold=self.confidence_threshold, - target_sizes=target_sizes + target_sizes=target_sizes, )[0] - + # Convert to list of dicts tables = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): tables.append({ - 'bbox': box.cpu().tolist(), - 'score': score.item(), - 'label': self._model.config.id2label[label.item()] + "bbox": box.cpu().tolist(), + "score": score.item(), + "label": self._model.config.id2label[label.item()], }) - + logger.info(f"Detected {len(tables)} tables in image") return tables - + except Exception as e: logger.error(f"Error detecting tables: {e}") return [] - + def extract_table_from_region( - self, - image: Image.Image, - bbox: List[float], - use_ocr: bool = True - ) -> Optional[Dict[str, Any]]: + self, + image: Image.Image, + bbox: list[float], + use_ocr: bool = True, + ) -> dict[str, Any] | None: """ Extract table data from a specific region of an image. @@ -158,48 +158,48 @@ class TableExtractor: # Crop to table region x1, y1, x2, y2 = [int(coord) for coord in bbox] table_image = image.crop((x1, y1, x2, y2)) - + if use_ocr: # Use OCR to extract text and structure import pytesseract - + # Get detailed OCR data ocr_data = pytesseract.image_to_data( - table_image, - output_type=pytesseract.Output.DICT + table_image, + output_type=pytesseract.Output.DICT, ) - + # Reconstruct table structure from OCR data table_data = self._reconstruct_table_from_ocr(ocr_data) - + # Also get raw text raw_text = pytesseract.image_to_string(table_image) - + return { - 'data': table_data, - 'raw_text': raw_text, - 'bbox': bbox, - 'image_size': table_image.size + "data": table_data, + "raw_text": raw_text, + "bbox": bbox, + "image_size": table_image.size, } else: # Fallback to basic OCR without structure import pytesseract raw_text = pytesseract.image_to_string(table_image) return { - 'data': None, - 'raw_text': raw_text, - 'bbox': bbox, - 'image_size': table_image.size + "data": None, + "raw_text": raw_text, + "bbox": bbox, + "image_size": table_image.size, } - + except ImportError: logger.error("pytesseract not installed. Install with: pip install pytesseract") return None except Exception as e: logger.error(f"Error extracting table from region: {e}") return None - - def _reconstruct_table_from_ocr(self, ocr_data: Dict) -> Optional[Any]: + + def _reconstruct_table_from_ocr(self, ocr_data: dict) -> Any | None: """ Reconstruct table structure from OCR output. @@ -211,58 +211,58 @@ class TableExtractor: """ try: import pandas as pd - + # Group text by vertical position (rows) rows = {} - for i, text in enumerate(ocr_data['text']): + for i, text in enumerate(ocr_data["text"]): if text.strip(): - top = ocr_data['top'][i] - left = ocr_data['left'][i] - + top = ocr_data["top"][i] + left = ocr_data["left"][i] + # Group by approximate row (within 20 pixels) row_key = round(top / 20) * 20 if row_key not in rows: rows[row_key] = [] rows[row_key].append((left, text)) - + # Sort rows and create DataFrame table_rows = [] for row_y in sorted(rows.keys()): # Sort cells by horizontal position cells = [text for _, text in sorted(rows[row_y])] table_rows.append(cells) - + if table_rows: # Pad rows to same length max_cols = max(len(row) for row in table_rows) - table_rows = [row + [''] * (max_cols - len(row)) for row in table_rows] - + table_rows = [row + [""] * (max_cols - len(row)) for row in table_rows] + # Create DataFrame df = pd.DataFrame(table_rows) - + # Try to use first row as header if it looks like one if len(df) > 1: - first_row_text = ' '.join(str(x) for x in df.iloc[0]) + first_row_text = " ".join(str(x) for x in df.iloc[0]) if not any(char.isdigit() for char in first_row_text): df.columns = df.iloc[0] df = df[1:].reset_index(drop=True) - + return df - + return None - + except ImportError: logger.error("pandas not installed. Install with: pip install pandas") return None except Exception as e: logger.error(f"Error reconstructing table: {e}") return None - + def extract_tables_from_image( - self, + self, image_path: str, - output_format: str = 'dataframe' - ) -> List[Dict[str, Any]]: + output_format: str = "dataframe", + ) -> list[dict[str, Any]]: """ Extract all tables from an image file. @@ -275,45 +275,45 @@ class TableExtractor: """ try: # Load image - image = Image.open(image_path).convert('RGB') - + image = Image.open(image_path).convert("RGB") + # Detect tables detections = self.detect_tables(image) - + # Extract data from each table tables = [] for i, detection in enumerate(detections): logger.info(f"Extracting table {i+1}/{len(detections)}") - + table_data = self.extract_table_from_region( - image, - detection['bbox'] + image, + detection["bbox"], ) - + if table_data: - table_data['detection_score'] = detection['score'] - table_data['table_index'] = i - + table_data["detection_score"] = detection["score"] + table_data["table_index"] = i + # Convert to requested format - if output_format == 'csv' and table_data['data'] is not None: - table_data['csv'] = table_data['data'].to_csv(index=False) - elif output_format == 'json' and table_data['data'] is not None: - table_data['json'] = table_data['data'].to_json(orient='records') - + if output_format == "csv" and table_data["data"] is not None: + table_data["csv"] = table_data["data"].to_csv(index=False) + elif output_format == "json" and table_data["data"] is not None: + table_data["json"] = table_data["data"].to_json(orient="records") + tables.append(table_data) - + logger.info(f"Successfully extracted {len(tables)} tables from {image_path}") return tables - + except Exception as e: logger.error(f"Error extracting tables from image {image_path}: {e}") return [] - + def extract_tables_from_pdf( - self, + self, pdf_path: str, - page_numbers: Optional[List[int]] = None - ) -> Dict[int, List[Dict[str, Any]]]: + page_numbers: list[int] | None = None, + ) -> dict[int, list[dict[str, Any]]]: """ Extract tables from a PDF document. @@ -326,56 +326,56 @@ class TableExtractor: """ try: from pdf2image import convert_from_path - + logger.info(f"Converting PDF to images: {pdf_path}") - + # Convert PDF pages to images if page_numbers: images = convert_from_path( - pdf_path, + pdf_path, first_page=min(page_numbers), - last_page=max(page_numbers) + last_page=max(page_numbers), ) else: images = convert_from_path(pdf_path) - + # Extract tables from each page results = {} for i, image in enumerate(images): page_num = page_numbers[i] if page_numbers else i + 1 logger.info(f"Processing page {page_num}") - + # Detect and extract tables detections = self.detect_tables(image) tables = [] - + for detection in detections: table_data = self.extract_table_from_region( - image, - detection['bbox'] + image, + detection["bbox"], ) if table_data: - table_data['detection_score'] = detection['score'] - table_data['page'] = page_num + table_data["detection_score"] = detection["score"] + table_data["page"] = page_num tables.append(table_data) - + if tables: results[page_num] = tables logger.info(f"Found {len(tables)} tables on page {page_num}") - + return results - + except ImportError: logger.error("pdf2image not installed. Install with: pip install pdf2image") return {} except Exception as e: logger.error(f"Error extracting tables from PDF: {e}") return {} - + def save_tables_to_excel( - self, - tables: List[Dict[str, Any]], - output_path: str + self, + tables: list[dict[str, Any]], + output_path: str, ) -> bool: """ Save extracted tables to an Excel file. @@ -389,23 +389,23 @@ class TableExtractor: """ try: import pandas as pd - - with pd.ExcelWriter(output_path, engine='openpyxl') as writer: + + with pd.ExcelWriter(output_path, engine="openpyxl") as writer: for i, table in enumerate(tables): - if table.get('data') is not None: + if table.get("data") is not None: sheet_name = f"Table_{i+1}" - if 'page' in table: + if "page" in table: sheet_name = f"Page_{table['page']}_Table_{i+1}" - - table['data'].to_excel( - writer, - sheet_name=sheet_name, - index=False + + table["data"].to_excel( + writer, + sheet_name=sheet_name, + index=False, ) - + logger.info(f"Saved {len(tables)} tables to {output_path}") return True - + except ImportError: logger.error("openpyxl not installed. Install with: pip install openpyxl") return False diff --git a/src/documents/permissions.py b/src/documents/permissions.py index 2ab20b497..732181f0f 100644 --- a/src/documents/permissions.py +++ b/src/documents/permissions.py @@ -233,11 +233,11 @@ class CanViewAISuggestionsPermission(BasePermission): def has_permission(self, request, view): if not request.user or not request.user.is_authenticated: return False - + # Superusers always have permission if request.user.is_superuser: return True - + # Check for specific permission return request.user.has_perm("documents.can_view_ai_suggestions") @@ -253,11 +253,11 @@ class CanApplyAISuggestionsPermission(BasePermission): def has_permission(self, request, view): if not request.user or not request.user.is_authenticated: return False - + # Superusers always have permission if request.user.is_superuser: return True - + # Check for specific permission return request.user.has_perm("documents.can_apply_ai_suggestions") @@ -273,11 +273,11 @@ class CanApproveDeletionsPermission(BasePermission): def has_permission(self, request, view): if not request.user or not request.user.is_authenticated: return False - + # Superusers always have permission if request.user.is_superuser: return True - + # Check for specific permission return request.user.has_perm("documents.can_approve_deletions") @@ -294,10 +294,10 @@ class CanConfigureAIPermission(BasePermission): def has_permission(self, request, view): if not request.user or not request.user.is_authenticated: return False - + # Superusers always have permission if request.user.is_superuser: return True - + # Check for specific permission return request.user.has_perm("documents.can_configure_ai") diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 36ecc4b37..599d0cbda 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -46,7 +46,6 @@ if settings.AUDIT_LOG_ENABLED: from documents import bulk_edit from documents.data_models import DocumentSource from documents.filters import CustomFieldQueryParser -from documents.models import AISuggestionFeedback from documents.models import Correspondent from documents.models import CustomField from documents.models import CustomFieldInstance @@ -2786,57 +2785,57 @@ class DeletionRequestSerializer(serializers.ModelSerializer): class DeletionRequestDetailSerializer(serializers.ModelSerializer): """Serializer for DeletionRequest model with document details.""" - + document_details = serializers.SerializerMethodField() - user_username = serializers.CharField(source='user.username', read_only=True) + user_username = serializers.CharField(source="user.username", read_only=True) reviewed_by_username = serializers.CharField( - source='reviewed_by.username', + source="reviewed_by.username", read_only=True, allow_null=True, ) - + class Meta: from documents.models import DeletionRequest model = DeletionRequest fields = [ - 'id', - 'created_at', - 'updated_at', - 'requested_by_ai', - 'ai_reason', - 'user', - 'user_username', - 'status', - 'impact_summary', - 'reviewed_at', - 'reviewed_by', - 'reviewed_by_username', - 'review_comment', - 'completed_at', - 'completion_details', - 'document_details', + "id", + "created_at", + "updated_at", + "requested_by_ai", + "ai_reason", + "user", + "user_username", + "status", + "impact_summary", + "reviewed_at", + "reviewed_by", + "reviewed_by_username", + "review_comment", + "completed_at", + "completion_details", + "document_details", ] read_only_fields = [ - 'id', - 'created_at', - 'updated_at', - 'reviewed_at', - 'reviewed_by', - 'completed_at', - 'completion_details', + "id", + "created_at", + "updated_at", + "reviewed_at", + "reviewed_by", + "completed_at", + "completion_details", ] - + def get_document_details(self, obj): """Get details of documents in this deletion request.""" documents = obj.documents.all() return [ { - 'id': doc.id, - 'title': doc.title, - 'created': doc.created.isoformat() if doc.created else None, - 'correspondent': doc.correspondent.name if doc.correspondent else None, - 'document_type': doc.document_type.name if doc.document_type else None, - 'tags': [tag.name for tag in doc.tags.all()], + "id": doc.id, + "title": doc.title, + "created": doc.created.isoformat() if doc.created else None, + "correspondent": doc.correspondent.name if doc.correspondent else None, + "document_type": doc.document_type.name if doc.document_type else None, + "tags": [tag.name for tag in doc.tags.all()], } for doc in documents ] @@ -2852,6 +2851,9 @@ class DeletionRequestActionSerializer(serializers.Serializer): allow_blank=True, label="Review Comment", help_text="Optional comment when reviewing the deletion request", + ) + + class AISuggestionsRequestSerializer(serializers.Serializer): """Serializer for requesting AI suggestions for a document.""" diff --git a/src/documents/serializers/__init__.py b/src/documents/serializers/__init__.py index 3c6543214..5baa5d978 100644 --- a/src/documents/serializers/__init__.py +++ b/src/documents/serializers/__init__.py @@ -1,17 +1,15 @@ """Serializers package for documents app.""" -from .ai_suggestions import ( - AISuggestionFeedbackSerializer, - AISuggestionsSerializer, - AISuggestionStatsSerializer, - ApplySuggestionSerializer, - RejectSuggestionSerializer, -) +from .ai_suggestions import AISuggestionFeedbackSerializer +from .ai_suggestions import AISuggestionsSerializer +from .ai_suggestions import AISuggestionStatsSerializer +from .ai_suggestions import ApplySuggestionSerializer +from .ai_suggestions import RejectSuggestionSerializer __all__ = [ - 'AISuggestionFeedbackSerializer', - 'AISuggestionsSerializer', - 'AISuggestionStatsSerializer', - 'ApplySuggestionSerializer', - 'RejectSuggestionSerializer', + "AISuggestionFeedbackSerializer", + "AISuggestionStatsSerializer", + "AISuggestionsSerializer", + "ApplySuggestionSerializer", + "RejectSuggestionSerializer", ] diff --git a/src/documents/serializers/ai_suggestions.py b/src/documents/serializers/ai_suggestions.py index f793482de..52db7d1bd 100644 --- a/src/documents/serializers/ai_suggestions.py +++ b/src/documents/serializers/ai_suggestions.py @@ -7,42 +7,39 @@ and handling user feedback on AI suggestions. from __future__ import annotations -from typing import Any, Dict +from typing import Any from rest_framework import serializers -from documents.models import ( - AISuggestionFeedback, - Correspondent, - CustomField, - DocumentType, - StoragePath, - Tag, - Workflow, -) - +from documents.models import AISuggestionFeedback +from documents.models import Correspondent +from documents.models import CustomField +from documents.models import DocumentType +from documents.models import StoragePath +from documents.models import Tag +from documents.models import Workflow # Suggestion type choices - used across multiple serializers SUGGESTION_TYPE_CHOICES = [ - 'tag', - 'correspondent', - 'document_type', - 'storage_path', - 'custom_field', - 'workflow', - 'title', + "tag", + "correspondent", + "document_type", + "storage_path", + "custom_field", + "workflow", + "title", ] # Types that require value_id -ID_REQUIRED_TYPES = ['tag', 'correspondent', 'document_type', 'storage_path', 'workflow'] +ID_REQUIRED_TYPES = ["tag", "correspondent", "document_type", "storage_path", "workflow"] # Types that require value_text -TEXT_REQUIRED_TYPES = ['title'] +TEXT_REQUIRED_TYPES = ["title"] # Types that can use either (custom_field can be ID or text) class TagSuggestionSerializer(serializers.Serializer): """Serializer for tag suggestions.""" - + id = serializers.IntegerField() name = serializers.CharField() color = serializers.CharField() @@ -51,7 +48,7 @@ class TagSuggestionSerializer(serializers.Serializer): class CorrespondentSuggestionSerializer(serializers.Serializer): """Serializer for correspondent suggestions.""" - + id = serializers.IntegerField() name = serializers.CharField() confidence = serializers.FloatField() @@ -59,7 +56,7 @@ class CorrespondentSuggestionSerializer(serializers.Serializer): class DocumentTypeSuggestionSerializer(serializers.Serializer): """Serializer for document type suggestions.""" - + id = serializers.IntegerField() name = serializers.CharField() confidence = serializers.FloatField() @@ -67,7 +64,7 @@ class DocumentTypeSuggestionSerializer(serializers.Serializer): class StoragePathSuggestionSerializer(serializers.Serializer): """Serializer for storage path suggestions.""" - + id = serializers.IntegerField() name = serializers.CharField() path = serializers.CharField() @@ -76,7 +73,7 @@ class StoragePathSuggestionSerializer(serializers.Serializer): class CustomFieldSuggestionSerializer(serializers.Serializer): """Serializer for custom field suggestions.""" - + field_id = serializers.IntegerField() field_name = serializers.CharField() value = serializers.CharField() @@ -85,7 +82,7 @@ class CustomFieldSuggestionSerializer(serializers.Serializer): class WorkflowSuggestionSerializer(serializers.Serializer): """Serializer for workflow suggestions.""" - + id = serializers.IntegerField() name = serializers.CharField() confidence = serializers.FloatField() @@ -93,7 +90,7 @@ class WorkflowSuggestionSerializer(serializers.Serializer): class TitleSuggestionSerializer(serializers.Serializer): """Serializer for title suggestions.""" - + title = serializers.CharField() @@ -103,7 +100,7 @@ class AISuggestionsSerializer(serializers.Serializer): Converts AIScanResult objects to JSON format for API responses. """ - + tags = TagSuggestionSerializer(many=True, required=False) correspondent = CorrespondentSuggestionSerializer(required=False, allow_null=True) document_type = DocumentTypeSuggestionSerializer(required=False, allow_null=True) @@ -111,9 +108,9 @@ class AISuggestionsSerializer(serializers.Serializer): custom_fields = CustomFieldSuggestionSerializer(many=True, required=False) workflows = WorkflowSuggestionSerializer(many=True, required=False) title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True) - + @staticmethod - def from_scan_result(scan_result, document_id: int) -> Dict[str, Any]: + def from_scan_result(scan_result, document_id: int) -> dict[str, Any]: """ Convert an AIScanResult object to serializer data. @@ -125,7 +122,7 @@ class AISuggestionsSerializer(serializers.Serializer): Dictionary ready for serialization """ data = {} - + # Tags if scan_result.tags: tag_suggestions = [] @@ -133,59 +130,59 @@ class AISuggestionsSerializer(serializers.Serializer): try: tag = Tag.objects.get(pk=tag_id) tag_suggestions.append({ - 'id': tag.id, - 'name': tag.name, - 'color': getattr(tag, 'color', '#000000'), - 'confidence': confidence, + "id": tag.id, + "name": tag.name, + "color": getattr(tag, "color", "#000000"), + "confidence": confidence, }) except Tag.DoesNotExist: # Tag no longer exists in database; skip this suggestion pass - data['tags'] = tag_suggestions - + data["tags"] = tag_suggestions + # Correspondent if scan_result.correspondent: corr_id, confidence = scan_result.correspondent try: correspondent = Correspondent.objects.get(pk=corr_id) - data['correspondent'] = { - 'id': correspondent.id, - 'name': correspondent.name, - 'confidence': confidence, + data["correspondent"] = { + "id": correspondent.id, + "name": correspondent.name, + "confidence": confidence, } except Correspondent.DoesNotExist: # Correspondent no longer exists in database; omit from suggestions pass - + # Document Type if scan_result.document_type: type_id, confidence = scan_result.document_type try: doc_type = DocumentType.objects.get(pk=type_id) - data['document_type'] = { - 'id': doc_type.id, - 'name': doc_type.name, - 'confidence': confidence, + data["document_type"] = { + "id": doc_type.id, + "name": doc_type.name, + "confidence": confidence, } except DocumentType.DoesNotExist: # Document type no longer exists in database; omit from suggestions pass - + # Storage Path if scan_result.storage_path: path_id, confidence = scan_result.storage_path try: storage_path = StoragePath.objects.get(pk=path_id) - data['storage_path'] = { - 'id': storage_path.id, - 'name': storage_path.name, - 'path': storage_path.path, - 'confidence': confidence, + data["storage_path"] = { + "id": storage_path.id, + "name": storage_path.name, + "path": storage_path.path, + "confidence": confidence, } except StoragePath.DoesNotExist: # Storage path no longer exists in database; omit from suggestions pass - + # Custom Fields if scan_result.custom_fields: field_suggestions = [] @@ -193,16 +190,16 @@ class AISuggestionsSerializer(serializers.Serializer): try: field = CustomField.objects.get(pk=field_id) field_suggestions.append({ - 'field_id': field.id, - 'field_name': field.name, - 'value': str(value), - 'confidence': confidence, + "field_id": field.id, + "field_name": field.name, + "value": str(value), + "confidence": confidence, }) except CustomField.DoesNotExist: # Custom field no longer exists in database; skip this suggestion pass - data['custom_fields'] = field_suggestions - + data["custom_fields"] = field_suggestions + # Workflows if scan_result.workflows: workflow_suggestions = [] @@ -210,21 +207,21 @@ class AISuggestionsSerializer(serializers.Serializer): try: workflow = Workflow.objects.get(pk=workflow_id) workflow_suggestions.append({ - 'id': workflow.id, - 'name': workflow.name, - 'confidence': confidence, + "id": workflow.id, + "name": workflow.name, + "confidence": confidence, }) except Workflow.DoesNotExist: # Workflow no longer exists in database; skip this suggestion pass - data['workflows'] = workflow_suggestions - + data["workflows"] = workflow_suggestions + # Title suggestion if scan_result.title_suggestion: - data['title_suggestion'] = { - 'title': scan_result.title_suggestion, + data["title_suggestion"] = { + "title": scan_result.title_suggestion, } - + return data @@ -234,28 +231,28 @@ class SuggestionSerializerMixin: """ def validate(self, attrs): """Validate that the correct value field is provided for the suggestion type.""" - suggestion_type = attrs.get('suggestion_type') - value_id = attrs.get('value_id') - value_text = attrs.get('value_text') - + suggestion_type = attrs.get("suggestion_type") + value_id = attrs.get("value_id") + value_text = attrs.get("value_text") + # Types that require value_id if suggestion_type in ID_REQUIRED_TYPES and not value_id: raise serializers.ValidationError( - f"value_id is required for suggestion_type '{suggestion_type}'" + f"value_id is required for suggestion_type '{suggestion_type}'", ) - + # Types that require value_text if suggestion_type in TEXT_REQUIRED_TYPES and not value_text: raise serializers.ValidationError( - f"value_text is required for suggestion_type '{suggestion_type}'" + f"value_text is required for suggestion_type '{suggestion_type}'", ) - + # For custom_field, either is acceptable - if suggestion_type == 'custom_field' and not value_id and not value_text: + if suggestion_type == "custom_field" and not value_id and not value_text: raise serializers.ValidationError( - "Either value_id or value_text must be provided for custom_field" + "Either value_id or value_text must be provided for custom_field", ) - + return attrs @@ -263,12 +260,12 @@ class ApplySuggestionSerializer(SuggestionSerializerMixin, serializers.Serialize """ Serializer for applying AI suggestions. """ - + suggestion_type = serializers.ChoiceField( choices=SUGGESTION_TYPE_CHOICES, required=True, ) - + value_id = serializers.IntegerField(required=False, allow_null=True) value_text = serializers.CharField(required=False, allow_blank=True) confidence = serializers.FloatField(required=True) @@ -278,12 +275,12 @@ class RejectSuggestionSerializer(SuggestionSerializerMixin, serializers.Serializ """ Serializer for rejecting AI suggestions. """ - + suggestion_type = serializers.ChoiceField( choices=SUGGESTION_TYPE_CHOICES, required=True, ) - + value_id = serializers.IntegerField(required=False, allow_null=True) value_text = serializers.CharField(required=False, allow_blank=True) confidence = serializers.FloatField(required=True) @@ -291,41 +288,41 @@ class RejectSuggestionSerializer(SuggestionSerializerMixin, serializers.Serializ class AISuggestionFeedbackSerializer(serializers.ModelSerializer): """Serializer for AI suggestion feedback model.""" - + class Meta: model = AISuggestionFeedback fields = [ - 'id', - 'document', - 'suggestion_type', - 'suggested_value_id', - 'suggested_value_text', - 'confidence', - 'status', - 'user', - 'created_at', - 'applied_at', - 'metadata', + "id", + "document", + "suggestion_type", + "suggested_value_id", + "suggested_value_text", + "confidence", + "status", + "user", + "created_at", + "applied_at", + "metadata", ] - read_only_fields = ['id', 'created_at', 'applied_at'] + read_only_fields = ["id", "created_at", "applied_at"] class AISuggestionStatsSerializer(serializers.Serializer): """ Serializer for AI suggestion accuracy statistics. """ - + total_suggestions = serializers.IntegerField() total_applied = serializers.IntegerField() total_rejected = serializers.IntegerField() accuracy_rate = serializers.FloatField() - + by_type = serializers.DictField( child=serializers.DictField(), help_text="Statistics broken down by suggestion type", ) - + average_confidence_applied = serializers.FloatField() average_confidence_rejected = serializers.FloatField() - + recent_suggestions = AISuggestionFeedbackSerializer(many=True, required=False) diff --git a/src/documents/tests/test_ai_deletion_manager.py b/src/documents/tests/test_ai_deletion_manager.py index 00b635ff5..deb6e8d73 100644 --- a/src/documents/tests/test_ai_deletion_manager.py +++ b/src/documents/tests/test_ai_deletion_manager.py @@ -18,13 +18,11 @@ from django.test import TestCase from django.utils import timezone from documents.ai_deletion_manager import AIDeletionManager -from documents.models import ( - Correspondent, - DeletionRequest, - Document, - DocumentType, - Tag, -) +from documents.models import Correspondent +from documents.models import DeletionRequest +from documents.models import Document +from documents.models import DocumentType +from documents.models import Tag class TestAIDeletionManagerCreateRequest(TestCase): @@ -33,13 +31,13 @@ class TestAIDeletionManagerCreateRequest(TestCase): def setUp(self): """Set up test data.""" self.user = User.objects.create_user(username="testuser", password="testpass") - + # Create test documents with various metadata self.correspondent = Correspondent.objects.create(name="Test Corp") self.doc_type = DocumentType.objects.create(name="Invoice") self.tag1 = Tag.objects.create(name="Important") self.tag2 = Tag.objects.create(name="2024") - + self.doc1 = Document.objects.create( title="Test Document 1", content="Test content 1", @@ -49,7 +47,7 @@ class TestAIDeletionManagerCreateRequest(TestCase): document_type=self.doc_type, ) self.doc1.tags.add(self.tag1, self.tag2) - + self.doc2 = Document.objects.create( title="Test Document 2", content="Test content 2", @@ -63,13 +61,13 @@ class TestAIDeletionManagerCreateRequest(TestCase): """Test creating a basic deletion request.""" documents = [self.doc1, self.doc2] reason = "Duplicate documents detected" - + request = AIDeletionManager.create_deletion_request( documents=documents, reason=reason, user=self.user, ) - + self.assertIsNotNone(request) self.assertIsInstance(request, DeletionRequest) self.assertEqual(request.ai_reason, reason) @@ -82,13 +80,13 @@ class TestAIDeletionManagerCreateRequest(TestCase): """Test that deletion request includes impact analysis.""" documents = [self.doc1, self.doc2] reason = "Test deletion" - + request = AIDeletionManager.create_deletion_request( documents=documents, reason=reason, user=self.user, ) - + impact = request.impact_summary self.assertIsNotNone(impact) self.assertEqual(impact["document_count"], 2) @@ -106,14 +104,14 @@ class TestAIDeletionManagerCreateRequest(TestCase): "document_count": 1, "custom_field": "custom_value", } - + request = AIDeletionManager.create_deletion_request( documents=documents, reason=reason, user=self.user, impact_analysis=custom_impact, ) - + self.assertEqual(request.impact_summary, custom_impact) self.assertEqual(request.impact_summary["custom_field"], "custom_value") @@ -121,13 +119,13 @@ class TestAIDeletionManagerCreateRequest(TestCase): """Test creating request with empty document list.""" documents = [] reason = "Test deletion" - + request = AIDeletionManager.create_deletion_request( documents=documents, reason=reason, user=self.user, ) - + self.assertIsNotNone(request) self.assertEqual(request.documents.count(), 0) self.assertEqual(request.impact_summary["document_count"], 0) @@ -157,9 +155,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase): document_type=self.doc_type1, ) doc.tags.add(self.tag1, self.tag2) - + impact = AIDeletionManager._analyze_impact([doc]) - + self.assertEqual(impact["document_count"], 1) self.assertEqual(len(impact["documents"]), 1) self.assertEqual(impact["documents"][0]["id"], doc.id) @@ -180,7 +178,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase): document_type=self.doc_type1, ) doc1.tags.add(self.tag1) - + doc2 = Document.objects.create( title="Document 2", content="Content 2", @@ -190,9 +188,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase): document_type=self.doc_type2, ) doc2.tags.add(self.tag2, self.tag3) - + impact = AIDeletionManager._analyze_impact([doc1, doc2]) - + self.assertEqual(impact["document_count"], 2) self.assertEqual(len(impact["documents"]), 2) self.assertIn("Corp A", impact["affected_correspondents"]) @@ -209,9 +207,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase): checksum="checksum1", mime_type="application/pdf", ) - + impact = AIDeletionManager._analyze_impact([doc]) - + self.assertEqual(impact["document_count"], 1) self.assertEqual(impact["documents"][0]["correspondent"], None) self.assertEqual(impact["documents"][0]["document_type"], None) @@ -232,7 +230,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase): # Force set the created date to an earlier time doc1.created = timezone.make_aware(datetime(2023, 1, 1)) doc1.save() - + doc2 = Document.objects.create( title="New Document", content="Content", @@ -241,9 +239,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase): ) doc2.created = timezone.make_aware(datetime(2024, 12, 31)) doc2.save() - + impact = AIDeletionManager._analyze_impact([doc1, doc2]) - + self.assertIsNotNone(impact["date_range"]["earliest"]) self.assertIsNotNone(impact["date_range"]["latest"]) # Check that dates are ISO formatted strings @@ -253,7 +251,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase): def test_analyze_impact_empty_list(self): """Test impact analysis with empty document list.""" impact = AIDeletionManager._analyze_impact([]) - + self.assertEqual(impact["document_count"], 0) self.assertEqual(len(impact["documents"]), 0) self.assertEqual(len(impact["affected_correspondents"]), 0) @@ -267,11 +265,11 @@ class TestAIDeletionManagerFormatRequest(TestCase): def setUp(self): """Set up test data.""" self.user = User.objects.create_user(username="testuser", password="testpass") - + self.correspondent = Correspondent.objects.create(name="Test Corp") self.doc_type = DocumentType.objects.create(name="Invoice") self.tag = Tag.objects.create(name="Important") - + self.doc = Document.objects.create( title="Test Document", content="Test content", @@ -289,9 +287,9 @@ class TestAIDeletionManagerFormatRequest(TestCase): reason="Test reason for deletion", user=self.user, ) - + message = AIDeletionManager.format_deletion_request_for_user(request) - + self.assertIsInstance(message, str) self.assertIn("AI DELETION REQUEST", message) self.assertIn("Test reason for deletion", message) @@ -306,15 +304,15 @@ class TestAIDeletionManagerFormatRequest(TestCase): checksum="checksum2", mime_type="application/pdf", ) - + request = AIDeletionManager.create_deletion_request( documents=[self.doc, doc2], reason="Multiple documents", user=self.user, ) - + message = AIDeletionManager.format_deletion_request_for_user(request) - + self.assertIn("Number of documents: 2", message) self.assertIn("Test Corp", message) self.assertIn("Invoice", message) @@ -328,15 +326,15 @@ class TestAIDeletionManagerFormatRequest(TestCase): checksum="checksum1", mime_type="application/pdf", ) - + request = AIDeletionManager.create_deletion_request( documents=[doc], reason="Test deletion", user=self.user, ) - + message = AIDeletionManager.format_deletion_request_for_user(request) - + self.assertIn("Basic Document", message) self.assertIn("None", message) # Should show None for missing metadata @@ -347,9 +345,9 @@ class TestAIDeletionManagerFormatRequest(TestCase): reason="Test", user=self.user, ) - + message = AIDeletionManager.format_deletion_request_for_user(request) - + self.assertIn("explicit approval", message.lower()) self.assertIn("no files will be deleted until you confirm", message.lower()) @@ -361,7 +359,7 @@ class TestAIDeletionManagerGetPendingRequests(TestCase): """Set up test data.""" self.user1 = User.objects.create_user(username="user1", password="pass1") self.user2 = User.objects.create_user(username="user2", password="pass2") - + self.doc = Document.objects.create( title="Test Document", content="Content", @@ -382,16 +380,16 @@ class TestAIDeletionManagerGetPendingRequests(TestCase): reason="Reason 2", user=self.user1, ) - + # Create request for user2 AIDeletionManager.create_deletion_request( documents=[self.doc], reason="Reason 3", user=self.user2, ) - + pending = AIDeletionManager.get_pending_requests(self.user1) - + self.assertEqual(len(pending), 2) self.assertIn(req1, pending) self.assertIn(req2, pending) @@ -408,12 +406,12 @@ class TestAIDeletionManagerGetPendingRequests(TestCase): reason="Reason 2", user=self.user1, ) - + # Approve one request req1.approve(self.user1, "Approved") - + pending = AIDeletionManager.get_pending_requests(self.user1) - + self.assertEqual(len(pending), 1) self.assertNotIn(req1, pending) self.assertIn(req2, pending) @@ -430,12 +428,12 @@ class TestAIDeletionManagerGetPendingRequests(TestCase): reason="Reason 2", user=self.user1, ) - + # Reject one request req1.reject(self.user1, "Rejected") - + pending = AIDeletionManager.get_pending_requests(self.user1) - + self.assertEqual(len(pending), 1) self.assertNotIn(req1, pending) self.assertIn(req2, pending) @@ -443,7 +441,7 @@ class TestAIDeletionManagerGetPendingRequests(TestCase): def test_get_pending_requests_empty(self): """Test getting pending requests when none exist.""" pending = AIDeletionManager.get_pending_requests(self.user1) - + self.assertEqual(len(pending), 0) @@ -454,7 +452,7 @@ class TestAIDeletionManagerSecurityConstraints(TestCase): """Test that AI can never delete automatically.""" # This is a critical security test result = AIDeletionManager.can_ai_delete_automatically() - + self.assertFalse(result) def test_deletion_request_requires_pending_status(self): @@ -466,13 +464,13 @@ class TestAIDeletionManagerSecurityConstraints(TestCase): checksum="checksum1", mime_type="application/pdf", ) - + request = AIDeletionManager.create_deletion_request( documents=[doc], reason="Test", user=user, ) - + self.assertEqual(request.status, DeletionRequest.STATUS_PENDING) def test_deletion_request_marked_as_ai_initiated(self): @@ -484,13 +482,13 @@ class TestAIDeletionManagerSecurityConstraints(TestCase): checksum="checksum1", mime_type="application/pdf", ) - + request = AIDeletionManager.create_deletion_request( documents=[doc], reason="Test", user=user, ) - + self.assertTrue(request.requested_by_ai) @@ -501,7 +499,7 @@ class TestAIDeletionManagerWorkflow(TestCase): """Set up test data.""" self.user = User.objects.create_user(username="testuser", password="pass") self.approver = User.objects.create_user(username="approver", password="pass") - + self.doc1 = Document.objects.create( title="Document 1", content="Content 1", @@ -523,14 +521,14 @@ class TestAIDeletionManagerWorkflow(TestCase): reason="Duplicates detected", user=self.user, ) - + self.assertEqual(request.status, DeletionRequest.STATUS_PENDING) self.assertIsNone(request.reviewed_at) self.assertIsNone(request.reviewed_by) - + # Step 2: Approve request success = request.approve(self.approver, "Looks good") - + self.assertTrue(success) self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED) self.assertIsNotNone(request.reviewed_at) @@ -545,12 +543,12 @@ class TestAIDeletionManagerWorkflow(TestCase): reason="Should be deleted", user=self.user, ) - + self.assertEqual(request.status, DeletionRequest.STATUS_PENDING) - + # Step 2: Reject request success = request.reject(self.approver, "Not a duplicate") - + self.assertTrue(success) self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED) self.assertIsNotNone(request.reviewed_at) @@ -564,14 +562,14 @@ class TestAIDeletionManagerWorkflow(TestCase): reason="Test deletion", user=self.user, ) - + # Record initial state created_at = request.created_at self.assertIsNotNone(created_at) - + # Approve request.approve(self.approver, "Approved") - + # Verify audit trail self.assertIsNotNone(request.created_at) self.assertIsNotNone(request.updated_at) diff --git a/src/documents/tests/test_ai_permissions.py b/src/documents/tests/test_ai_permissions.py index f8266b2cd..39463b47e 100644 --- a/src/documents/tests/test_ai_permissions.py +++ b/src/documents/tests/test_ai_permissions.py @@ -10,24 +10,23 @@ Tests cover: - Permission assignment and verification """ -from django.contrib.auth.models import Group, Permission, User +from django.contrib.auth.models import Group +from django.contrib.auth.models import Permission +from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from django.test import TestCase from rest_framework.test import APIRequestFactory from documents.models import Document -from documents.permissions import ( - CanApplyAISuggestionsPermission, - CanApproveDeletionsPermission, - CanConfigureAIPermission, - CanViewAISuggestionsPermission, -) +from documents.permissions import CanApplyAISuggestionsPermission +from documents.permissions import CanApproveDeletionsPermission +from documents.permissions import CanConfigureAIPermission +from documents.permissions import CanViewAISuggestionsPermission class MockView: """Mock view for testing permissions.""" - pass class TestCanViewAISuggestionsPermission(TestCase): @@ -41,13 +40,13 @@ class TestCanViewAISuggestionsPermission(TestCase): # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.regular_user = User.objects.create_user( - username="regular", email="regular@test.com", password="regular123" + username="regular", email="regular@test.com", password="regular123", ) self.permitted_user = User.objects.create_user( - username="permitted", email="permitted@test.com", password="permitted123" + username="permitted", email="permitted@test.com", password="permitted123", ) # Assign permission to permitted_user @@ -107,13 +106,13 @@ class TestCanApplyAISuggestionsPermission(TestCase): # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.regular_user = User.objects.create_user( - username="regular", email="regular@test.com", password="regular123" + username="regular", email="regular@test.com", password="regular123", ) self.permitted_user = User.objects.create_user( - username="permitted", email="permitted@test.com", password="permitted123" + username="permitted", email="permitted@test.com", password="permitted123", ) # Assign permission to permitted_user @@ -173,13 +172,13 @@ class TestCanApproveDeletionsPermission(TestCase): # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.regular_user = User.objects.create_user( - username="regular", email="regular@test.com", password="regular123" + username="regular", email="regular@test.com", password="regular123", ) self.permitted_user = User.objects.create_user( - username="permitted", email="permitted@test.com", password="permitted123" + username="permitted", email="permitted@test.com", password="permitted123", ) # Assign permission to permitted_user @@ -239,13 +238,13 @@ class TestCanConfigureAIPermission(TestCase): # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.regular_user = User.objects.create_user( - username="regular", email="regular@test.com", password="regular123" + username="regular", email="regular@test.com", password="regular123", ) self.permitted_user = User.objects.create_user( - username="permitted", email="permitted@test.com", password="permitted123" + username="permitted", email="permitted@test.com", password="permitted123", ) # Assign permission to permitted_user @@ -345,7 +344,7 @@ class TestRoleBasedAccessControl(TestCase): def test_viewer_role_permissions(self): """Test that viewer role has appropriate permissions.""" user = User.objects.create_user( - username="viewer", email="viewer@test.com", password="viewer123" + username="viewer", email="viewer@test.com", password="viewer123", ) user.groups.add(self.viewer_group) @@ -360,7 +359,7 @@ class TestRoleBasedAccessControl(TestCase): def test_editor_role_permissions(self): """Test that editor role has appropriate permissions.""" user = User.objects.create_user( - username="editor", email="editor@test.com", password="editor123" + username="editor", email="editor@test.com", password="editor123", ) user.groups.add(self.editor_group) @@ -375,7 +374,7 @@ class TestRoleBasedAccessControl(TestCase): def test_admin_role_permissions(self): """Test that admin role has all permissions.""" user = User.objects.create_user( - username="ai_admin", email="ai_admin@test.com", password="admin123" + username="ai_admin", email="ai_admin@test.com", password="admin123", ) user.groups.add(self.admin_group) @@ -390,7 +389,7 @@ class TestRoleBasedAccessControl(TestCase): def test_user_with_multiple_groups(self): """Test that user permissions accumulate from multiple groups.""" user = User.objects.create_user( - username="multi_role", email="multi@test.com", password="multi123" + username="multi_role", email="multi@test.com", password="multi123", ) user.groups.add(self.viewer_group, self.editor_group) @@ -405,7 +404,7 @@ class TestRoleBasedAccessControl(TestCase): def test_direct_permission_assignment_overrides_group(self): """Test that direct permission assignment works alongside group permissions.""" user = User.objects.create_user( - username="special", email="special@test.com", password="special123" + username="special", email="special@test.com", password="special123", ) user.groups.add(self.viewer_group) @@ -428,7 +427,7 @@ class TestPermissionAssignment(TestCase): def setUp(self): """Set up test user.""" self.user = User.objects.create_user( - username="testuser", email="test@test.com", password="test123" + username="testuser", email="test@test.com", password="test123", ) content_type = ContentType.objects.get_for_model(Document) self.view_permission, _ = Permission.objects.get_or_create( @@ -500,7 +499,7 @@ class TestPermissionEdgeCases(TestCase): def test_inactive_user_with_permission(self): """Test that inactive users are denied even with permission.""" user = User.objects.create_user( - username="inactive", email="inactive@test.com", password="inactive123" + username="inactive", email="inactive@test.com", password="inactive123", ) user.is_active = False user.save() diff --git a/src/documents/tests/test_ai_scanner.py b/src/documents/tests/test_ai_scanner.py index 77ed16343..307574bbf 100644 --- a/src/documents/tests/test_ai_scanner.py +++ b/src/documents/tests/test_ai_scanner.py @@ -21,24 +21,21 @@ Tests cover: from unittest import mock from django.db import transaction -from django.test import TestCase, override_settings +from django.test import TestCase +from django.test import override_settings -from documents.ai_scanner import ( - AIScanResult, - AIDocumentScanner, - get_ai_scanner, -) -from documents.models import ( - Correspondent, - CustomField, - Document, - DocumentType, - StoragePath, - Tag, - Workflow, - WorkflowTrigger, - WorkflowAction, -) +from documents.ai_scanner import AIDocumentScanner +from documents.ai_scanner import AIScanResult +from documents.ai_scanner import get_ai_scanner +from documents.models import Correspondent +from documents.models import CustomField +from documents.models import Document +from documents.models import DocumentType +from documents.models import StoragePath +from documents.models import Tag +from documents.models import Workflow +from documents.models import WorkflowAction +from documents.models import WorkflowTrigger class TestAIScanResult(TestCase): @@ -47,7 +44,7 @@ class TestAIScanResult(TestCase): def test_init_creates_empty_result(self): """Test that AIScanResult initializes with empty structures.""" result = AIScanResult() - + self.assertEqual(result.tags, []) self.assertIsNone(result.correspondent) self.assertIsNone(result.document_type) @@ -70,9 +67,9 @@ class TestAIScanResult(TestCase): result.extracted_entities = {"persons": ["John Doe"], "organizations": ["ACME Corp"]} result.title_suggestion = "Invoice - ACME Corp - 2024-01-01" result.metadata = {"tables": [{"data": "test"}]} - + result_dict = result.to_dict() - + self.assertEqual(result_dict["tags"], [(1, 0.85), (2, 0.75)]) self.assertEqual(result_dict["correspondent"], (1, 0.90)) self.assertEqual(result_dict["document_type"], (2, 0.80)) @@ -90,7 +87,7 @@ class TestAIDocumentScannerInitialization(TestCase): def test_init_with_defaults(self): """Test scanner initialization with default parameters.""" scanner = AIDocumentScanner() - + self.assertEqual(scanner.auto_apply_threshold, 0.80) self.assertEqual(scanner.suggest_threshold, 0.60) self.assertTrue(scanner.ml_enabled) @@ -100,9 +97,9 @@ class TestAIDocumentScannerInitialization(TestCase): """Test scanner initialization with custom confidence thresholds.""" scanner = AIDocumentScanner( auto_apply_threshold=0.90, - suggest_threshold=0.70 + suggest_threshold=0.70, ) - + self.assertEqual(scanner.auto_apply_threshold, 0.90) self.assertEqual(scanner.suggest_threshold, 0.70) @@ -110,32 +107,32 @@ class TestAIDocumentScannerInitialization(TestCase): def test_init_respects_ml_disabled_setting(self): """Test that ML features can be disabled via settings.""" scanner = AIDocumentScanner() - + self.assertFalse(scanner.ml_enabled) def test_init_with_explicit_ml_override(self): """Test explicit ML feature override.""" scanner = AIDocumentScanner(enable_ml_features=False) - + self.assertFalse(scanner.ml_enabled) @override_settings(PAPERLESS_ENABLE_ADVANCED_OCR=False) def test_init_respects_ocr_disabled_setting(self): """Test that advanced OCR can be disabled via settings.""" scanner = AIDocumentScanner() - + self.assertFalse(scanner.advanced_ocr_enabled) def test_init_with_explicit_ocr_override(self): """Test explicit OCR feature override.""" scanner = AIDocumentScanner(enable_advanced_ocr=False) - + self.assertFalse(scanner.advanced_ocr_enabled) def test_lazy_loading_components_not_initialized(self): """Test that ML components are not initialized at construction.""" scanner = AIDocumentScanner() - + self.assertIsNone(scanner._classifier) self.assertIsNone(scanner._ner_extractor) self.assertIsNone(scanner._semantic_search) @@ -145,45 +142,45 @@ class TestAIDocumentScannerInitialization(TestCase): class TestAIDocumentScannerLazyLoading(TestCase): """Test lazy loading of ML components.""" - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_get_classifier_loads_successfully(self, mock_logger): """Test successful lazy loading of classifier.""" scanner = AIDocumentScanner() - + # Mock the import and class mock_classifier_instance = mock.MagicMock() - with mock.patch('documents.ai_scanner.TransformerDocumentClassifier', + with mock.patch("documents.ai_scanner.TransformerDocumentClassifier", return_value=mock_classifier_instance) as mock_classifier_class: classifier = scanner._get_classifier() - + self.assertIsNotNone(classifier) self.assertEqual(classifier, mock_classifier_instance) mock_classifier_class.assert_called_once() mock_logger.info.assert_called_with("ML classifier loaded successfully") - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_get_classifier_returns_cached_instance(self, mock_logger): """Test that classifier is only loaded once.""" scanner = AIDocumentScanner() - + mock_classifier_instance = mock.MagicMock() - with mock.patch('documents.ai_scanner.TransformerDocumentClassifier', + with mock.patch("documents.ai_scanner.TransformerDocumentClassifier", return_value=mock_classifier_instance): classifier1 = scanner._get_classifier() classifier2 = scanner._get_classifier() - + self.assertEqual(classifier1, classifier2) self.assertIs(classifier1, classifier2) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_get_classifier_handles_import_error(self, mock_logger): """Test that classifier loading handles import errors gracefully.""" scanner = AIDocumentScanner() - - with mock.patch('documents.ai_scanner.TransformerDocumentClassifier', + + with mock.patch("documents.ai_scanner.TransformerDocumentClassifier", side_effect=ImportError("Module not found")): classifier = scanner._get_classifier() - + self.assertIsNone(classifier) self.assertFalse(scanner.ml_enabled) mock_logger.warning.assert_called() @@ -191,63 +188,63 @@ class TestAIDocumentScannerLazyLoading(TestCase): def test_get_classifier_returns_none_when_ml_disabled(self): """Test that classifier returns None when ML is disabled.""" scanner = AIDocumentScanner(enable_ml_features=False) - + classifier = scanner._get_classifier() - + self.assertIsNone(classifier) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_get_ner_extractor_loads_successfully(self, mock_logger): """Test successful lazy loading of NER extractor.""" scanner = AIDocumentScanner() - + mock_ner_instance = mock.MagicMock() - with mock.patch('documents.ai_scanner.DocumentNER', + with mock.patch("documents.ai_scanner.DocumentNER", return_value=mock_ner_instance) as mock_ner_class: ner = scanner._get_ner_extractor() - + self.assertIsNotNone(ner) self.assertEqual(ner, mock_ner_instance) mock_ner_class.assert_called_once() mock_logger.info.assert_called_with("NER extractor loaded successfully") - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_get_ner_extractor_handles_error(self, mock_logger): """Test NER extractor handles loading errors.""" scanner = AIDocumentScanner() - - with mock.patch('documents.ai_scanner.DocumentNER', + + with mock.patch("documents.ai_scanner.DocumentNER", side_effect=Exception("Failed to load")): ner = scanner._get_ner_extractor() - + self.assertIsNone(ner) mock_logger.warning.assert_called() - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_get_semantic_search_loads_successfully(self, mock_logger): """Test successful lazy loading of semantic search.""" scanner = AIDocumentScanner() - + mock_search_instance = mock.MagicMock() - with mock.patch('documents.ai_scanner.SemanticSearch', + with mock.patch("documents.ai_scanner.SemanticSearch", return_value=mock_search_instance) as mock_search_class: search = scanner._get_semantic_search() - + self.assertIsNotNone(search) self.assertEqual(search, mock_search_instance) mock_search_class.assert_called_once() mock_logger.info.assert_called_with("Semantic search loaded successfully") - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_get_table_extractor_loads_successfully(self, mock_logger): """Test successful lazy loading of table extractor.""" scanner = AIDocumentScanner() - + mock_extractor_instance = mock.MagicMock() - with mock.patch('documents.ai_scanner.TableExtractor', + with mock.patch("documents.ai_scanner.TableExtractor", return_value=mock_extractor_instance) as mock_extractor_class: extractor = scanner._get_table_extractor() - + self.assertIsNotNone(extractor) self.assertEqual(extractor, mock_extractor_instance) mock_extractor_class.assert_called_once() @@ -256,9 +253,9 @@ class TestAIDocumentScannerLazyLoading(TestCase): def test_get_table_extractor_returns_none_when_ocr_disabled(self): """Test that table extractor returns None when OCR is disabled.""" scanner = AIDocumentScanner(enable_advanced_ocr=False) - + extractor = scanner._get_table_extractor() - + self.assertIsNone(extractor) @@ -268,7 +265,7 @@ class TestExtractEntities(TestCase): def test_extract_entities_with_ner_available(self): """Test entity extraction when NER is available.""" scanner = AIDocumentScanner() - + mock_ner = mock.MagicMock() mock_ner.extract_all.return_value = { "persons": ["John Doe", "Jane Smith"], @@ -276,16 +273,16 @@ class TestExtractEntities(TestCase): "dates": ["2024-01-01", "2024-12-31"], "amounts": ["$1,000", "$500"], "locations": ["New York"], - "misc": ["Invoice#123"] + "misc": ["Invoice#123"], } - + scanner._ner_extractor = mock_ner - + entities = scanner._extract_entities("Sample document text") - + # Verify NER was called mock_ner.extract_all.assert_called_once_with("Sample document text") - + # Verify entities are converted to dict format self.assertIn("persons", entities) self.assertEqual(len(entities["persons"]), 2) @@ -295,17 +292,17 @@ class TestExtractEntities(TestCase): def test_extract_entities_converts_strings_to_dicts(self): """Test that string entities are converted to dict format.""" scanner = AIDocumentScanner() - + mock_ner = mock.MagicMock() mock_ner.extract_all.return_value = { "persons": ["John Doe"], # String format "organizations": [{"text": "ACME Corp", "confidence": 0.9}], # Already dict } - + scanner._ner_extractor = mock_ner - + entities = scanner._extract_entities("Sample text") - + # Verify string entities are converted self.assertEqual(entities["persons"][0], {"text": "John Doe"}) # Verify dict entities remain unchanged @@ -315,22 +312,22 @@ class TestExtractEntities(TestCase): """Test entity extraction when NER is not available.""" scanner = AIDocumentScanner() scanner._get_ner_extractor = mock.MagicMock(return_value=None) - + entities = scanner._extract_entities("Sample text") - + self.assertEqual(entities, {}) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_extract_entities_handles_exception(self, mock_logger): """Test that entity extraction handles exceptions gracefully.""" scanner = AIDocumentScanner() - + mock_ner = mock.MagicMock() mock_ner.extract_all.side_effect = Exception("NER failed") scanner._ner_extractor = mock_ner - + entities = scanner._extract_entities("Sample text") - + self.assertEqual(entities, {}) mock_logger.error.assert_called() @@ -345,67 +342,67 @@ class TestSuggestTags(TestCase): self.tag3 = Tag.objects.create(name="Tax", matching_algorithm=Tag.MATCH_AUTO) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) - @mock.patch('documents.ai_scanner.match_tags') + @mock.patch("documents.ai_scanner.match_tags") def test_suggest_tags_with_matched_tags(self, mock_match_tags): """Test tag suggestions from matching.""" scanner = AIDocumentScanner() mock_match_tags.return_value = [self.tag1, self.tag2] - + suggestions = scanner._suggest_tags( self.document, "Invoice from ACME Corp", - {} + {}, ) - + # Should suggest both matched tags self.assertEqual(len(suggestions), 2) tag_ids = [tag_id for tag_id, _ in suggestions] self.assertIn(self.tag1.id, tag_ids) self.assertIn(self.tag2.id, tag_ids) - + # Check confidence for _, confidence in suggestions: self.assertGreaterEqual(confidence, 0.6) - @mock.patch('documents.ai_scanner.match_tags') + @mock.patch("documents.ai_scanner.match_tags") def test_suggest_tags_with_organization_entities(self, mock_match_tags): """Test tag suggestions based on organization entities.""" scanner = AIDocumentScanner() mock_match_tags.return_value = [] - + entities = { - "organizations": [{"text": "ACME Corp"}] + "organizations": [{"text": "ACME Corp"}], } - + suggestions = scanner._suggest_tags(self.document, "text", entities) - + # Should suggest company tag based on organization tag_ids = [tag_id for tag_id, _ in suggestions] self.assertIn(self.tag2.id, tag_ids) - @mock.patch('documents.ai_scanner.match_tags') + @mock.patch("documents.ai_scanner.match_tags") def test_suggest_tags_removes_duplicates(self, mock_match_tags): """Test that duplicate tags keep highest confidence.""" scanner = AIDocumentScanner() mock_match_tags.return_value = [self.tag1] - + # Manually add same tag with different confidence scanner._suggest_tags(self.document, "text", {}) - + # Implementation should remove duplicates in actual code - @mock.patch('documents.ai_scanner.match_tags') - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.match_tags") + @mock.patch("documents.ai_scanner.logger") def test_suggest_tags_handles_exception(self, mock_logger, mock_match_tags): """Test tag suggestion handles exceptions.""" scanner = AIDocumentScanner() mock_match_tags.side_effect = Exception("Matching failed") - + suggestions = scanner._suggest_tags(self.document, "text", {}) - + self.assertEqual(suggestions, []) mock_logger.error.assert_called() @@ -417,66 +414,66 @@ class TestDetectCorrespondent(TestCase): """Set up test correspondents.""" self.correspondent1 = Correspondent.objects.create( name="ACME Corporation", - matching_algorithm=Correspondent.MATCH_AUTO + matching_algorithm=Correspondent.MATCH_AUTO, ) self.correspondent2 = Correspondent.objects.create( name="TechStart Inc", - matching_algorithm=Correspondent.MATCH_AUTO + matching_algorithm=Correspondent.MATCH_AUTO, ) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) - @mock.patch('documents.ai_scanner.match_correspondents') + @mock.patch("documents.ai_scanner.match_correspondents") def test_detect_correspondent_with_match(self, mock_match): """Test correspondent detection with successful match.""" scanner = AIDocumentScanner() mock_match.return_value = [self.correspondent1] - + result = scanner._detect_correspondent(self.document, "text", {}) - + self.assertIsNotNone(result) corr_id, confidence = result self.assertEqual(corr_id, self.correspondent1.id) self.assertEqual(confidence, 0.85) - @mock.patch('documents.ai_scanner.match_correspondents') + @mock.patch("documents.ai_scanner.match_correspondents") def test_detect_correspondent_without_match(self, mock_match): """Test correspondent detection without match.""" scanner = AIDocumentScanner() mock_match.return_value = [] - + result = scanner._detect_correspondent(self.document, "text", {}) - + self.assertIsNone(result) - @mock.patch('documents.ai_scanner.match_correspondents') + @mock.patch("documents.ai_scanner.match_correspondents") def test_detect_correspondent_from_ner_entities(self, mock_match): """Test correspondent detection from NER organizations.""" scanner = AIDocumentScanner() mock_match.return_value = [] - + entities = { - "organizations": [{"text": "ACME Corporation"}] + "organizations": [{"text": "ACME Corporation"}], } - + result = scanner._detect_correspondent(self.document, "text", entities) - + self.assertIsNotNone(result) corr_id, confidence = result self.assertEqual(corr_id, self.correspondent1.id) self.assertEqual(confidence, 0.70) - @mock.patch('documents.ai_scanner.match_correspondents') - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.match_correspondents") + @mock.patch("documents.ai_scanner.logger") def test_detect_correspondent_handles_exception(self, mock_logger, mock_match): """Test correspondent detection handles exceptions.""" scanner = AIDocumentScanner() mock_match.side_effect = Exception("Detection failed") - + result = scanner._detect_correspondent(self.document, "text", {}) - + self.assertIsNone(result) mock_logger.error.assert_called() @@ -488,49 +485,49 @@ class TestClassifyDocumentType(TestCase): """Set up test document types.""" self.doc_type1 = DocumentType.objects.create( name="Invoice", - matching_algorithm=DocumentType.MATCH_AUTO + matching_algorithm=DocumentType.MATCH_AUTO, ) self.doc_type2 = DocumentType.objects.create( name="Receipt", - matching_algorithm=DocumentType.MATCH_AUTO + matching_algorithm=DocumentType.MATCH_AUTO, ) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) - @mock.patch('documents.ai_scanner.match_document_types') + @mock.patch("documents.ai_scanner.match_document_types") def test_classify_document_type_with_match(self, mock_match): """Test document type classification with match.""" scanner = AIDocumentScanner() mock_match.return_value = [self.doc_type1] - + result = scanner._classify_document_type(self.document, "text", {}) - + self.assertIsNotNone(result) type_id, confidence = result self.assertEqual(type_id, self.doc_type1.id) self.assertEqual(confidence, 0.85) - @mock.patch('documents.ai_scanner.match_document_types') + @mock.patch("documents.ai_scanner.match_document_types") def test_classify_document_type_without_match(self, mock_match): """Test document type classification without match.""" scanner = AIDocumentScanner() mock_match.return_value = [] - + result = scanner._classify_document_type(self.document, "text", {}) - + self.assertIsNone(result) - @mock.patch('documents.ai_scanner.match_document_types') - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.match_document_types") + @mock.patch("documents.ai_scanner.logger") def test_classify_document_type_handles_exception(self, mock_logger, mock_match): """Test classification handles exceptions.""" scanner = AIDocumentScanner() mock_match.side_effect = Exception("Classification failed") - + result = scanner._classify_document_type(self.document, "text", {}) - + self.assertIsNone(result) mock_logger.error.assert_called() @@ -543,48 +540,48 @@ class TestSuggestStoragePath(TestCase): self.storage_path1 = StoragePath.objects.create( name="Invoices", path="/documents/invoices", - matching_algorithm=StoragePath.MATCH_AUTO + matching_algorithm=StoragePath.MATCH_AUTO, ) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) - @mock.patch('documents.ai_scanner.match_storage_paths') + @mock.patch("documents.ai_scanner.match_storage_paths") def test_suggest_storage_path_with_match(self, mock_match): """Test storage path suggestion with match.""" scanner = AIDocumentScanner() mock_match.return_value = [self.storage_path1] - + scan_result = AIScanResult() result = scanner._suggest_storage_path(self.document, "text", scan_result) - + self.assertIsNotNone(result) path_id, confidence = result self.assertEqual(path_id, self.storage_path1.id) self.assertEqual(confidence, 0.80) - @mock.patch('documents.ai_scanner.match_storage_paths') + @mock.patch("documents.ai_scanner.match_storage_paths") def test_suggest_storage_path_without_match(self, mock_match): """Test storage path suggestion without match.""" scanner = AIDocumentScanner() mock_match.return_value = [] - + scan_result = AIScanResult() result = scanner._suggest_storage_path(self.document, "text", scan_result) - + self.assertIsNone(result) - @mock.patch('documents.ai_scanner.match_storage_paths') - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.match_storage_paths") + @mock.patch("documents.ai_scanner.logger") def test_suggest_storage_path_handles_exception(self, mock_logger, mock_match): """Test storage path suggestion handles exceptions.""" scanner = AIDocumentScanner() mock_match.side_effect = Exception("Suggestion failed") - + scan_result = AIScanResult() result = scanner._suggest_storage_path(self.document, "text", scan_result) - + self.assertIsNone(result) mock_logger.error.assert_called() @@ -596,33 +593,33 @@ class TestExtractCustomFields(TestCase): """Set up test custom fields.""" self.field_date = CustomField.objects.create( name="Invoice Date", - data_type=CustomField.FieldDataType.DATE + data_type=CustomField.FieldDataType.DATE, ) self.field_amount = CustomField.objects.create( name="Total Amount", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_email = CustomField.objects.create( name="Contact Email", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) def test_extract_custom_fields_with_entities(self): """Test custom field extraction with entities.""" scanner = AIDocumentScanner() - + entities = { "dates": [{"text": "2024-01-01"}], "amounts": [{"text": "$1,000"}], - "emails": ["test@example.com"] + "emails": ["test@example.com"], } - + fields = scanner._extract_custom_fields(self.document, "text", entities) - + # Should extract date field self.assertIn(self.field_date.id, fields) value, confidence = fields[self.field_date.id] @@ -632,21 +629,21 @@ class TestExtractCustomFields(TestCase): def test_extract_custom_fields_without_entities(self): """Test custom field extraction without entities.""" scanner = AIDocumentScanner() - + fields = scanner._extract_custom_fields(self.document, "text", {}) - + # Should return empty dict self.assertEqual(fields, {}) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_extract_custom_fields_handles_exception(self, mock_logger): """Test custom field extraction handles exceptions.""" scanner = AIDocumentScanner() - - with mock.patch.object(CustomField.objects, 'all', + + with mock.patch.object(CustomField.objects, "all", side_effect=Exception("DB error")): fields = scanner._extract_custom_fields(self.document, "text", {}) - + self.assertEqual(fields, {}) mock_logger.error.assert_called() @@ -658,42 +655,42 @@ class TestExtractFieldValue(TestCase): """Set up test fields.""" self.field_date = CustomField.objects.create( name="Invoice Date", - data_type=CustomField.FieldDataType.DATE + data_type=CustomField.FieldDataType.DATE, ) self.field_amount = CustomField.objects.create( name="Total Amount", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_invoice = CustomField.objects.create( name="Invoice Number", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_email = CustomField.objects.create( name="Contact Email", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_phone = CustomField.objects.create( name="Phone Number", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_person = CustomField.objects.create( name="Person Name", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_company = CustomField.objects.create( name="Company Name", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) def test_extract_field_value_date(self): """Test extraction of date field.""" scanner = AIDocumentScanner() entities = {"dates": [{"text": "2024-01-01"}]} - + value, confidence = scanner._extract_field_value( - self.field_date, "text", entities + self.field_date, "text", entities, ) - + self.assertEqual(value, "2024-01-01") self.assertEqual(confidence, 0.75) @@ -701,11 +698,11 @@ class TestExtractFieldValue(TestCase): """Test extraction of amount field.""" scanner = AIDocumentScanner() entities = {"amounts": [{"text": "$1,000"}]} - + value, confidence = scanner._extract_field_value( - self.field_amount, "text", entities + self.field_amount, "text", entities, ) - + self.assertEqual(value, "$1,000") self.assertEqual(confidence, 0.75) @@ -713,11 +710,11 @@ class TestExtractFieldValue(TestCase): """Test extraction of invoice number.""" scanner = AIDocumentScanner() entities = {"invoice_numbers": ["INV-12345"]} - + value, confidence = scanner._extract_field_value( - self.field_invoice, "text", entities + self.field_invoice, "text", entities, ) - + self.assertEqual(value, "INV-12345") self.assertEqual(confidence, 0.80) @@ -725,11 +722,11 @@ class TestExtractFieldValue(TestCase): """Test extraction of email field.""" scanner = AIDocumentScanner() entities = {"emails": ["test@example.com"]} - + value, confidence = scanner._extract_field_value( - self.field_email, "text", entities + self.field_email, "text", entities, ) - + self.assertEqual(value, "test@example.com") self.assertEqual(confidence, 0.85) @@ -737,11 +734,11 @@ class TestExtractFieldValue(TestCase): """Test extraction of phone field.""" scanner = AIDocumentScanner() entities = {"phones": ["+1-555-1234"]} - + value, confidence = scanner._extract_field_value( - self.field_phone, "text", entities + self.field_phone, "text", entities, ) - + self.assertEqual(value, "+1-555-1234") self.assertEqual(confidence, 0.85) @@ -749,11 +746,11 @@ class TestExtractFieldValue(TestCase): """Test extraction of person name.""" scanner = AIDocumentScanner() entities = {"persons": [{"text": "John Doe"}]} - + value, confidence = scanner._extract_field_value( - self.field_person, "text", entities + self.field_person, "text", entities, ) - + self.assertEqual(value, "John Doe") self.assertEqual(confidence, 0.70) @@ -761,11 +758,11 @@ class TestExtractFieldValue(TestCase): """Test extraction of company name.""" scanner = AIDocumentScanner() entities = {"organizations": [{"text": "ACME Corp"}]} - + value, confidence = scanner._extract_field_value( - self.field_company, "text", entities + self.field_company, "text", entities, ) - + self.assertEqual(value, "ACME Corp") self.assertEqual(confidence, 0.70) @@ -773,11 +770,11 @@ class TestExtractFieldValue(TestCase): """Test extraction when no entity matches.""" scanner = AIDocumentScanner() entities = {} - + value, confidence = scanner._extract_field_value( - self.field_date, "text", entities + self.field_date, "text", entities, ) - + self.assertIsNone(value) self.assertEqual(confidence, 0.0) @@ -789,42 +786,42 @@ class TestSuggestWorkflows(TestCase): """Set up test workflows.""" self.workflow1 = Workflow.objects.create( name="Invoice Processing", - enabled=True + enabled=True, ) self.trigger1 = WorkflowTrigger.objects.create( workflow=self.workflow1, - type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION + type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION, ) self.workflow2 = Workflow.objects.create( name="Document Archival", - enabled=True + enabled=True, ) self.trigger2 = WorkflowTrigger.objects.create( workflow=self.workflow2, - type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION + type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION, ) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) def test_suggest_workflows_with_matches(self): """Test workflow suggestion with matches.""" scanner = AIDocumentScanner(suggest_threshold=0.5) - + scan_result = AIScanResult() scan_result.document_type = (1, 0.85) scan_result.correspondent = (2, 0.90) scan_result.tags = [(1, 0.80)] - + # Create action for workflow WorkflowAction.objects.create( workflow=self.workflow1, - type=WorkflowAction.WorkflowActionType.ASSIGNMENT + type=WorkflowAction.WorkflowActionType.ASSIGNMENT, ) - + suggestions = scanner._suggest_workflows(self.document, "text", scan_result) - + # Should suggest workflows self.assertGreater(len(suggestions), 0) for workflow_id, confidence in suggestions: @@ -833,27 +830,27 @@ class TestSuggestWorkflows(TestCase): def test_suggest_workflows_filters_by_threshold(self): """Test that workflows below threshold are filtered.""" scanner = AIDocumentScanner(suggest_threshold=0.95) - + scan_result = AIScanResult() - + suggestions = scanner._suggest_workflows(self.document, "text", scan_result) - + # Should not suggest any (confidence too low) self.assertEqual(len(suggestions), 0) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_suggest_workflows_handles_exception(self, mock_logger): """Test workflow suggestion handles exceptions.""" scanner = AIDocumentScanner() - + scan_result = AIScanResult() - - with mock.patch.object(Workflow.objects, 'filter', + + with mock.patch.object(Workflow.objects, "filter", side_effect=Exception("DB error")): suggestions = scanner._suggest_workflows( - self.document, "text", scan_result + self.document, "text", scan_result, ) - + self.assertEqual(suggestions, []) mock_logger.error.assert_called() @@ -865,22 +862,22 @@ class TestEvaluateWorkflowMatch(TestCase): """Set up test workflow.""" self.workflow = Workflow.objects.create( name="Test Workflow", - enabled=True + enabled=True, ) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) def test_evaluate_workflow_match_base_confidence(self): """Test base confidence for workflow.""" scanner = AIDocumentScanner() scan_result = AIScanResult() - + confidence = scanner._evaluate_workflow_match( - self.workflow, self.document, scan_result + self.workflow, self.document, scan_result, ) - + self.assertEqual(confidence, 0.5) def test_evaluate_workflow_match_with_document_type(self): @@ -888,17 +885,17 @@ class TestEvaluateWorkflowMatch(TestCase): scanner = AIDocumentScanner() scan_result = AIScanResult() scan_result.document_type = (1, 0.85) - + # Create action for workflow WorkflowAction.objects.create( workflow=self.workflow, - type=WorkflowAction.WorkflowActionType.ASSIGNMENT + type=WorkflowAction.WorkflowActionType.ASSIGNMENT, ) - + confidence = scanner._evaluate_workflow_match( - self.workflow, self.document, scan_result + self.workflow, self.document, scan_result, ) - + self.assertGreater(confidence, 0.5) def test_evaluate_workflow_match_with_correspondent(self): @@ -906,11 +903,11 @@ class TestEvaluateWorkflowMatch(TestCase): scanner = AIDocumentScanner() scan_result = AIScanResult() scan_result.correspondent = (1, 0.90) - + confidence = scanner._evaluate_workflow_match( - self.workflow, self.document, scan_result + self.workflow, self.document, scan_result, ) - + self.assertGreater(confidence, 0.5) def test_evaluate_workflow_match_with_tags(self): @@ -918,11 +915,11 @@ class TestEvaluateWorkflowMatch(TestCase): scanner = AIDocumentScanner() scan_result = AIScanResult() scan_result.tags = [(1, 0.80), (2, 0.75)] - + confidence = scanner._evaluate_workflow_match( - self.workflow, self.document, scan_result + self.workflow, self.document, scan_result, ) - + self.assertGreater(confidence, 0.5) def test_evaluate_workflow_match_max_confidence(self): @@ -932,17 +929,17 @@ class TestEvaluateWorkflowMatch(TestCase): scan_result.document_type = (1, 0.85) scan_result.correspondent = (1, 0.90) scan_result.tags = [(1, 0.80)] - + # Create action WorkflowAction.objects.create( workflow=self.workflow, - type=WorkflowAction.WorkflowActionType.ASSIGNMENT + type=WorkflowAction.WorkflowActionType.ASSIGNMENT, ) - + confidence = scanner._evaluate_workflow_match( - self.workflow, self.document, scan_result + self.workflow, self.document, scan_result, ) - + self.assertLessEqual(confidence, 1.0) @@ -953,21 +950,21 @@ class TestSuggestTitle(TestCase): """Set up test document.""" self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) def test_suggest_title_with_all_entities(self): """Test title suggestion with all entity types.""" scanner = AIDocumentScanner() - + entities = { "document_type": "Invoice", "organizations": [{"text": "ACME Corporation"}], - "dates": [{"text": "2024-01-01"}] + "dates": [{"text": "2024-01-01"}], } - + title = scanner._suggest_title(self.document, "text", entities) - + self.assertIsNotNone(title) self.assertIn("Invoice", title) self.assertIn("ACME Corporation", title) @@ -976,51 +973,51 @@ class TestSuggestTitle(TestCase): def test_suggest_title_with_partial_entities(self): """Test title suggestion with partial entities.""" scanner = AIDocumentScanner() - + entities = { - "organizations": [{"text": "TechStart Inc"}] + "organizations": [{"text": "TechStart Inc"}], } - + title = scanner._suggest_title(self.document, "text", entities) - + self.assertIsNotNone(title) self.assertIn("TechStart Inc", title) def test_suggest_title_without_entities(self): """Test title suggestion without entities.""" scanner = AIDocumentScanner() - + title = scanner._suggest_title(self.document, "text", {}) - + self.assertIsNone(title) def test_suggest_title_respects_length_limit(self): """Test that title respects 127 character limit.""" scanner = AIDocumentScanner() - + # Create very long organization name long_org = "A" * 100 entities = { "organizations": [{"text": long_org}], - "dates": [{"text": "2024-01-01"}] + "dates": [{"text": "2024-01-01"}], } - + title = scanner._suggest_title(self.document, "text", entities) - + self.assertIsNotNone(title) self.assertLessEqual(len(title), 127) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_suggest_title_handles_exception(self, mock_logger): """Test title suggestion handles exceptions.""" scanner = AIDocumentScanner() - + # Force an exception entities = mock.MagicMock() entities.get.side_effect = Exception("Unexpected error") - + title = scanner._suggest_title(self.document, "text", entities) - + self.assertIsNone(title) mock_logger.error.assert_called() @@ -1031,15 +1028,15 @@ class TestExtractTables(TestCase): def test_extract_tables_with_extractor(self): """Test table extraction when extractor is available.""" scanner = AIDocumentScanner() - + mock_extractor = mock.MagicMock() mock_extractor.extract_tables_from_image.return_value = [ - {"data": [[1, 2], [3, 4]], "headers": ["A", "B"]} + {"data": [[1, 2], [3, 4]], "headers": ["A", "B"]}, ] scanner._table_extractor = mock_extractor - + tables = scanner._extract_tables("/path/to/file.pdf") - + self.assertEqual(len(tables), 1) self.assertIn("data", tables[0]) mock_extractor.extract_tables_from_image.assert_called_once() @@ -1048,22 +1045,22 @@ class TestExtractTables(TestCase): """Test table extraction when extractor is not available.""" scanner = AIDocumentScanner() scanner._get_table_extractor = mock.MagicMock(return_value=None) - + tables = scanner._extract_tables("/path/to/file.pdf") - + self.assertEqual(tables, []) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_extract_tables_handles_exception(self, mock_logger): """Test table extraction handles exceptions.""" scanner = AIDocumentScanner() - + mock_extractor = mock.MagicMock() mock_extractor.extract_tables_from_image.side_effect = Exception("Extraction failed") scanner._table_extractor = mock_extractor - + tables = scanner._extract_tables("/path/to/file.pdf") - + self.assertEqual(tables, []) mock_logger.error.assert_called() @@ -1075,17 +1072,17 @@ class TestScanDocument(TestCase): """Set up test document.""" self.document = Document.objects.create( title="Test Document", - content="Invoice from ACME Corporation dated 2024-01-01" + content="Invoice from ACME Corporation dated 2024-01-01", ) - @mock.patch.object(AIDocumentScanner, '_extract_entities') - @mock.patch.object(AIDocumentScanner, '_suggest_tags') - @mock.patch.object(AIDocumentScanner, '_detect_correspondent') - @mock.patch.object(AIDocumentScanner, '_classify_document_type') - @mock.patch.object(AIDocumentScanner, '_suggest_storage_path') - @mock.patch.object(AIDocumentScanner, '_extract_custom_fields') - @mock.patch.object(AIDocumentScanner, '_suggest_workflows') - @mock.patch.object(AIDocumentScanner, '_suggest_title') + @mock.patch.object(AIDocumentScanner, "_extract_entities") + @mock.patch.object(AIDocumentScanner, "_suggest_tags") + @mock.patch.object(AIDocumentScanner, "_detect_correspondent") + @mock.patch.object(AIDocumentScanner, "_classify_document_type") + @mock.patch.object(AIDocumentScanner, "_suggest_storage_path") + @mock.patch.object(AIDocumentScanner, "_extract_custom_fields") + @mock.patch.object(AIDocumentScanner, "_suggest_workflows") + @mock.patch.object(AIDocumentScanner, "_suggest_title") def test_scan_document_orchestrates_all_methods( self, mock_title, @@ -1095,11 +1092,11 @@ class TestScanDocument(TestCase): mock_doc_type, mock_correspondent, mock_tags, - mock_entities + mock_entities, ): """Test that scan_document calls all extraction methods.""" scanner = AIDocumentScanner() - + # Set up mock returns mock_entities.return_value = {"persons": ["John Doe"]} mock_tags.return_value = [(1, 0.85)] @@ -1109,9 +1106,9 @@ class TestScanDocument(TestCase): mock_fields.return_value = {1: ("value", 0.70)} mock_workflows.return_value = [(1, 0.65)] mock_title.return_value = "Suggested Title" - + result = scanner.scan_document(self.document, "Document text") - + # Verify all methods were called mock_entities.assert_called_once() mock_tags.assert_called_once() @@ -1121,53 +1118,53 @@ class TestScanDocument(TestCase): mock_fields.assert_called_once() mock_workflows.assert_called_once() mock_title.assert_called_once() - + # Verify result contains data self.assertEqual(result.tags, [(1, 0.85)]) self.assertEqual(result.correspondent, (1, 0.90)) self.assertEqual(result.document_type, (1, 0.80)) - @mock.patch.object(AIDocumentScanner, '_extract_tables') + @mock.patch.object(AIDocumentScanner, "_extract_tables") def test_scan_document_extracts_tables_when_enabled(self, mock_extract_tables): """Test that tables are extracted when OCR is enabled and file path provided.""" scanner = AIDocumentScanner(enable_advanced_ocr=True) mock_extract_tables.return_value = [{"data": "test"}] - + # Mock other methods to avoid complexity - with mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + with mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + result = scanner.scan_document( self.document, "Document text", - original_file_path="/path/to/file.pdf" + original_file_path="/path/to/file.pdf", ) - + mock_extract_tables.assert_called_once_with("/path/to/file.pdf") self.assertIn("tables", result.metadata) def test_scan_document_without_file_path_skips_tables(self): """Test that tables are not extracted when file path is not provided.""" scanner = AIDocumentScanner(enable_advanced_ocr=True) - - with mock.patch.object(scanner, '_extract_tables') as mock_extract_tables, \ - mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + + with mock.patch.object(scanner, "_extract_tables") as mock_extract_tables, \ + mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + result = scanner.scan_document(self.document, "Document text") - + mock_extract_tables.assert_not_called() self.assertNotIn("tables", result.metadata) @@ -1183,35 +1180,35 @@ class TestApplyScanResults(TestCase): self.doc_type = DocumentType.objects.create(name="Invoice") self.storage_path = StoragePath.objects.create( name="Invoices", - path="/invoices" + path="/invoices", ) self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) def test_apply_scan_results_auto_applies_high_confidence(self): """Test that high confidence suggestions are auto-applied.""" scanner = AIDocumentScanner(auto_apply_threshold=0.80) - + scan_result = AIScanResult() scan_result.tags = [(self.tag1.id, 0.85), (self.tag2.id, 0.82)] scan_result.correspondent = (self.correspondent.id, 0.90) scan_result.document_type = (self.doc_type.id, 0.88) scan_result.storage_path = (self.storage_path.id, 0.85) - + result = scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + # Verify auto-applied self.assertEqual(len(result["applied"]["tags"]), 2) self.assertIsNotNone(result["applied"]["correspondent"]) self.assertIsNotNone(result["applied"]["document_type"]) self.assertIsNotNone(result["applied"]["storage_path"]) - + # Verify document was updated self.document.refresh_from_db() self.assertEqual(self.document.correspondent, self.correspondent) @@ -1222,25 +1219,25 @@ class TestApplyScanResults(TestCase): """Test that medium confidence items are suggested, not applied.""" scanner = AIDocumentScanner( auto_apply_threshold=0.80, - suggest_threshold=0.60 + suggest_threshold=0.60, ) - + scan_result = AIScanResult() scan_result.tags = [(self.tag1.id, 0.70)] scan_result.correspondent = (self.correspondent.id, 0.65) - + result = scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + # Verify suggested but not applied self.assertEqual(len(result["suggestions"]["tags"]), 1) self.assertIsNotNone(result["suggestions"]["correspondent"]) self.assertEqual(len(result["applied"]["tags"]), 0) self.assertIsNone(result["applied"]["correspondent"]) - + # Verify document was not updated self.document.refresh_from_db() self.assertIsNone(self.document.correspondent) @@ -1248,54 +1245,54 @@ class TestApplyScanResults(TestCase): def test_apply_scan_results_respects_auto_apply_false(self): """Test that auto_apply=False prevents automatic application.""" scanner = AIDocumentScanner(auto_apply_threshold=0.80) - + scan_result = AIScanResult() scan_result.tags = [(self.tag1.id, 0.90)] - + result = scanner.apply_scan_results( self.document, scan_result, - auto_apply=False + auto_apply=False, ) - + # Verify nothing was applied self.assertEqual(len(result["applied"]["tags"]), 0) def test_apply_scan_results_uses_transaction(self): """Test that apply_scan_results uses atomic transaction.""" scanner = AIDocumentScanner() - + scan_result = AIScanResult() scan_result.correspondent = (self.correspondent.id, 0.90) - - with mock.patch.object(self.document, 'save', + + with mock.patch.object(self.document, "save", side_effect=Exception("Save failed")): with self.assertRaises(Exception): with transaction.atomic(): scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + # Verify transaction was rolled back self.document.refresh_from_db() self.assertIsNone(self.document.correspondent) - @mock.patch('documents.ai_scanner.logger') + @mock.patch("documents.ai_scanner.logger") def test_apply_scan_results_handles_exception(self, mock_logger): """Test that apply_scan_results handles exceptions gracefully.""" scanner = AIDocumentScanner() - + scan_result = AIScanResult() scan_result.tags = [(999, 0.90)] # Non-existent tag - + scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + mock_logger.error.assert_called() @@ -1305,14 +1302,14 @@ class TestGetAIScanner(TestCase): def test_get_ai_scanner_returns_instance(self): """Test that get_ai_scanner returns a scanner instance.""" scanner = get_ai_scanner() - + self.assertIsInstance(scanner, AIDocumentScanner) def test_get_ai_scanner_returns_same_instance(self): """Test that get_ai_scanner returns the same instance.""" scanner1 = get_ai_scanner() scanner2 = get_ai_scanner() - + self.assertIs(scanner1, scanner2) @@ -1323,24 +1320,24 @@ class TestEdgeCasesAndErrorHandling(TestCase): """Set up test document.""" self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) def test_scan_document_with_empty_text(self): """Test scanning document with empty text.""" scanner = AIDocumentScanner() - - with mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + + with mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + result = scanner.scan_document(self.document, "") - + self.assertIsNotNone(result) self.assertIsInstance(result, AIScanResult) @@ -1348,49 +1345,49 @@ class TestEdgeCasesAndErrorHandling(TestCase): """Test scanning document with very long text.""" scanner = AIDocumentScanner() long_text = "A" * 100000 - - with mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + + with mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + result = scanner.scan_document(self.document, long_text) - + self.assertIsNotNone(result) def test_scan_document_with_special_characters(self): """Test scanning document with special characters.""" scanner = AIDocumentScanner() special_text = "Test with émojis 😀 and special chars: <>{}[]|\\`~" - - with mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + + with mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + result = scanner.scan_document(self.document, special_text) - + self.assertIsNotNone(result) def test_apply_scan_results_with_empty_result(self): """Test applying empty scan results.""" scanner = AIDocumentScanner() scan_result = AIScanResult() - + result = scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + self.assertEqual(result["applied"]["tags"], []) self.assertIsNone(result["applied"]["correspondent"]) @@ -1399,16 +1396,16 @@ class TestEdgeCasesAndErrorHandling(TestCase): # Test at exact threshold scanner = AIDocumentScanner(auto_apply_threshold=0.80) self.assertEqual(scanner.auto_apply_threshold, 0.80) - + # Test extreme values scanner_low = AIDocumentScanner( auto_apply_threshold=0.01, - suggest_threshold=0.01 + suggest_threshold=0.01, ) self.assertEqual(scanner_low.auto_apply_threshold, 0.01) - + scanner_high = AIDocumentScanner( auto_apply_threshold=0.99, - suggest_threshold=0.80 + suggest_threshold=0.80, ) self.assertEqual(scanner_high.auto_apply_threshold, 0.99) diff --git a/src/documents/tests/test_ai_scanner_integration.py b/src/documents/tests/test_ai_scanner_integration.py index 68ed670aa..7673fa4c9 100644 --- a/src/documents/tests/test_ai_scanner_integration.py +++ b/src/documents/tests/test_ai_scanner_integration.py @@ -8,24 +8,21 @@ document consumption to metadata application. from unittest import mock -from django.test import TestCase, TransactionTestCase +from django.test import TestCase +from django.test import TransactionTestCase -from documents.ai_scanner import ( - AIDocumentScanner, - AIScanResult, - get_ai_scanner, -) -from documents.models import ( - Correspondent, - CustomField, - Document, - DocumentType, - StoragePath, - Tag, - Workflow, - WorkflowTrigger, - WorkflowAction, -) +from documents.ai_scanner import AIDocumentScanner +from documents.ai_scanner import AIScanResult +from documents.ai_scanner import get_ai_scanner +from documents.models import Correspondent +from documents.models import CustomField +from documents.models import Document +from documents.models import DocumentType +from documents.models import StoragePath +from documents.models import Tag +from documents.models import Workflow +from documents.models import WorkflowAction +from documents.models import WorkflowTrigger class TestAIScannerIntegrationBasic(TestCase): @@ -35,49 +32,49 @@ class TestAIScannerIntegrationBasic(TestCase): """Set up test data.""" self.document = Document.objects.create( title="Invoice from ACME Corporation", - content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000" + content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000", ) - + self.tag_invoice = Tag.objects.create( name="Invoice", matching_algorithm=Tag.MATCH_AUTO, - match="invoice" + match="invoice", ) self.tag_important = Tag.objects.create( name="Important", matching_algorithm=Tag.MATCH_AUTO, - match="total" + match="total", ) - + self.correspondent = Correspondent.objects.create( name="ACME Corporation", matching_algorithm=Correspondent.MATCH_AUTO, - match="acme" + match="acme", ) - + self.doc_type = DocumentType.objects.create( name="Invoice", matching_algorithm=DocumentType.MATCH_AUTO, - match="invoice" + match="invoice", ) - + self.storage_path = StoragePath.objects.create( name="Invoices", path="/invoices", matching_algorithm=StoragePath.MATCH_AUTO, - match="invoice" + match="invoice", ) - @mock.patch('documents.ai_scanner.match_tags') - @mock.patch('documents.ai_scanner.match_correspondents') - @mock.patch('documents.ai_scanner.match_document_types') - @mock.patch('documents.ai_scanner.match_storage_paths') + @mock.patch("documents.ai_scanner.match_tags") + @mock.patch("documents.ai_scanner.match_correspondents") + @mock.patch("documents.ai_scanner.match_document_types") + @mock.patch("documents.ai_scanner.match_storage_paths") def test_full_scan_and_apply_workflow( self, mock_storage, mock_types, mock_correspondents, - mock_tags + mock_tags, ): """Test complete workflow from scan to application.""" # Mock the matching functions to return our test data @@ -85,51 +82,51 @@ class TestAIScannerIntegrationBasic(TestCase): mock_correspondents.return_value = [self.correspondent] mock_types.return_value = [self.doc_type] mock_storage.return_value = [self.storage_path] - + scanner = AIDocumentScanner(auto_apply_threshold=0.80) - + # Scan the document scan_result = scanner.scan_document( self.document, - self.document.content + self.document.content, ) - + # Verify scan results self.assertIsNotNone(scan_result) self.assertGreater(len(scan_result.tags), 0) self.assertIsNotNone(scan_result.correspondent) self.assertIsNotNone(scan_result.document_type) self.assertIsNotNone(scan_result.storage_path) - + # Apply the results result = scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + # Verify application self.assertGreater(len(result["applied"]["tags"]), 0) self.assertIsNotNone(result["applied"]["correspondent"]) - + # Verify database changes self.document.refresh_from_db() self.assertEqual(self.document.correspondent, self.correspondent) self.assertEqual(self.document.document_type, self.doc_type) self.assertEqual(self.document.storage_path, self.storage_path) - @mock.patch('documents.ai_scanner.match_tags') + @mock.patch("documents.ai_scanner.match_tags") def test_scan_with_no_matches(self, mock_tags): """Test scanning when no matches are found.""" mock_tags.return_value = [] - + scanner = AIDocumentScanner() - + scan_result = scanner.scan_document( self.document, - "Some random text with no matches" + "Some random text with no matches", ) - + # Should return empty results self.assertEqual(len(scan_result.tags), 0) self.assertIsNone(scan_result.correspondent) @@ -143,46 +140,46 @@ class TestAIScannerIntegrationCustomFields(TestCase): """Set up test data with custom fields.""" self.document = Document.objects.create( title="Invoice", - content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com" + content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com", ) - + self.field_date = CustomField.objects.create( name="Invoice Date", - data_type=CustomField.FieldDataType.DATE + data_type=CustomField.FieldDataType.DATE, ) self.field_number = CustomField.objects.create( name="Invoice Number", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_amount = CustomField.objects.create( name="Total Amount", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) self.field_email = CustomField.objects.create( name="Contact Email", - data_type=CustomField.FieldDataType.STRING + data_type=CustomField.FieldDataType.STRING, ) def test_custom_field_extraction_integration(self): """Test custom field extraction with mocked NER.""" scanner = AIDocumentScanner() - + # Mock NER to return entities mock_ner = mock.MagicMock() mock_ner.extract_all.return_value = { "dates": [{"text": "2024-01-01"}], "amounts": [{"text": "$1,500"}], "invoice_numbers": ["INV-123"], - "emails": ["john@example.com"] + "emails": ["john@example.com"], } scanner._ner_extractor = mock_ner - + # Scan document scan_result = scanner.scan_document(self.document, self.document.content) - + # Verify custom fields were extracted self.assertGreater(len(scan_result.custom_fields), 0) - + # Check specific fields extracted_field_ids = list(scan_result.custom_fields.keys()) self.assertIn(self.field_date.id, extracted_field_ids) @@ -196,47 +193,47 @@ class TestAIScannerIntegrationWorkflows(TestCase): """Set up test workflows.""" self.document = Document.objects.create( title="Invoice", - content="Invoice document" + content="Invoice document", ) - + self.workflow1 = Workflow.objects.create( name="Invoice Processing", - enabled=True + enabled=True, ) self.trigger1 = WorkflowTrigger.objects.create( workflow=self.workflow1, - type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION + type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION, ) self.action1 = WorkflowAction.objects.create( workflow=self.workflow1, - type=WorkflowAction.WorkflowActionType.ASSIGNMENT + type=WorkflowAction.WorkflowActionType.ASSIGNMENT, ) - + self.workflow2 = Workflow.objects.create( name="Archive Documents", - enabled=True + enabled=True, ) self.trigger2 = WorkflowTrigger.objects.create( workflow=self.workflow2, - type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION + type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION, ) def test_workflow_suggestion_integration(self): """Test workflow suggestion with real workflows.""" scanner = AIDocumentScanner(suggest_threshold=0.5) - + # Create scan result with some attributes scan_result = AIScanResult() scan_result.document_type = (1, 0.85) scan_result.tags = [(1, 0.80)] - + # Get workflow suggestions workflows = scanner._suggest_workflows( self.document, self.document.content, - scan_result + scan_result, ) - + # Should suggest workflows self.assertGreater(len(workflows), 0) workflow_ids = [wf_id for wf_id, _ in workflows] @@ -250,7 +247,7 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase): """Set up test data.""" self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) self.tag = Tag.objects.create(name="TestTag") self.correspondent = Correspondent.objects.create(name="TestCorp") @@ -258,29 +255,29 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase): def test_transaction_rollback_on_error(self): """Test that transaction rolls back on error.""" scanner = AIDocumentScanner() - + scan_result = AIScanResult() scan_result.tags = [(self.tag.id, 0.90)] scan_result.correspondent = (self.correspondent.id, 0.90) - + # Force an error during save original_save = Document.save call_count = [0] - + def failing_save(self, *args, **kwargs): call_count[0] += 1 if call_count[0] >= 1: raise Exception("Forced save failure") return original_save(self, *args, **kwargs) - - with mock.patch.object(Document, 'save', failing_save): + + with mock.patch.object(Document, "save", failing_save): with self.assertRaises(Exception): scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + # Verify changes were rolled back self.document.refresh_from_db() # Document should not have been modified @@ -292,30 +289,30 @@ class TestAIScannerIntegrationPerformance(TestCase): def test_scan_multiple_documents(self): """Test scanning multiple documents efficiently.""" scanner = AIDocumentScanner() - + documents = [] for i in range(5): doc = Document.objects.create( title=f"Document {i}", - content=f"Content for document {i}" + content=f"Content for document {i}", ) documents.append(doc) - + # Mock to avoid actual ML loading - with mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + with mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + results = [] for doc in documents: result = scanner.scan_document(doc, doc.content) results.append(result) - + # Verify all scans completed self.assertEqual(len(results), 5) for result in results: @@ -329,37 +326,37 @@ class TestAIScannerIntegrationEntityMatching(TestCase): """Set up test data.""" self.document = Document.objects.create( title="Business Invoice", - content="Invoice from ACME Corporation" + content="Invoice from ACME Corporation", ) - + self.correspondent_acme = Correspondent.objects.create( name="ACME Corporation", - matching_algorithm=Correspondent.MATCH_AUTO + matching_algorithm=Correspondent.MATCH_AUTO, ) self.correspondent_other = Correspondent.objects.create( name="Other Company", - matching_algorithm=Correspondent.MATCH_AUTO + matching_algorithm=Correspondent.MATCH_AUTO, ) def test_correspondent_matching_with_ner_entities(self): """Test that NER entities help match correspondents.""" scanner = AIDocumentScanner() - + # Mock NER to extract organization mock_ner = mock.MagicMock() mock_ner.extract_all.return_value = { - "organizations": [{"text": "ACME Corporation"}] + "organizations": [{"text": "ACME Corporation"}], } scanner._ner_extractor = mock_ner - + # Mock matching to return empty (so NER-based matching is used) - with mock.patch('documents.ai_scanner.match_correspondents', return_value=[]): + with mock.patch("documents.ai_scanner.match_correspondents", return_value=[]): result = scanner._detect_correspondent( self.document, self.document.content, - {"organizations": [{"text": "ACME Corporation"}]} + {"organizations": [{"text": "ACME Corporation"}]}, ) - + # Should detect ACME correspondent self.assertIsNotNone(result) corr_id, confidence = result @@ -372,20 +369,20 @@ class TestAIScannerIntegrationTitleGeneration(TestCase): def test_title_generation_with_entities(self): """Test title generation uses extracted entities.""" scanner = AIDocumentScanner() - + document = Document.objects.create( title="document.pdf", - content="Invoice from ACME Corp dated 2024-01-15" + content="Invoice from ACME Corp dated 2024-01-15", ) - + entities = { "document_type": "Invoice", "organizations": [{"text": "ACME Corp"}], - "dates": [{"text": "2024-01-15"}] + "dates": [{"text": "2024-01-15"}], } - + title = scanner._suggest_title(document, document.content, entities) - + self.assertIsNotNone(title) self.assertIn("Invoice", title) self.assertIn("ACME Corp", title) @@ -399,7 +396,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase): """Set up test data.""" self.document = Document.objects.create( title="Test", - content="Test" + content="Test", ) self.tag_high = Tag.objects.create(name="HighConfidence") self.tag_medium = Tag.objects.create(name="MediumConfidence") @@ -409,26 +406,26 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase): """Test that only high confidence suggestions are auto-applied.""" scanner = AIDocumentScanner( auto_apply_threshold=0.80, - suggest_threshold=0.60 + suggest_threshold=0.60, ) - + scan_result = AIScanResult() scan_result.tags = [ (self.tag_high.id, 0.90), # Should be applied (self.tag_medium.id, 0.70), # Should be suggested (self.tag_low.id, 0.50), # Should be ignored ] - + result = scanner.apply_scan_results( self.document, scan_result, - auto_apply=True + auto_apply=True, ) - + # Verify high confidence was applied self.assertEqual(len(result["applied"]["tags"]), 1) self.assertEqual(result["applied"]["tags"][0]["id"], self.tag_high.id) - + # Verify medium confidence was suggested self.assertEqual(len(result["suggestions"]["tags"]), 1) self.assertEqual(result["suggestions"]["tags"][0]["id"], self.tag_medium.id) @@ -441,28 +438,28 @@ class TestAIScannerIntegrationGlobalInstance(TestCase): """Test that global scanner can be reused across multiple scans.""" scanner1 = get_ai_scanner() scanner2 = get_ai_scanner() - + # Should be the same instance self.assertIs(scanner1, scanner2) - + # Should be functional document = Document.objects.create( title="Test", - content="Test content" + content="Test content", ) - - with mock.patch.object(scanner1, '_extract_entities', return_value={}), \ - mock.patch.object(scanner1, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner1, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner1, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner1, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner1, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner1, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner1, '_suggest_title', return_value=None): - + + with mock.patch.object(scanner1, "_extract_entities", return_value={}), \ + mock.patch.object(scanner1, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner1, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner1, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner1, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner1, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner1, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner1, "_suggest_title", return_value=None): + result1 = scanner1.scan_document(document, document.content) result2 = scanner2.scan_document(document, document.content) - + self.assertIsInstance(result1, AIScanResult) self.assertIsInstance(result2, AIScanResult) @@ -473,66 +470,66 @@ class TestAIScannerIntegrationEdgeCases(TestCase): def test_scan_with_minimal_document(self): """Test scanning a document with minimal information.""" scanner = AIDocumentScanner() - + document = Document.objects.create( title="", - content="" + content="", ) - - with mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + + with mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + result = scanner.scan_document(document, document.content) - + self.assertIsInstance(result, AIScanResult) def test_apply_with_deleted_references(self): """Test applying results when referenced objects have been deleted.""" scanner = AIDocumentScanner() - + document = Document.objects.create( title="Test", - content="Test" + content="Test", ) - + scan_result = AIScanResult() scan_result.tags = [(9999, 0.90)] # Non-existent tag ID scan_result.correspondent = (9999, 0.90) # Non-existent correspondent ID - + # Should handle gracefully result = scanner.apply_scan_results( document, scan_result, - auto_apply=True + auto_apply=True, ) - + # Should not crash, just log errors self.assertEqual(len(result["applied"]["tags"]), 0) def test_scan_with_unicode_and_special_characters(self): """Test scanning documents with Unicode and special characters.""" scanner = AIDocumentScanner() - + document = Document.objects.create( title="Factura - España 🇪🇸", - content="Société française • 日本語 • Ελληνικά • مرحبا" + content="Société française • 日本語 • Ελληνικά • مرحبا", ) - - with mock.patch.object(scanner, '_extract_entities', return_value={}), \ - mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ - mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ - mock.patch.object(scanner, '_classify_document_type', return_value=None), \ - mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ - mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ - mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ - mock.patch.object(scanner, '_suggest_title', return_value=None): - + + with mock.patch.object(scanner, "_extract_entities", return_value={}), \ + mock.patch.object(scanner, "_suggest_tags", return_value=[]), \ + mock.patch.object(scanner, "_detect_correspondent", return_value=None), \ + mock.patch.object(scanner, "_classify_document_type", return_value=None), \ + mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \ + mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \ + mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \ + mock.patch.object(scanner, "_suggest_title", return_value=None): + result = scanner.scan_document(document, document.content) - + self.assertIsInstance(result, AIScanResult) diff --git a/src/documents/tests/test_api_ai_endpoints.py b/src/documents/tests/test_api_ai_endpoints.py index a753e0c29..dfb272403 100644 --- a/src/documents/tests/test_api_ai_endpoints.py +++ b/src/documents/tests/test_api_ai_endpoints.py @@ -12,18 +12,17 @@ Tests cover: from unittest import mock -from django.contrib.auth.models import Permission, User +from django.contrib.auth.models import Permission +from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from rest_framework import status from rest_framework.test import APITestCase -from documents.models import ( - Correspondent, - DeletionRequest, - Document, - DocumentType, - Tag, -) +from documents.models import Correspondent +from documents.models import DeletionRequest +from documents.models import Document +from documents.models import DocumentType +from documents.models import Tag from documents.tests.utils import DirectoriesMixin @@ -33,18 +32,18 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase): def setUp(self): """Set up test data.""" super().setUp() - + # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.user_with_permission = User.objects.create_user( - username="permitted", email="permitted@test.com", password="permitted123" + username="permitted", email="permitted@test.com", password="permitted123", ) self.user_without_permission = User.objects.create_user( - username="regular", email="regular@test.com", password="regular123" + username="regular", email="regular@test.com", password="regular123", ) - + # Assign view permission content_type = ContentType.objects.get_for_model(Document) view_permission, _ = Permission.objects.get_or_create( @@ -53,13 +52,13 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase): content_type=content_type, ) self.user_with_permission.user_permissions.add(view_permission) - + # Create test document self.document = Document.objects.create( title="Test Document", - content="This is a test invoice from ACME Corporation" + content="This is a test invoice from ACME Corporation", ) - + # Create test metadata objects self.tag = Tag.objects.create(name="Invoice") self.correspondent = Correspondent.objects.create(name="ACME Corp") @@ -70,28 +69,28 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase): response = self.client.post( "/api/ai/suggestions/", {"document_id": self.document.id}, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_user_without_permission_denied(self): """Test that users without permission are denied.""" self.client.force_authenticate(user=self.user_without_permission) - + response = self.client.post( "/api/ai/suggestions/", {"document_id": self.document.id}, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_superuser_allowed(self): """Test that superusers can access the endpoint.""" self.client.force_authenticate(user=self.superuser) - - with mock.patch('documents.views.get_ai_scanner') as mock_scanner: + + with mock.patch("documents.views.get_ai_scanner") as mock_scanner: # Mock the scanner response mock_scan_result = mock.MagicMock() mock_scan_result.tags = [(self.tag.id, 0.85)] @@ -100,17 +99,17 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase): mock_scan_result.storage_path = None mock_scan_result.title_suggestion = "Invoice - ACME Corp" mock_scan_result.custom_fields = {} - + mock_scanner_instance = mock.MagicMock() mock_scanner_instance.scan_document.return_value = mock_scan_result mock_scanner.return_value = mock_scanner_instance - + response = self.client.post( "/api/ai/suggestions/", {"document_id": self.document.id}, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("document_id", response.data) self.assertEqual(response.data["document_id"], self.document.id) @@ -118,8 +117,8 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase): def test_user_with_permission_allowed(self): """Test that users with permission can access the endpoint.""" self.client.force_authenticate(user=self.user_with_permission) - - with mock.patch('documents.views.get_ai_scanner') as mock_scanner: + + with mock.patch("documents.views.get_ai_scanner") as mock_scanner: # Mock the scanner response mock_scan_result = mock.MagicMock() mock_scan_result.tags = [] @@ -128,41 +127,41 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase): mock_scan_result.storage_path = None mock_scan_result.title_suggestion = None mock_scan_result.custom_fields = {} - + mock_scanner_instance = mock.MagicMock() mock_scanner_instance.scan_document.return_value = mock_scan_result mock_scanner.return_value = mock_scanner_instance - + response = self.client.post( "/api/ai/suggestions/", {"document_id": self.document.id}, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_invalid_document_id(self): """Test handling of invalid document ID.""" self.client.force_authenticate(user=self.superuser) - + response = self.client.post( "/api/ai/suggestions/", {"document_id": 99999}, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_missing_document_id(self): """Test handling of missing document ID.""" self.client.force_authenticate(user=self.superuser) - + response = self.client.post( "/api/ai/suggestions/", {}, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -172,15 +171,15 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase): def setUp(self): """Set up test data.""" super().setUp() - + # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.user_with_permission = User.objects.create_user( - username="permitted", email="permitted@test.com", password="permitted123" + username="permitted", email="permitted@test.com", password="permitted123", ) - + # Assign apply permission content_type = ContentType.objects.get_for_model(Document) apply_permission, _ = Permission.objects.get_or_create( @@ -189,13 +188,13 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase): content_type=content_type, ) self.user_with_permission.user_permissions.add(apply_permission) - + # Create test document self.document = Document.objects.create( title="Test Document", - content="Test content" + content="Test content", ) - + # Create test metadata self.tag = Tag.objects.create(name="Test Tag") self.correspondent = Correspondent.objects.create(name="Test Corp") @@ -205,16 +204,16 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase): response = self.client.post( "/api/ai/suggestions/apply/", {"document_id": self.document.id}, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_apply_tags_success(self): """Test successfully applying tag suggestions.""" self.client.force_authenticate(user=self.superuser) - - with mock.patch('documents.views.get_ai_scanner') as mock_scanner: + + with mock.patch("documents.views.get_ai_scanner") as mock_scanner: # Mock the scanner response mock_scan_result = mock.MagicMock() mock_scan_result.tags = [(self.tag.id, 0.85)] @@ -223,29 +222,29 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase): mock_scan_result.storage_path = None mock_scan_result.title_suggestion = None mock_scan_result.custom_fields = {} - + mock_scanner_instance = mock.MagicMock() mock_scanner_instance.scan_document.return_value = mock_scan_result mock_scanner_instance.auto_apply_threshold = 0.80 mock_scanner.return_value = mock_scanner_instance - + response = self.client.post( "/api/ai/suggestions/apply/", { "document_id": self.document.id, - "apply_tags": True + "apply_tags": True, }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["status"], "success") def test_apply_correspondent_success(self): """Test successfully applying correspondent suggestion.""" self.client.force_authenticate(user=self.superuser) - - with mock.patch('documents.views.get_ai_scanner') as mock_scanner: + + with mock.patch("documents.views.get_ai_scanner") as mock_scanner: # Mock the scanner response mock_scan_result = mock.MagicMock() mock_scan_result.tags = [] @@ -254,23 +253,23 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase): mock_scan_result.storage_path = None mock_scan_result.title_suggestion = None mock_scan_result.custom_fields = {} - + mock_scanner_instance = mock.MagicMock() mock_scanner_instance.scan_document.return_value = mock_scan_result mock_scanner_instance.auto_apply_threshold = 0.80 mock_scanner.return_value = mock_scanner_instance - + response = self.client.post( "/api/ai/suggestions/apply/", { "document_id": self.document.id, - "apply_correspondent": True + "apply_correspondent": True, }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - + # Verify correspondent was applied self.document.refresh_from_db() self.assertEqual(self.document.correspondent, self.correspondent) @@ -282,43 +281,43 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase): def setUp(self): """Set up test data.""" super().setUp() - + # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.user_without_permission = User.objects.create_user( - username="regular", email="regular@test.com", password="regular123" + username="regular", email="regular@test.com", password="regular123", ) def test_unauthorized_access_denied(self): """Test that unauthenticated users are denied.""" response = self.client.get("/api/ai/config/") - + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_user_without_permission_denied(self): """Test that users without permission are denied.""" self.client.force_authenticate(user=self.user_without_permission) - + response = self.client.get("/api/ai/config/") - + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_get_config_success(self): """Test getting AI configuration.""" self.client.force_authenticate(user=self.superuser) - - with mock.patch('documents.views.get_ai_scanner') as mock_scanner: + + with mock.patch("documents.views.get_ai_scanner") as mock_scanner: mock_scanner_instance = mock.MagicMock() mock_scanner_instance.auto_apply_threshold = 0.80 mock_scanner_instance.suggest_threshold = 0.60 mock_scanner_instance.ml_enabled = True mock_scanner_instance.advanced_ocr_enabled = True mock_scanner.return_value = mock_scanner_instance - + response = self.client.get("/api/ai/config/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("auto_apply_threshold", response.data) self.assertEqual(response.data["auto_apply_threshold"], 0.80) @@ -326,31 +325,31 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase): def test_update_config_success(self): """Test updating AI configuration.""" self.client.force_authenticate(user=self.superuser) - + response = self.client.post( "/api/ai/config/", { "auto_apply_threshold": 0.90, - "suggest_threshold": 0.70 + "suggest_threshold": 0.70, }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["status"], "success") def test_update_config_invalid_threshold(self): """Test updating with invalid threshold value.""" self.client.force_authenticate(user=self.superuser) - + response = self.client.post( "/api/ai/config/", { - "auto_apply_threshold": 1.5 # Invalid: > 1.0 + "auto_apply_threshold": 1.5, # Invalid: > 1.0 }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -360,18 +359,18 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase): def setUp(self): """Set up test data.""" super().setUp() - + # Create users self.superuser = User.objects.create_superuser( - username="admin", email="admin@test.com", password="admin123" + username="admin", email="admin@test.com", password="admin123", ) self.user_with_permission = User.objects.create_user( - username="permitted", email="permitted@test.com", password="permitted123" + username="permitted", email="permitted@test.com", password="permitted123", ) self.user_without_permission = User.objects.create_user( - username="regular", email="regular@test.com", password="regular123" + username="regular", email="regular@test.com", password="regular123", ) - + # Assign approval permission content_type = ContentType.objects.get_for_model(Document) approval_permission, _ = Permission.objects.get_or_create( @@ -380,12 +379,12 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase): content_type=content_type, ) self.user_with_permission.user_permissions.add(approval_permission) - + # Create test deletion request self.deletion_request = DeletionRequest.objects.create( user=self.user_with_permission, requested_by_ai=True, - ai_reason="Document appears to be a duplicate" + ai_reason="Document appears to be a duplicate", ) def test_unauthorized_access_denied(self): @@ -394,102 +393,102 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase): "/api/ai/deletions/approve/", { "request_id": self.deletion_request.id, - "action": "approve" + "action": "approve", }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_user_without_permission_denied(self): """Test that users without permission are denied.""" self.client.force_authenticate(user=self.user_without_permission) - + response = self.client.post( "/api/ai/deletions/approve/", { "request_id": self.deletion_request.id, - "action": "approve" + "action": "approve", }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_approve_deletion_success(self): """Test successfully approving a deletion request.""" self.client.force_authenticate(user=self.user_with_permission) - + response = self.client.post( "/api/ai/deletions/approve/", { "request_id": self.deletion_request.id, - "action": "approve" + "action": "approve", }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["status"], "success") - + # Verify status was updated self.deletion_request.refresh_from_db() self.assertEqual( self.deletion_request.status, - DeletionRequest.STATUS_APPROVED + DeletionRequest.STATUS_APPROVED, ) def test_reject_deletion_success(self): """Test successfully rejecting a deletion request.""" self.client.force_authenticate(user=self.user_with_permission) - + response = self.client.post( "/api/ai/deletions/approve/", { "request_id": self.deletion_request.id, "action": "reject", - "reason": "Document is still needed" + "reason": "Document is still needed", }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - + # Verify status was updated self.deletion_request.refresh_from_db() self.assertEqual( self.deletion_request.status, - DeletionRequest.STATUS_REJECTED + DeletionRequest.STATUS_REJECTED, ) def test_invalid_request_id(self): """Test handling of invalid deletion request ID.""" self.client.force_authenticate(user=self.superuser) - + response = self.client.post( "/api/ai/deletions/approve/", { "request_id": 99999, - "action": "approve" + "action": "approve", }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_superuser_can_approve_any_request(self): """Test that superusers can approve any deletion request.""" self.client.force_authenticate(user=self.superuser) - + response = self.client.post( "/api/ai/deletions/approve/", { "request_id": self.deletion_request.id, - "action": "approve" + "action": "approve", }, - format="json" + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -499,14 +498,14 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase): def setUp(self): """Set up test data.""" super().setUp() - + # Create user with all AI permissions self.power_user = User.objects.create_user( - username="power_user", email="power@test.com", password="power123" + username="power_user", email="power@test.com", password="power123", ) - + content_type = ContentType.objects.get_for_model(Document) - + # Assign all AI permissions permissions = [ "can_view_ai_suggestions", @@ -514,7 +513,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase): "can_approve_deletions", "can_configure_ai", ] - + for codename in permissions: perm, _ = Permission.objects.get_or_create( codename=codename, @@ -522,18 +521,18 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase): content_type=content_type, ) self.power_user.user_permissions.add(perm) - + self.document = Document.objects.create( title="Test Doc", - content="Test" + content="Test", ) def test_power_user_can_access_all_endpoints(self): """Test that user with all permissions can access all endpoints.""" self.client.force_authenticate(user=self.power_user) - + # Test suggestions endpoint - with mock.patch('documents.views.get_ai_scanner') as mock_scanner: + with mock.patch("documents.views.get_ai_scanner") as mock_scanner: mock_scan_result = mock.MagicMock() mock_scan_result.tags = [] mock_scan_result.correspondent = None @@ -541,7 +540,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase): mock_scan_result.storage_path = None mock_scan_result.title_suggestion = None mock_scan_result.custom_fields = {} - + mock_scanner_instance = mock.MagicMock() mock_scanner_instance.scan_document.return_value = mock_scan_result mock_scanner_instance.auto_apply_threshold = 0.80 @@ -549,25 +548,25 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase): mock_scanner_instance.ml_enabled = True mock_scanner_instance.advanced_ocr_enabled = True mock_scanner.return_value = mock_scanner_instance - + response1 = self.client.post( "/api/ai/suggestions/", {"document_id": self.document.id}, - format="json" + format="json", ) self.assertEqual(response1.status_code, status.HTTP_200_OK) - + # Test apply endpoint response2 = self.client.post( "/api/ai/suggestions/apply/", { "document_id": self.document.id, - "apply_tags": False + "apply_tags": False, }, - format="json" + format="json", ) self.assertEqual(response2.status_code, status.HTTP_200_OK) - + # Test config endpoint response3 = self.client.get("/api/ai/config/") self.assertEqual(response3.status_code, status.HTTP_200_OK) diff --git a/src/documents/tests/test_api_ai_suggestions.py b/src/documents/tests/test_api_ai_suggestions.py index 74705690f..c2af4199d 100644 --- a/src/documents/tests/test_api_ai_suggestions.py +++ b/src/documents/tests/test_api_ai_suggestions.py @@ -9,14 +9,12 @@ from rest_framework import status from rest_framework.test import APITestCase from documents.ai_scanner import AIScanResult -from documents.models import ( - AISuggestionFeedback, - Correspondent, - Document, - DocumentType, - StoragePath, - Tag, -) +from documents.models import AISuggestionFeedback +from documents.models import Correspondent +from documents.models import Document +from documents.models import DocumentType +from documents.models import StoragePath +from documents.models import Tag from documents.tests.utils import DirectoriesMixin @@ -25,11 +23,11 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): def setUp(self): super().setUp() - + # Create test user self.user = User.objects.create_superuser(username="test_admin") self.client.force_authenticate(user=self.user) - + # Create test data self.correspondent = Correspondent.objects.create( name="Test Corp", @@ -52,7 +50,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): path="/archive/", pk=1, ) - + # Create test document self.document = Document.objects.create( title="Test Document", @@ -64,12 +62,12 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): def test_ai_suggestions_endpoint_exists(self): """Test that the ai-suggestions endpoint is accessible.""" response = self.client.get( - f"/api/documents/{self.document.pk}/ai-suggestions/" + f"/api/documents/{self.document.pk}/ai-suggestions/", ) # Should not be 404 self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND) - @mock.patch('documents.ai_scanner.get_ai_scanner') + @mock.patch("documents.ai_scanner.get_ai_scanner") def test_get_ai_suggestions_success(self, mock_get_scanner): """Test successfully getting AI suggestions for a document.""" # Create mock scan result @@ -79,39 +77,39 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): scan_result.document_type = (self.doc_type.id, 0.88) scan_result.storage_path = (self.storage_path.id, 0.80) scan_result.title_suggestion = "Suggested Title" - + # Mock scanner mock_scanner = mock.Mock() mock_scanner.scan_document.return_value = scan_result mock_get_scanner.return_value = mock_scanner - + # Make request response = self.client.get( - f"/api/documents/{self.document.pk}/ai-suggestions/" + f"/api/documents/{self.document.pk}/ai-suggestions/", ) - + # Verify response self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - + # Check tags - self.assertIn('tags', data) - self.assertEqual(len(data['tags']), 2) - self.assertEqual(data['tags'][0]['id'], self.tag1.id) - self.assertEqual(data['tags'][0]['confidence'], 0.85) - + self.assertIn("tags", data) + self.assertEqual(len(data["tags"]), 2) + self.assertEqual(data["tags"][0]["id"], self.tag1.id) + self.assertEqual(data["tags"][0]["confidence"], 0.85) + # Check correspondent - self.assertIn('correspondent', data) - self.assertEqual(data['correspondent']['id'], self.correspondent.id) - self.assertEqual(data['correspondent']['confidence'], 0.90) - + self.assertIn("correspondent", data) + self.assertEqual(data["correspondent"]["id"], self.correspondent.id) + self.assertEqual(data["correspondent"]["confidence"], 0.90) + # Check document type - self.assertIn('document_type', data) - self.assertEqual(data['document_type']['id'], self.doc_type.id) - + self.assertIn("document_type", data) + self.assertEqual(data["document_type"]["id"], self.doc_type.id) + # Check title suggestion - self.assertIn('title_suggestion', data) - self.assertEqual(data['title_suggestion']['title'], "Suggested Title") + self.assertIn("title_suggestion", data) + self.assertEqual(data["title_suggestion"]["title"], "Suggested Title") def test_get_ai_suggestions_no_content(self): """Test getting AI suggestions for document without content.""" @@ -122,43 +120,43 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): checksum="empty123", mime_type="application/pdf", ) - + response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/") - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertIn("no content", response.json()['detail'].lower()) + self.assertIn("no content", response.json()["detail"].lower()) def test_get_ai_suggestions_document_not_found(self): """Test getting AI suggestions for non-existent document.""" response = self.client.get("/api/documents/99999/ai-suggestions/") - + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_apply_suggestion_tag(self): """Test applying a tag suggestion.""" request_data = { - 'suggestion_type': 'tag', - 'value_id': self.tag1.id, - 'confidence': 0.85, + "suggestion_type": "tag", + "value_id": self.tag1.id, + "confidence": 0.85, } - + response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.json()['status'], 'success') - + self.assertEqual(response.json()["status"], "success") + # Verify tag was applied self.document.refresh_from_db() self.assertIn(self.tag1, self.document.tags.all()) - + # Verify feedback was recorded feedback = AISuggestionFeedback.objects.filter( document=self.document, - suggestion_type='tag', + suggestion_type="tag", ).first() self.assertIsNotNone(feedback) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED) @@ -169,27 +167,27 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): def test_apply_suggestion_correspondent(self): """Test applying a correspondent suggestion.""" request_data = { - 'suggestion_type': 'correspondent', - 'value_id': self.correspondent.id, - 'confidence': 0.90, + "suggestion_type": "correspondent", + "value_id": self.correspondent.id, + "confidence": 0.90, } - + response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - + # Verify correspondent was applied self.document.refresh_from_db() self.assertEqual(self.document.correspondent, self.correspondent) - + # Verify feedback was recorded feedback = AISuggestionFeedback.objects.filter( document=self.document, - suggestion_type='correspondent', + suggestion_type="correspondent", ).first() self.assertIsNotNone(feedback) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED) @@ -197,19 +195,19 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): def test_apply_suggestion_document_type(self): """Test applying a document type suggestion.""" request_data = { - 'suggestion_type': 'document_type', - 'value_id': self.doc_type.id, - 'confidence': 0.88, + "suggestion_type": "document_type", + "value_id": self.doc_type.id, + "confidence": 0.88, } - + response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - + # Verify document type was applied self.document.refresh_from_db() self.assertEqual(self.document.document_type, self.doc_type) @@ -217,91 +215,91 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): def test_apply_suggestion_title(self): """Test applying a title suggestion.""" request_data = { - 'suggestion_type': 'title', - 'value_text': 'New Suggested Title', - 'confidence': 0.80, + "suggestion_type": "title", + "value_text": "New Suggested Title", + "confidence": 0.80, } - + response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - + # Verify title was applied self.document.refresh_from_db() - self.assertEqual(self.document.title, 'New Suggested Title') + self.assertEqual(self.document.title, "New Suggested Title") def test_apply_suggestion_invalid_type(self): """Test applying suggestion with invalid type.""" request_data = { - 'suggestion_type': 'invalid_type', - 'value_id': 1, - 'confidence': 0.85, + "suggestion_type": "invalid_type", + "value_id": 1, + "confidence": 0.85, } - + response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_apply_suggestion_missing_value(self): """Test applying suggestion without value_id or value_text.""" request_data = { - 'suggestion_type': 'tag', - 'confidence': 0.85, + "suggestion_type": "tag", + "confidence": 0.85, } - + response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_apply_suggestion_nonexistent_object(self): """Test applying suggestion with non-existent object ID.""" request_data = { - 'suggestion_type': 'tag', - 'value_id': 99999, - 'confidence': 0.85, + "suggestion_type": "tag", + "value_id": 99999, + "confidence": 0.85, } - + response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_reject_suggestion(self): """Test rejecting an AI suggestion.""" request_data = { - 'suggestion_type': 'tag', - 'value_id': self.tag1.id, - 'confidence': 0.65, + "suggestion_type": "tag", + "value_id": self.tag1.id, + "confidence": 0.65, } - + response = self.client.post( f"/api/documents/{self.document.pk}/reject-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.json()['status'], 'success') - + self.assertEqual(response.json()["status"], "success") + # Verify feedback was recorded feedback = AISuggestionFeedback.objects.filter( document=self.document, - suggestion_type='tag', + suggestion_type="tag", ).first() self.assertIsNotNone(feedback) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED) @@ -312,46 +310,46 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): def test_reject_suggestion_with_text(self): """Test rejecting a suggestion with text value.""" request_data = { - 'suggestion_type': 'title', - 'value_text': 'Bad Title Suggestion', - 'confidence': 0.50, + "suggestion_type": "title", + "value_text": "Bad Title Suggestion", + "confidence": 0.50, } - + response = self.client.post( f"/api/documents/{self.document.pk}/reject-suggestion/", data=request_data, - format='json', + format="json", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - + # Verify feedback was recorded feedback = AISuggestionFeedback.objects.filter( document=self.document, - suggestion_type='title', + suggestion_type="title", ).first() self.assertIsNotNone(feedback) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED) - self.assertEqual(feedback.suggested_value_text, 'Bad Title Suggestion') + self.assertEqual(feedback.suggested_value_text, "Bad Title Suggestion") def test_ai_suggestion_stats_empty(self): """Test getting statistics when no feedback exists.""" response = self.client.get("/api/documents/ai-suggestion-stats/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - - self.assertEqual(data['total_suggestions'], 0) - self.assertEqual(data['total_applied'], 0) - self.assertEqual(data['total_rejected'], 0) - self.assertEqual(data['accuracy_rate'], 0) + + self.assertEqual(data["total_suggestions"], 0) + self.assertEqual(data["total_applied"], 0) + self.assertEqual(data["total_rejected"], 0) + self.assertEqual(data["accuracy_rate"], 0) def test_ai_suggestion_stats_with_data(self): """Test getting statistics with feedback data.""" # Create some feedback entries AISuggestionFeedback.objects.create( document=self.document, - suggestion_type='tag', + suggestion_type="tag", suggested_value_id=self.tag1.id, confidence=0.85, status=AISuggestionFeedback.STATUS_APPLIED, @@ -359,7 +357,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): ) AISuggestionFeedback.objects.create( document=self.document, - suggestion_type='tag', + suggestion_type="tag", suggested_value_id=self.tag2.id, confidence=0.70, status=AISuggestionFeedback.STATUS_APPLIED, @@ -367,38 +365,38 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): ) AISuggestionFeedback.objects.create( document=self.document, - suggestion_type='correspondent', + suggestion_type="correspondent", suggested_value_id=self.correspondent.id, confidence=0.60, status=AISuggestionFeedback.STATUS_REJECTED, user=self.user, ) - + response = self.client.get("/api/documents/ai-suggestion-stats/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - + # Check overall stats - self.assertEqual(data['total_suggestions'], 3) - self.assertEqual(data['total_applied'], 2) - self.assertEqual(data['total_rejected'], 1) - self.assertAlmostEqual(data['accuracy_rate'], 66.67, places=1) - + self.assertEqual(data["total_suggestions"], 3) + self.assertEqual(data["total_applied"], 2) + self.assertEqual(data["total_rejected"], 1) + self.assertAlmostEqual(data["accuracy_rate"], 66.67, places=1) + # Check by_type stats - self.assertIn('by_type', data) - self.assertIn('tag', data['by_type']) - self.assertEqual(data['by_type']['tag']['total'], 2) - self.assertEqual(data['by_type']['tag']['applied'], 2) - self.assertEqual(data['by_type']['tag']['rejected'], 0) - + self.assertIn("by_type", data) + self.assertIn("tag", data["by_type"]) + self.assertEqual(data["by_type"]["tag"]["total"], 2) + self.assertEqual(data["by_type"]["tag"]["applied"], 2) + self.assertEqual(data["by_type"]["tag"]["rejected"], 0) + # Check confidence averages - self.assertGreater(data['average_confidence_applied'], 0) - self.assertGreater(data['average_confidence_rejected'], 0) - + self.assertGreater(data["average_confidence_applied"], 0) + self.assertGreater(data["average_confidence_rejected"], 0) + # Check recent suggestions - self.assertIn('recent_suggestions', data) - self.assertEqual(len(data['recent_suggestions']), 3) + self.assertIn("recent_suggestions", data) + self.assertEqual(len(data["recent_suggestions"]), 3) def test_ai_suggestion_stats_accuracy_calculation(self): """Test that accuracy rate is calculated correctly.""" @@ -406,57 +404,57 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase): for i in range(7): AISuggestionFeedback.objects.create( document=self.document, - suggestion_type='tag', + suggestion_type="tag", suggested_value_id=self.tag1.id, confidence=0.80, status=AISuggestionFeedback.STATUS_APPLIED, user=self.user, ) - + for i in range(3): AISuggestionFeedback.objects.create( document=self.document, - suggestion_type='tag', + suggestion_type="tag", suggested_value_id=self.tag2.id, confidence=0.60, status=AISuggestionFeedback.STATUS_REJECTED, user=self.user, ) - + response = self.client.get("/api/documents/ai-suggestion-stats/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - - self.assertEqual(data['total_suggestions'], 10) - self.assertEqual(data['total_applied'], 7) - self.assertEqual(data['total_rejected'], 3) - self.assertEqual(data['accuracy_rate'], 70.0) + + self.assertEqual(data["total_suggestions"], 10) + self.assertEqual(data["total_applied"], 7) + self.assertEqual(data["total_rejected"], 3) + self.assertEqual(data["accuracy_rate"], 70.0) def test_authentication_required(self): """Test that authentication is required for all endpoints.""" self.client.force_authenticate(user=None) - + # Test ai-suggestions endpoint response = self.client.get( - f"/api/documents/{self.document.pk}/ai-suggestions/" + f"/api/documents/{self.document.pk}/ai-suggestions/", ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - + # Test apply-suggestion endpoint response = self.client.post( f"/api/documents/{self.document.pk}/apply-suggestion/", data={}, ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - + # Test reject-suggestion endpoint response = self.client.post( f"/api/documents/{self.document.pk}/reject-suggestion/", data={}, ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - + # Test stats endpoint response = self.client.get("/api/documents/ai-suggestion-stats/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) diff --git a/src/documents/tests/test_api_deletion_requests.py b/src/documents/tests/test_api_deletion_requests.py index 44bd6375a..a1e65c27a 100644 --- a/src/documents/tests/test_api_deletion_requests.py +++ b/src/documents/tests/test_api_deletion_requests.py @@ -11,17 +11,14 @@ Tests cover: """ from django.contrib.auth.models import User -from django.test import override_settings from rest_framework import status from rest_framework.test import APITestCase -from documents.models import ( - Correspondent, - DeletionRequest, - Document, - DocumentType, - Tag, -) +from documents.models import Correspondent +from documents.models import DeletionRequest +from documents.models import Document +from documents.models import DocumentType +from documents.models import Tag class TestDeletionRequestAPI(APITestCase): @@ -33,7 +30,7 @@ class TestDeletionRequestAPI(APITestCase): self.user1 = User.objects.create_user(username="user1", password="pass123") self.user2 = User.objects.create_user(username="user2", password="pass123") self.admin = User.objects.create_superuser(username="admin", password="admin123") - + # Create test documents self.doc1 = Document.objects.create( title="Test Document 1", @@ -53,7 +50,7 @@ class TestDeletionRequestAPI(APITestCase): checksum="checksum3", mime_type="application/pdf", ) - + # Create deletion requests self.request1 = DeletionRequest.objects.create( requested_by_ai=True, @@ -63,7 +60,7 @@ class TestDeletionRequestAPI(APITestCase): impact_summary={"document_count": 1}, ) self.request1.documents.add(self.doc1) - + self.request2 = DeletionRequest.objects.create( requested_by_ai=True, ai_reason="Low quality document", @@ -77,7 +74,7 @@ class TestDeletionRequestAPI(APITestCase): """Test that users can list their own deletion requests.""" self.client.force_authenticate(user=self.user1) response = self.client.get("/api/deletion-requests/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]), 1) self.assertEqual(response.data["results"][0]["id"], self.request1.id) @@ -86,7 +83,7 @@ class TestDeletionRequestAPI(APITestCase): """Test that admin can list all deletion requests.""" self.client.force_authenticate(user=self.admin) response = self.client.get("/api/deletion-requests/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]), 2) @@ -94,7 +91,7 @@ class TestDeletionRequestAPI(APITestCase): """Test retrieving a single deletion request.""" self.client.force_authenticate(user=self.user1) response = self.client.get(f"/api/deletion-requests/{self.request1.id}/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["id"], self.request1.id) self.assertEqual(response.data["ai_reason"], "Duplicate document detected") @@ -104,23 +101,23 @@ class TestDeletionRequestAPI(APITestCase): def test_approve_deletion_request_as_owner(self): """Test approving a deletion request as the owner.""" self.client.force_authenticate(user=self.user1) - + # Verify document exists self.assertTrue(Document.objects.filter(id=self.doc1.id).exists()) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/approve/", {"comment": "Approved by owner"}, ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("message", response.data) self.assertIn("execution_result", response.data) self.assertEqual(response.data["execution_result"]["deleted_count"], 1) - + # Verify document was deleted self.assertFalse(Document.objects.filter(id=self.doc1.id).exists()) - + # Verify deletion request was updated self.request1.refresh_from_db() self.assertEqual(self.request1.status, DeletionRequest.STATUS_COMPLETED) @@ -131,18 +128,18 @@ class TestDeletionRequestAPI(APITestCase): def test_approve_deletion_request_as_admin(self): """Test approving a deletion request as admin.""" self.client.force_authenticate(user=self.admin) - + response = self.client.post( f"/api/deletion-requests/{self.request2.id}/approve/", {"comment": "Approved by admin"}, ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("execution_result", response.data) - + # Verify document was deleted self.assertFalse(Document.objects.filter(id=self.doc2.id).exists()) - + # Verify deletion request was updated self.request2.refresh_from_db() self.assertEqual(self.request2.status, DeletionRequest.STATUS_COMPLETED) @@ -151,16 +148,16 @@ class TestDeletionRequestAPI(APITestCase): def test_approve_deletion_request_without_permission(self): """Test that non-owners cannot approve deletion requests.""" self.client.force_authenticate(user=self.user2) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/approve/", ) - + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + # Verify document was NOT deleted self.assertTrue(Document.objects.filter(id=self.doc1.id).exists()) - + # Verify deletion request was NOT updated self.request1.refresh_from_db() self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING) @@ -169,13 +166,13 @@ class TestDeletionRequestAPI(APITestCase): """Test that already approved requests cannot be approved again.""" self.request1.status = DeletionRequest.STATUS_APPROVED self.request1.save() - + self.client.force_authenticate(user=self.user1) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/approve/", ) - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) self.assertIn("pending", response.data["error"].lower()) @@ -183,18 +180,18 @@ class TestDeletionRequestAPI(APITestCase): def test_reject_deletion_request_as_owner(self): """Test rejecting a deletion request as the owner.""" self.client.force_authenticate(user=self.user1) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/reject/", {"comment": "Not needed"}, ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("message", response.data) - + # Verify document was NOT deleted self.assertTrue(Document.objects.filter(id=self.doc1.id).exists()) - + # Verify deletion request was updated self.request1.refresh_from_db() self.assertEqual(self.request1.status, DeletionRequest.STATUS_REJECTED) @@ -205,16 +202,16 @@ class TestDeletionRequestAPI(APITestCase): def test_reject_deletion_request_as_admin(self): """Test rejecting a deletion request as admin.""" self.client.force_authenticate(user=self.admin) - + response = self.client.post( f"/api/deletion-requests/{self.request2.id}/reject/", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) - + # Verify document was NOT deleted self.assertTrue(Document.objects.filter(id=self.doc2.id).exists()) - + # Verify deletion request was updated self.request2.refresh_from_db() self.assertEqual(self.request2.status, DeletionRequest.STATUS_REJECTED) @@ -223,13 +220,13 @@ class TestDeletionRequestAPI(APITestCase): def test_reject_deletion_request_without_permission(self): """Test that non-owners cannot reject deletion requests.""" self.client.force_authenticate(user=self.user2) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/reject/", ) - + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + # Verify deletion request was NOT updated self.request1.refresh_from_db() self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING) @@ -238,31 +235,31 @@ class TestDeletionRequestAPI(APITestCase): """Test that already rejected requests cannot be rejected again.""" self.request1.status = DeletionRequest.STATUS_REJECTED self.request1.save() - + self.client.force_authenticate(user=self.user1) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/reject/", ) - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) def test_cancel_deletion_request_as_owner(self): """Test canceling a deletion request as the owner.""" self.client.force_authenticate(user=self.user1) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/cancel/", {"comment": "Changed my mind"}, ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("message", response.data) - + # Verify document was NOT deleted self.assertTrue(Document.objects.filter(id=self.doc1.id).exists()) - + # Verify deletion request was updated self.request1.refresh_from_db() self.assertEqual(self.request1.status, DeletionRequest.STATUS_CANCELLED) @@ -273,13 +270,13 @@ class TestDeletionRequestAPI(APITestCase): def test_cancel_deletion_request_without_permission(self): """Test that non-owners cannot cancel deletion requests.""" self.client.force_authenticate(user=self.user2) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/cancel/", ) - + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + # Verify deletion request was NOT updated self.request1.refresh_from_db() self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING) @@ -288,13 +285,13 @@ class TestDeletionRequestAPI(APITestCase): """Test that approved requests cannot be cancelled.""" self.request1.status = DeletionRequest.STATUS_APPROVED self.request1.save() - + self.client.force_authenticate(user=self.user1) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/cancel/", ) - + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) @@ -309,17 +306,17 @@ class TestDeletionRequestAPI(APITestCase): impact_summary={"document_count": 2}, ) multi_request.documents.add(self.doc1, self.doc3) - + self.client.force_authenticate(user=self.user1) - + response = self.client.post( f"/api/deletion-requests/{multi_request.id}/approve/", ) - + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["execution_result"]["deleted_count"], 2) self.assertEqual(response.data["execution_result"]["total_documents"], 2) - + # Verify both documents were deleted self.assertFalse(Document.objects.filter(id=self.doc1.id).exists()) self.assertFalse(Document.objects.filter(id=self.doc3.id).exists()) @@ -330,15 +327,15 @@ class TestDeletionRequestAPI(APITestCase): tag = Tag.objects.create(name="test-tag") correspondent = Correspondent.objects.create(name="Test Corp") doc_type = DocumentType.objects.create(name="Invoice") - + self.doc1.tags.add(tag) self.doc1.correspondent = correspondent self.doc1.document_type = doc_type self.doc1.save() - + self.client.force_authenticate(user=self.user1) response = self.client.get(f"/api/deletion-requests/{self.request1.id}/") - + self.assertEqual(response.status_code, status.HTTP_200_OK) doc_details = response.data["document_details"] self.assertEqual(len(doc_details), 1) @@ -352,7 +349,7 @@ class TestDeletionRequestAPI(APITestCase): """Test that unauthenticated users cannot access the API.""" response = self.client.get("/api/deletion-requests/") self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + response = self.client.post( f"/api/deletion-requests/{self.request1.id}/approve/", ) diff --git a/src/documents/tests/test_consumer.py b/src/documents/tests/test_consumer.py index 2a9c87ddf..22e7bd495 100644 --- a/src/documents/tests/test_consumer.py +++ b/src/documents/tests/test_consumer.py @@ -1350,18 +1350,18 @@ class TestConsumerAIScannerIntegration( correspondent = Correspondent.objects.create(name="Test Corp") doc_type = DocumentType.objects.create(name="Invoice") storage_path = StoragePath.objects.create(name="Invoices", path="/invoices") - + # Create mock AI scanner mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner - + # Mock scan results scan_result = AIScanResult() scan_result.tags = [(tag1.id, 0.85), (tag2.id, 0.75)] scan_result.correspondent = (correspondent.id, 0.90) scan_result.document_type = (doc_type.id, 0.85) scan_result.storage_path = (storage_path.id, 0.80) - + mock_scanner.scan_document.return_value = scan_result mock_scanner.apply_scan_results.return_value = { "applied": { @@ -1381,20 +1381,20 @@ class TestConsumerAIScannerIntegration( "workflows": [], }, } - + # Run consumer filename = self.get_test_file() with self.get_consumer(filename) as consumer: consumer.run() - + # Verify document was created document = Document.objects.first() self.assertIsNotNone(document) - + # Verify AI scanner was called mock_scanner.scan_document.assert_called_once() mock_scanner.apply_scan_results.assert_called_once() - + # Verify the call arguments call_args = mock_scanner.scan_document.call_args self.assertEqual(call_args[1]["document"], document) @@ -1412,11 +1412,11 @@ class TestConsumerAIScannerIntegration( demonstrating graceful degradation. """ filename = self.get_test_file() - + # Consumer should complete successfully even with ML disabled with self.get_consumer(filename) as consumer: consumer.run() - + # Verify document was created document = Document.objects.first() self.assertIsNotNone(document) @@ -1435,13 +1435,13 @@ class TestConsumerAIScannerIntegration( mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner mock_scanner.scan_document.side_effect = Exception("AI Scanner failed") - + filename = self.get_test_file() - + # Consumer should complete despite AI scanner failure with self.get_consumer(filename) as consumer: consumer.run() - + # Verify document was created despite AI failure document = Document.objects.first() self.assertIsNotNone(document) @@ -1457,17 +1457,17 @@ class TestConsumerAIScannerIntegration( """ mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner - + self.create_empty_scan_result_mock(mock_scanner) - + filename = self.get_test_file() - + with self.get_consumer(filename) as consumer: consumer.run() - + document = Document.objects.first() self.assertIsNotNone(document) - + # Verify AI scanner was called with PDF mock_scanner.scan_document.assert_called_once() call_args = mock_scanner.scan_document.call_args @@ -1488,7 +1488,7 @@ class TestConsumerAIScannerIntegration( self.dirs.scratch_dir, self.get_test_archive_file(), ) - + with mock.patch("documents.parsers.document_consumer_declaration.send") as m: m.return_value = [ ( @@ -1500,21 +1500,21 @@ class TestConsumerAIScannerIntegration( }, ), ] - + mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner - + self.create_empty_scan_result_mock(mock_scanner) - + # Create a PNG file dst = self.get_test_file_with_name("sample.png") - + with self.get_consumer(dst) as consumer: consumer.run() - + document = Document.objects.first() self.assertIsNotNone(document) - + # Verify AI scanner was called mock_scanner.scan_document.assert_called_once() @@ -1527,26 +1527,26 @@ class TestConsumerAIScannerIntegration( Verifies that AI scanning adds minimal overhead to document consumption. """ import time - + mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner - + self.create_empty_scan_result_mock(mock_scanner) - + filename = self.get_test_file() - + start_time = time.time() with self.get_consumer(filename) as consumer: consumer.run() end_time = time.time() - + # Verify document was created document = Document.objects.first() self.assertIsNotNone(document) - + # Verify AI scanner was called mock_scanner.scan_document.assert_called_once() - + # With mocks, this should be very fast (<1s). # TODO: Implement proper performance testing with real ML models in integration/performance test suite. elapsed_time = end_time - start_time @@ -1561,33 +1561,32 @@ class TestConsumerAIScannerIntegration( Verifies that AI scanner respects database transactions and handles rollbacks correctly. """ - from django.db import transaction as db_transaction - + tag = Tag.objects.create(name="Invoice") - + mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner - + scan_result = AIScanResult() scan_result.tags = [(tag.id, 0.85)] mock_scanner.scan_document.return_value = scan_result - + # Mock apply_scan_results to raise an exception after some work def apply_with_error(document, scan_result, auto_apply=True): # Simulate partial work document.tags.add(tag) # Then fail raise Exception("Simulated transaction failure") - + mock_scanner.apply_scan_results.side_effect = apply_with_error - + filename = self.get_test_file() - + # Even with AI scanner failure, the document should still be created # because we handle AI scanner errors gracefully with self.get_consumer(filename) as consumer: consumer.run() - + document = Document.objects.first() self.assertIsNotNone(document) # The tag addition from AI scanner should be rolled back due to exception @@ -1604,17 +1603,17 @@ class TestConsumerAIScannerIntegration( """ tag1 = Tag.objects.create(name="Invoice") tag2 = Tag.objects.create(name="Receipt") - + mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner - + # Configure scanner to return different results for each call scan_results = [] for tag in [tag1, tag2]: scan_result = AIScanResult() scan_result.tags = [(tag.id, 0.85)] scan_results.append(scan_result) - + mock_scanner.scan_document.side_effect = scan_results mock_scanner.apply_scan_results.return_value = { "applied": { @@ -1634,20 +1633,20 @@ class TestConsumerAIScannerIntegration( "workflows": [], }, } - + # Process multiple documents filenames = [self.get_test_file()] # Create second file filenames.append(self.get_test_file_with_name("sample2.pdf")) - + for filename in filenames: with self.get_consumer(filename) as consumer: consumer.run() - + # Verify both documents were created documents = Document.objects.all() self.assertEqual(documents.count(), 2) - + # Verify AI scanner was called for each document self.assertEqual(mock_scanner.scan_document.call_count, 2) @@ -1661,17 +1660,17 @@ class TestConsumerAIScannerIntegration( """ mock_scanner = MagicMock() mock_get_scanner.return_value = mock_scanner - + self.create_empty_scan_result_mock(mock_scanner) - + filename = self.get_test_file() - + with self.get_consumer(filename) as consumer: consumer.run() - + document = Document.objects.first() self.assertIsNotNone(document) - + # Verify AI scanner received text content mock_scanner.scan_document.assert_called_once() call_args = mock_scanner.scan_document.call_args @@ -1686,10 +1685,10 @@ class TestConsumerAIScannerIntegration( the AI scanner is not invoked at all. """ filename = self.get_test_file() - + with self.get_consumer(filename) as consumer: consumer.run() - + # Document should be created normally without AI scanning document = Document.objects.first() self.assertIsNotNone(document) diff --git a/src/documents/tests/test_deletion_request_model.py b/src/documents/tests/test_deletion_request_model.py index ed1ce5974..90c06e4b6 100644 --- a/src/documents/tests/test_deletion_request_model.py +++ b/src/documents/tests/test_deletion_request_model.py @@ -15,13 +15,11 @@ from django.contrib.auth.models import User from django.test import TestCase from django.utils import timezone -from documents.models import ( - Correspondent, - DeletionRequest, - Document, - DocumentType, - Tag, -) +from documents.models import Correspondent +from documents.models import DeletionRequest +from documents.models import Document +from documents.models import DocumentType +from documents.models import Tag class TestDeletionRequestModelCreation(TestCase): @@ -45,7 +43,7 @@ class TestDeletionRequestModelCreation(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + self.assertIsNotNone(request) self.assertTrue(request.requested_by_ai) self.assertEqual(request.ai_reason, "Test reason") @@ -59,7 +57,7 @@ class TestDeletionRequestModelCreation(TestCase): ai_reason="Test", user=self.user, ) - + self.assertIsNotNone(request.created_at) self.assertIsNotNone(request.updated_at) @@ -70,7 +68,7 @@ class TestDeletionRequestModelCreation(TestCase): ai_reason="Test", user=self.user, ) - + self.assertEqual(request.status, DeletionRequest.STATUS_PENDING) def test_deletion_request_with_documents(self): @@ -80,16 +78,16 @@ class TestDeletionRequestModelCreation(TestCase): ai_reason="Test", user=self.user, ) - + doc2 = Document.objects.create( title="Document 2", content="Content 2", checksum="checksum2", mime_type="application/pdf", ) - + request.documents.add(self.doc, doc2) - + self.assertEqual(request.documents.count(), 2) self.assertIn(self.doc, request.documents.all()) self.assertIn(doc2, request.documents.all()) @@ -101,7 +99,7 @@ class TestDeletionRequestModelCreation(TestCase): ai_reason="Test", user=self.user, ) - + self.assertIsInstance(request.impact_summary, dict) self.assertEqual(request.impact_summary, {}) @@ -112,14 +110,14 @@ class TestDeletionRequestModelCreation(TestCase): "affected_tags": ["tag1", "tag2"], "metadata": {"key": "value"}, } - + request = DeletionRequest.objects.create( requested_by_ai=True, ai_reason="Test", user=self.user, impact_summary=impact, ) - + self.assertEqual(request.impact_summary["document_count"], 5) self.assertEqual(request.impact_summary["affected_tags"], ["tag1", "tag2"]) @@ -131,9 +129,9 @@ class TestDeletionRequestModelCreation(TestCase): user=self.user, ) request.documents.add(self.doc) - + str_repr = str(request) - + self.assertIn("Deletion Request", str_repr) self.assertIn(str(request.id), str_repr) self.assertIn("1 documents", str_repr) @@ -162,9 +160,9 @@ class TestDeletionRequestApprove(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + result = request.approve(self.approver, "Approved") - + self.assertTrue(result) self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED) self.assertEqual(request.reviewed_by, self.approver) @@ -179,9 +177,9 @@ class TestDeletionRequestApprove(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + result = request.approve(self.approver) - + self.assertTrue(result) self.assertEqual(request.review_comment, "") @@ -195,9 +193,9 @@ class TestDeletionRequestApprove(TestCase): reviewed_by=self.user, reviewed_at=timezone.now(), ) - + result = request.approve(self.approver, "Trying to approve again") - + self.assertFalse(result) self.assertEqual(request.reviewed_by, self.user) # Should not change @@ -211,9 +209,9 @@ class TestDeletionRequestApprove(TestCase): reviewed_by=self.user, reviewed_at=timezone.now(), ) - + result = request.approve(self.approver, "Trying to approve rejected") - + self.assertFalse(result) self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED) @@ -225,9 +223,9 @@ class TestDeletionRequestApprove(TestCase): user=self.user, status=DeletionRequest.STATUS_CANCELLED, ) - + result = request.approve(self.approver, "Trying to approve cancelled") - + self.assertFalse(result) self.assertEqual(request.status, DeletionRequest.STATUS_CANCELLED) @@ -239,9 +237,9 @@ class TestDeletionRequestApprove(TestCase): user=self.user, status=DeletionRequest.STATUS_COMPLETED, ) - + result = request.approve(self.approver, "Trying to approve completed") - + self.assertFalse(result) self.assertEqual(request.status, DeletionRequest.STATUS_COMPLETED) @@ -253,11 +251,11 @@ class TestDeletionRequestApprove(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + before_approval = timezone.now() result = request.approve(self.approver, "Approved") after_approval = timezone.now() - + self.assertTrue(result) self.assertIsNotNone(request.reviewed_at) self.assertGreaterEqual(request.reviewed_at, before_approval) @@ -280,9 +278,9 @@ class TestDeletionRequestReject(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + result = request.reject(self.reviewer, "Not necessary") - + self.assertTrue(result) self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED) self.assertEqual(request.reviewed_by, self.reviewer) @@ -297,9 +295,9 @@ class TestDeletionRequestReject(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + result = request.reject(self.reviewer) - + self.assertTrue(result) self.assertEqual(request.review_comment, "") @@ -313,9 +311,9 @@ class TestDeletionRequestReject(TestCase): reviewed_by=self.user, reviewed_at=timezone.now(), ) - + result = request.reject(self.reviewer, "Trying to reject again") - + self.assertFalse(result) self.assertEqual(request.reviewed_by, self.user) # Should not change @@ -329,9 +327,9 @@ class TestDeletionRequestReject(TestCase): reviewed_by=self.user, reviewed_at=timezone.now(), ) - + result = request.reject(self.reviewer, "Trying to reject approved") - + self.assertFalse(result) self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED) @@ -343,9 +341,9 @@ class TestDeletionRequestReject(TestCase): user=self.user, status=DeletionRequest.STATUS_CANCELLED, ) - + result = request.reject(self.reviewer, "Trying to reject cancelled") - + self.assertFalse(result) self.assertEqual(request.status, DeletionRequest.STATUS_CANCELLED) @@ -357,9 +355,9 @@ class TestDeletionRequestReject(TestCase): user=self.user, status=DeletionRequest.STATUS_COMPLETED, ) - + result = request.reject(self.reviewer, "Trying to reject completed") - + self.assertFalse(result) self.assertEqual(request.status, DeletionRequest.STATUS_COMPLETED) @@ -371,11 +369,11 @@ class TestDeletionRequestReject(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + before_rejection = timezone.now() result = request.reject(self.reviewer, "Rejected") after_rejection = timezone.now() - + self.assertTrue(result) self.assertIsNotNone(request.reviewed_at) self.assertGreaterEqual(request.reviewed_at, before_rejection) @@ -389,11 +387,11 @@ class TestDeletionRequestWorkflowScenarios(TestCase): """Set up test data.""" self.user = User.objects.create_user(username="user1", password="pass") self.approver = User.objects.create_user(username="approver", password="pass") - + self.correspondent = Correspondent.objects.create(name="Test Corp") self.doc_type = DocumentType.objects.create(name="Invoice") self.tag = Tag.objects.create(name="Important") - + self.doc1 = Document.objects.create( title="Document 1", content="Content 1", @@ -403,7 +401,7 @@ class TestDeletionRequestWorkflowScenarios(TestCase): document_type=self.doc_type, ) self.doc1.tags.add(self.tag) - + self.doc2 = Document.objects.create( title="Document 2", content="Content 2", @@ -421,15 +419,15 @@ class TestDeletionRequestWorkflowScenarios(TestCase): impact_summary={"document_count": 2}, ) request.documents.add(self.doc1, self.doc2) - + # Verify initial state self.assertEqual(request.status, DeletionRequest.STATUS_PENDING) self.assertIsNone(request.reviewed_by) self.assertIsNone(request.reviewed_at) - + # Approve success = request.approve(self.approver, "Confirmed duplicates") - + # Verify final state self.assertTrue(success) self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED) @@ -446,13 +444,13 @@ class TestDeletionRequestWorkflowScenarios(TestCase): status=DeletionRequest.STATUS_PENDING, ) request.documents.add(self.doc1) - + # Verify initial state self.assertEqual(request.status, DeletionRequest.STATUS_PENDING) - + # Reject success = request.reject(self.approver, "Not duplicates") - + # Verify final state self.assertTrue(success) self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED) @@ -468,14 +466,14 @@ class TestDeletionRequestWorkflowScenarios(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + # Reject first request.reject(self.user, "Rejected") self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED) - + # Try to approve success = request.approve(self.approver, "Changed my mind") - + # Should fail self.assertFalse(success) self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED) @@ -488,14 +486,14 @@ class TestDeletionRequestWorkflowScenarios(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + # Approve first request.approve(self.approver, "Approved") self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED) - + # Try to reject success = request.reject(self.user, "Changed my mind") - + # Should fail self.assertFalse(success) self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED) @@ -516,7 +514,7 @@ class TestDeletionRequestAuditTrail(TestCase): ai_reason="Test", user=self.user, ) - + self.assertEqual(request.user, self.user) def test_audit_trail_records_reviewer(self): @@ -527,9 +525,9 @@ class TestDeletionRequestAuditTrail(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + request.approve(self.approver, "Approved") - + self.assertEqual(request.reviewed_by, self.approver) self.assertNotEqual(request.reviewed_by, request.user) @@ -540,12 +538,12 @@ class TestDeletionRequestAuditTrail(TestCase): ai_reason="Test", user=self.user, ) - + created_at = request.created_at - + # Approve the request request.approve(self.approver, "Approved") - + # Verify timestamps self.assertIsNotNone(request.created_at) self.assertIsNotNone(request.updated_at) @@ -555,16 +553,16 @@ class TestDeletionRequestAuditTrail(TestCase): def test_audit_trail_preserves_ai_reason(self): """Test that AI's original reason is preserved.""" original_reason = "AI detected duplicates based on content similarity" - + request = DeletionRequest.objects.create( requested_by_ai=True, ai_reason=original_reason, user=self.user, ) - + # Approve with different comment request.approve(self.approver, "User confirmed") - + # Original AI reason should be preserved self.assertEqual(request.ai_reason, original_reason) self.assertEqual(request.review_comment, "User confirmed") @@ -582,7 +580,7 @@ class TestDeletionRequestAuditTrail(TestCase): "completed_by": "system", }, ) - + self.assertEqual(request.completion_details["deleted_count"], 5) self.assertEqual(request.completion_details["failed_count"], 0) @@ -593,17 +591,17 @@ class TestDeletionRequestAuditTrail(TestCase): ai_reason="Reason 1", user=self.user, ) - + request2 = DeletionRequest.objects.create( requested_by_ai=True, ai_reason="Reason 2", user=self.user, ) - + # Approve one, reject another request1.approve(self.approver, "Approved") request2.reject(self.approver, "Rejected") - + # Verify each has its own audit trail self.assertEqual(request1.status, DeletionRequest.STATUS_APPROVED) self.assertEqual(request2.status, DeletionRequest.STATUS_REJECTED) @@ -625,13 +623,13 @@ class TestDeletionRequestModelRelationships(TestCase): ai_reason="Test", user=self.user, ) - + request_id = request.id self.assertEqual(DeletionRequest.objects.filter(id=request_id).count(), 1) - + # Delete user self.user.delete() - + # Request should be deleted self.assertEqual(DeletionRequest.objects.filter(id=request_id).count(), 0) @@ -649,15 +647,15 @@ class TestDeletionRequestModelRelationships(TestCase): checksum="checksum2", mime_type="application/pdf", ) - + request = DeletionRequest.objects.create( requested_by_ai=True, ai_reason="Test", user=self.user, ) - + request.documents.add(doc1, doc2) - + self.assertEqual(request.documents.count(), 2) self.assertEqual(doc1.deletion_requests.count(), 1) self.assertEqual(doc2.deletion_requests.count(), 1) @@ -670,29 +668,29 @@ class TestDeletionRequestModelRelationships(TestCase): user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + self.assertIsNone(request.reviewed_by) def test_reviewed_by_set_null_on_delete(self): """Test that reviewed_by is set to null when reviewer is deleted.""" approver = User.objects.create_user(username="approver", password="pass") - + request = DeletionRequest.objects.create( requested_by_ai=True, ai_reason="Test", user=self.user, status=DeletionRequest.STATUS_PENDING, ) - + request.approve(approver, "Approved") self.assertEqual(request.reviewed_by, approver) - + # Delete approver approver.delete() - + # Refresh request request.refresh_from_db() - + # reviewed_by should be null self.assertIsNone(request.reviewed_by) # But the request should still exist diff --git a/src/documents/tests/test_ml_cache.py b/src/documents/tests/test_ml_cache.py index 719142d83..8f26a46aa 100644 --- a/src/documents/tests/test_ml_cache.py +++ b/src/documents/tests/test_ml_cache.py @@ -8,11 +8,9 @@ from unittest import mock from django.test import TestCase -from documents.ml.model_cache import ( - CacheMetrics, - LRUCache, - ModelCacheManager, -) +from documents.ml.model_cache import CacheMetrics +from documents.ml.model_cache import LRUCache +from documents.ml.model_cache import ModelCacheManager class TestCacheMetrics(TestCase): @@ -22,10 +20,10 @@ class TestCacheMetrics(TestCase): """Test recording cache hits.""" metrics = CacheMetrics() self.assertEqual(metrics.hits, 0) - + metrics.record_hit() self.assertEqual(metrics.hits, 1) - + metrics.record_hit() self.assertEqual(metrics.hits, 2) @@ -33,26 +31,26 @@ class TestCacheMetrics(TestCase): """Test recording cache misses.""" metrics = CacheMetrics() self.assertEqual(metrics.misses, 0) - + metrics.record_miss() self.assertEqual(metrics.misses, 1) def test_get_stats(self): """Test getting cache statistics.""" metrics = CacheMetrics() - + # Initial stats stats = metrics.get_stats() self.assertEqual(stats["hits"], 0) self.assertEqual(stats["misses"], 0) self.assertEqual(stats["hit_rate"], "0.00%") - + # After some hits and misses metrics.record_hit() metrics.record_hit() metrics.record_hit() metrics.record_miss() - + stats = metrics.get_stats() self.assertEqual(stats["hits"], 3) self.assertEqual(stats["misses"], 1) @@ -64,9 +62,9 @@ class TestCacheMetrics(TestCase): metrics = CacheMetrics() metrics.record_hit() metrics.record_miss() - + metrics.reset() - + stats = metrics.get_stats() self.assertEqual(stats["hits"], 0) self.assertEqual(stats["misses"], 0) @@ -78,28 +76,28 @@ class TestLRUCache(TestCase): def test_put_and_get(self): """Test basic cache operations.""" cache = LRUCache(max_size=2) - + cache.put("key1", "value1") cache.put("key2", "value2") - + self.assertEqual(cache.get("key1"), "value1") self.assertEqual(cache.get("key2"), "value2") def test_cache_miss(self): """Test cache miss returns None.""" cache = LRUCache(max_size=2) - + result = cache.get("nonexistent") self.assertIsNone(result) def test_lru_eviction(self): """Test LRU eviction policy.""" cache = LRUCache(max_size=2) - + cache.put("key1", "value1") cache.put("key2", "value2") cache.put("key3", "value3") # Should evict key1 - + self.assertIsNone(cache.get("key1")) # Evicted self.assertEqual(cache.get("key2"), "value2") self.assertEqual(cache.get("key3"), "value3") @@ -107,12 +105,12 @@ class TestLRUCache(TestCase): def test_lru_update_access_order(self): """Test that accessing an item updates its position.""" cache = LRUCache(max_size=2) - + cache.put("key1", "value1") cache.put("key2", "value2") cache.get("key1") # Access key1, making it most recent cache.put("key3", "value3") # Should evict key2, not key1 - + self.assertEqual(cache.get("key1"), "value1") self.assertIsNone(cache.get("key2")) # Evicted self.assertEqual(cache.get("key3"), "value3") @@ -120,24 +118,24 @@ class TestLRUCache(TestCase): def test_cache_size(self): """Test cache size tracking.""" cache = LRUCache(max_size=3) - + self.assertEqual(cache.size(), 0) - + cache.put("key1", "value1") self.assertEqual(cache.size(), 1) - + cache.put("key2", "value2") self.assertEqual(cache.size(), 2) def test_clear(self): """Test clearing cache.""" cache = LRUCache(max_size=2) - + cache.put("key1", "value1") cache.put("key2", "value2") - + cache.clear() - + self.assertEqual(cache.size(), 0) self.assertIsNone(cache.get("key1")) self.assertIsNone(cache.get("key2")) @@ -155,20 +153,20 @@ class TestModelCacheManager(TestCase): """Test that ModelCacheManager is a singleton.""" instance1 = ModelCacheManager.get_instance() instance2 = ModelCacheManager.get_instance() - + self.assertIs(instance1, instance2) def test_get_or_load_model_first_time(self): """Test loading a model for the first time (cache miss).""" cache_manager = ModelCacheManager.get_instance() - + # Mock loader function mock_model = mock.Mock() loader = mock.Mock(return_value=mock_model) - + # Load model result = cache_manager.get_or_load_model("test_model", loader) - + # Verify loader was called loader.assert_called_once() self.assertIs(result, mock_model) @@ -176,17 +174,17 @@ class TestModelCacheManager(TestCase): def test_get_or_load_model_cached(self): """Test loading a model from cache (cache hit).""" cache_manager = ModelCacheManager.get_instance() - + # Mock loader function mock_model = mock.Mock() loader = mock.Mock(return_value=mock_model) - + # Load model first time cache_manager.get_or_load_model("test_model", loader) - + # Load model second time (should be cached) result = cache_manager.get_or_load_model("test_model", loader) - + # Verify loader was only called once loader.assert_called_once() self.assertIs(result, mock_model) @@ -197,48 +195,48 @@ class TestModelCacheManager(TestCase): cache_manager = ModelCacheManager.get_instance( disk_cache_dir=tmpdir, ) - + # Create test embeddings embeddings = { 1: "embedding1", 2: "embedding2", 3: "embedding3", } - + # Save to disk cache_manager.save_embeddings_to_disk("test_embeddings", embeddings) - + # Verify file was created cache_file = Path(tmpdir) / "test_embeddings.pkl" self.assertTrue(cache_file.exists()) - + # Load from disk loaded = cache_manager.load_embeddings_from_disk("test_embeddings") - + # Verify embeddings match self.assertEqual(loaded, embeddings) def test_get_metrics(self): """Test getting cache metrics.""" cache_manager = ModelCacheManager.get_instance() - + # Mock loader loader = mock.Mock(return_value=mock.Mock()) - + # Generate some cache activity cache_manager.get_or_load_model("model1", loader) cache_manager.get_or_load_model("model1", loader) # Cache hit cache_manager.get_or_load_model("model2", loader) - + # Get metrics metrics = cache_manager.get_metrics() - + # Verify metrics structure self.assertIn("hits", metrics) self.assertIn("misses", metrics) self.assertIn("cache_size", metrics) self.assertIn("max_size", metrics) - + # Verify hit/miss counts self.assertEqual(metrics["hits"], 1) # One cache hit self.assertEqual(metrics["misses"], 2) # Two cache misses @@ -249,21 +247,21 @@ class TestModelCacheManager(TestCase): cache_manager = ModelCacheManager.get_instance( disk_cache_dir=tmpdir, ) - + # Add some models to cache loader = mock.Mock(return_value=mock.Mock()) cache_manager.get_or_load_model("model1", loader) - + # Add embeddings to disk embeddings = {1: "embedding1"} cache_manager.save_embeddings_to_disk("test", embeddings) - + # Clear all cache_manager.clear_all() - + # Verify memory cache is cleared self.assertEqual(cache_manager.model_cache.size(), 0) - + # Verify disk cache is cleared cache_file = Path(tmpdir) / "test.pkl" self.assertFalse(cache_file.exists()) @@ -271,22 +269,22 @@ class TestModelCacheManager(TestCase): def test_warm_up(self): """Test model warm-up functionality.""" cache_manager = ModelCacheManager.get_instance() - + # Create mock loaders model1 = mock.Mock() model2 = mock.Mock() - + loaders = { "model1": mock.Mock(return_value=model1), "model2": mock.Mock(return_value=model2), } - + # Warm up cache_manager.warm_up(loaders) - + # Verify all loaders were called for loader in loaders.values(): loader.assert_called_once() - + # Verify models are cached self.assertEqual(cache_manager.model_cache.size(), 2) diff --git a/src/documents/tests/test_ml_smoke.py b/src/documents/tests/test_ml_smoke.py index 97eca291f..91d829179 100644 --- a/src/documents/tests/test_ml_smoke.py +++ b/src/documents/tests/test_ml_smoke.py @@ -150,7 +150,6 @@ class TestMLCacheDirectory: def test_model_cache_writable(self, tmp_path): """Test that we can write to model cache directory.""" - import pathlib # Use tmp_path fixture for testing cache_dir = tmp_path / ".cache" / "huggingface" @@ -169,7 +168,6 @@ class TestMLCacheDirectory: def test_torch_cache_directory(self, tmp_path, monkeypatch): """Test that PyTorch can use a custom cache directory.""" - import torch # Set custom cache directory cache_dir = tmp_path / ".cache" / "torch" @@ -204,9 +202,10 @@ class TestMLPerformanceBasic: def test_numpy_performance_basic(self): """Test basic NumPy performance with larger arrays.""" - import numpy as np import time + import numpy as np + # Create large array (10 million elements) arr = np.random.rand(10_000_000) diff --git a/src/documents/views.py b/src/documents/views.py index 8c71350b8..c77dd5a09 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -89,6 +89,8 @@ from rest_framework.viewsets import ViewSet from documents import bulk_edit from documents import index +from documents.ai_scanner import AIDocumentScanner +from documents.ai_scanner import get_ai_scanner from documents.bulk_download import ArchiveOnlyStrategy from documents.bulk_download import OriginalAndArchiveStrategy from documents.bulk_download import OriginalsOnlyStrategy @@ -141,13 +143,10 @@ from documents.models import UiSettings from documents.models import Workflow from documents.models import WorkflowAction from documents.models import WorkflowTrigger -from documents.ai_scanner import AIDocumentScanner -from documents.ai_scanner import get_ai_scanner from documents.parsers import get_parser_class_for_mime_type from documents.parsers import parse_date_generator from documents.permissions import AcknowledgeTasksPermissions from documents.permissions import CanApplyAISuggestionsPermission -from documents.permissions import CanApproveDeletionsPermission from documents.permissions import CanConfigureAIPermission from documents.permissions import CanViewAISuggestionsPermission from documents.permissions import PaperlessAdminPermissions @@ -1370,36 +1369,36 @@ class UnifiedSearchViewSet(DocumentViewSet): """ from documents.ai_scanner import get_ai_scanner from documents.serializers.ai_suggestions import AISuggestionsSerializer - + try: document = self.get_object() - + # Check if document has content to scan if not document.content: return Response( {"detail": "Document has no content to analyze"}, status=400, ) - + # Get AI scanner instance scanner = get_ai_scanner() - + # Perform AI scan scan_result = scanner.scan_document( document=document, document_text=document.content, - original_file_path=document.source_path if hasattr(document, 'source_path') else None, + original_file_path=document.source_path if hasattr(document, "source_path") else None, ) - + # Convert scan result to serializable format data = AISuggestionsSerializer.from_scan_result(scan_result, document.id) - + # Serialize and return serializer = AISuggestionsSerializer(data=data) serializer.is_valid(raise_exception=True) - + return Response(serializer.validated_data) - + except Exception as e: logger.error(f"Error getting AI suggestions for document {pk}: {e}", exc_info=True) return Response( @@ -1416,56 +1415,56 @@ class UnifiedSearchViewSet(DocumentViewSet): """ from documents.models import AISuggestionFeedback from documents.serializers.ai_suggestions import ApplySuggestionSerializer - + try: document = self.get_object() - + # Validate input serializer = ApplySuggestionSerializer(data=request.data) serializer.is_valid(raise_exception=True) - - suggestion_type = serializer.validated_data['suggestion_type'] - value_id = serializer.validated_data.get('value_id') - value_text = serializer.validated_data.get('value_text') - confidence = serializer.validated_data['confidence'] - + + suggestion_type = serializer.validated_data["suggestion_type"] + value_id = serializer.validated_data.get("value_id") + value_text = serializer.validated_data.get("value_text") + confidence = serializer.validated_data["confidence"] + # Apply the suggestion based on type applied = False result_message = "" - - if suggestion_type == 'tag' and value_id: + + if suggestion_type == "tag" and value_id: tag = Tag.objects.get(pk=value_id) document.tags.add(tag) applied = True result_message = f"Tag '{tag.name}' applied" - - elif suggestion_type == 'correspondent' and value_id: + + elif suggestion_type == "correspondent" and value_id: correspondent = Correspondent.objects.get(pk=value_id) document.correspondent = correspondent document.save() applied = True result_message = f"Correspondent '{correspondent.name}' applied" - - elif suggestion_type == 'document_type' and value_id: + + elif suggestion_type == "document_type" and value_id: doc_type = DocumentType.objects.get(pk=value_id) document.document_type = doc_type document.save() applied = True result_message = f"Document type '{doc_type.name}' applied" - - elif suggestion_type == 'storage_path' and value_id: + + elif suggestion_type == "storage_path" and value_id: storage_path = StoragePath.objects.get(pk=value_id) document.storage_path = storage_path document.save() applied = True result_message = f"Storage path '{storage_path.name}' applied" - - elif suggestion_type == 'title' and value_text: + + elif suggestion_type == "title" and value_text: document.title = value_text document.save() applied = True result_message = f"Title updated to '{value_text}'" - + if applied: # Record feedback AISuggestionFeedback.objects.create( @@ -1477,7 +1476,7 @@ class UnifiedSearchViewSet(DocumentViewSet): status=AISuggestionFeedback.STATUS_APPLIED, user=request.user, ) - + return Response({ "status": "success", "message": result_message, @@ -1487,8 +1486,8 @@ class UnifiedSearchViewSet(DocumentViewSet): {"detail": "Invalid suggestion type or missing value"}, status=400, ) - - except (Tag.DoesNotExist, Correspondent.DoesNotExist, + + except (Tag.DoesNotExist, Correspondent.DoesNotExist, DocumentType.DoesNotExist, StoragePath.DoesNotExist): return Response( {"detail": "Referenced object not found"}, @@ -1510,19 +1509,19 @@ class UnifiedSearchViewSet(DocumentViewSet): """ from documents.models import AISuggestionFeedback from documents.serializers.ai_suggestions import RejectSuggestionSerializer - + try: document = self.get_object() - + # Validate input serializer = RejectSuggestionSerializer(data=request.data) serializer.is_valid(raise_exception=True) - - suggestion_type = serializer.validated_data['suggestion_type'] - value_id = serializer.validated_data.get('value_id') - value_text = serializer.validated_data.get('value_text') - confidence = serializer.validated_data['confidence'] - + + suggestion_type = serializer.validated_data["suggestion_type"] + value_id = serializer.validated_data.get("value_id") + value_text = serializer.validated_data.get("value_text") + confidence = serializer.validated_data["confidence"] + # Record feedback AISuggestionFeedback.objects.create( document=document, @@ -1533,12 +1532,12 @@ class UnifiedSearchViewSet(DocumentViewSet): status=AISuggestionFeedback.STATUS_REJECTED, user=request.user, ) - + return Response({ "status": "success", "message": "Suggestion rejected and feedback recorded", }) - + except Exception as e: logger.error(f"Error rejecting suggestion for document {pk}: {e}", exc_info=True) return Response( @@ -1554,78 +1553,83 @@ class UnifiedSearchViewSet(DocumentViewSet): Returns aggregated data about applied vs rejected suggestions, accuracy rates, and confidence scores. """ - from django.db.models import Avg, Count, Q + from django.db.models import Avg + from django.db.models import Count + from django.db.models import Q + from documents.models import AISuggestionFeedback from documents.serializers.ai_suggestions import AISuggestionStatsSerializer - + try: # Get overall counts total_feedbacks = AISuggestionFeedback.objects.count() total_applied = AISuggestionFeedback.objects.filter( - status=AISuggestionFeedback.STATUS_APPLIED + status=AISuggestionFeedback.STATUS_APPLIED, ).count() total_rejected = AISuggestionFeedback.objects.filter( - status=AISuggestionFeedback.STATUS_REJECTED + status=AISuggestionFeedback.STATUS_REJECTED, ).count() - + # Calculate accuracy rate accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0 - + # Get statistics by suggestion type using a single aggregated query - stats_by_type = AISuggestionFeedback.objects.values('suggestion_type').annotate( - total=Count('id'), - applied=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)), - rejected=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)) + stats_by_type = AISuggestionFeedback.objects.values("suggestion_type").annotate( + total=Count("id"), + applied=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)), + rejected=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)), ) - + # Build the by_type dictionary using the aggregated results by_type = {} for stat in stats_by_type: - suggestion_type = stat['suggestion_type'] - type_total = stat['total'] - type_applied = stat['applied'] - type_rejected = stat['rejected'] - + suggestion_type = stat["suggestion_type"] + type_total = stat["total"] + type_applied = stat["applied"] + type_rejected = stat["rejected"] + by_type[suggestion_type] = { - 'total': type_total, - 'applied': type_applied, - 'rejected': type_rejected, - 'accuracy_rate': (type_applied / type_total * 100) if type_total > 0 else 0, + "total": type_total, + "applied": type_applied, + "rejected": type_rejected, + "accuracy_rate": (type_applied / type_total * 100) if type_total > 0 else 0, } - + # Get average confidence scores avg_confidence_applied = AISuggestionFeedback.objects.filter( - status=AISuggestionFeedback.STATUS_APPLIED - ).aggregate(Avg('confidence'))['confidence__avg'] or 0.0 - + status=AISuggestionFeedback.STATUS_APPLIED, + ).aggregate(Avg("confidence"))["confidence__avg"] or 0.0 + avg_confidence_rejected = AISuggestionFeedback.objects.filter( - status=AISuggestionFeedback.STATUS_REJECTED - ).aggregate(Avg('confidence'))['confidence__avg'] or 0.0 - + status=AISuggestionFeedback.STATUS_REJECTED, + ).aggregate(Avg("confidence"))["confidence__avg"] or 0.0 + # Get recent suggestions (last 10) - recent_suggestions = AISuggestionFeedback.objects.order_by('-created_at')[:10] - + recent_suggestions = AISuggestionFeedback.objects.order_by("-created_at")[:10] + # Build response data - from documents.serializers.ai_suggestions import AISuggestionFeedbackSerializer + from documents.serializers.ai_suggestions import ( + AISuggestionFeedbackSerializer, + ) data = { - 'total_suggestions': total_feedbacks, - 'total_applied': total_applied, - 'total_rejected': total_rejected, - 'accuracy_rate': accuracy_rate, - 'by_type': by_type, - 'average_confidence_applied': avg_confidence_applied, - 'average_confidence_rejected': avg_confidence_rejected, - 'recent_suggestions': AISuggestionFeedbackSerializer( - recent_suggestions, many=True + "total_suggestions": total_feedbacks, + "total_applied": total_applied, + "total_rejected": total_rejected, + "accuracy_rate": accuracy_rate, + "by_type": by_type, + "average_confidence_applied": avg_confidence_applied, + "average_confidence_rejected": avg_confidence_rejected, + "recent_suggestions": AISuggestionFeedbackSerializer( + recent_suggestions, many=True, ).data, } - + # Serialize and return serializer = AISuggestionStatsSerializer(data=data) serializer.is_valid(raise_exception=True) - + return Response(serializer.validated_data) - + except Exception as e: logger.error(f"Error getting AI suggestion statistics: {e}", exc_info=True) return Response( @@ -3561,37 +3565,37 @@ class AISuggestionsView(GenericAPIView): Requires: can_view_ai_suggestions permission """ - + permission_classes = [IsAuthenticated, CanViewAISuggestionsPermission] serializer_class = AISuggestionsResponseSerializer - + def post(self, request): """Get AI suggestions for a document.""" # Validate request request_serializer = AISuggestionsRequestSerializer(data=request.data) request_serializer.is_valid(raise_exception=True) - - document_id = request_serializer.validated_data['document_id'] - + + document_id = request_serializer.validated_data["document_id"] + try: document = Document.objects.get(pk=document_id) except Document.DoesNotExist: return Response( {"error": "Document not found or you don't have permission to view it"}, - status=status.HTTP_404_NOT_FOUND + status=status.HTTP_404_NOT_FOUND, ) - + # Check if user has permission to view this document - if not has_perms_owner_aware(request.user, 'documents.view_document', document): + if not has_perms_owner_aware(request.user, "documents.view_document", document): return Response( {"error": "Permission denied"}, - status=status.HTTP_403_FORBIDDEN + status=status.HTTP_403_FORBIDDEN, ) - + # Get AI scanner and scan document scanner = get_ai_scanner() scan_result = scanner.scan_document(document, document.content or "") - + # Build response response_data = { "document_id": document.id, @@ -3600,9 +3604,9 @@ class AISuggestionsView(GenericAPIView): "document_type": None, "storage_path": None, "title_suggestion": scan_result.title_suggestion, - "custom_fields": {} + "custom_fields": {}, } - + # Format tag suggestions for tag_id, confidence in scan_result.tags: try: @@ -3610,12 +3614,12 @@ class AISuggestionsView(GenericAPIView): response_data["tags"].append({ "id": tag.id, "name": tag.name, - "confidence": confidence + "confidence": confidence, }) except Tag.DoesNotExist: # Tag was suggested by AI but no longer exists; skip it pass - + # Format correspondent suggestion if scan_result.correspondent: corr_id, confidence = scan_result.correspondent @@ -3624,12 +3628,12 @@ class AISuggestionsView(GenericAPIView): response_data["correspondent"] = { "id": correspondent.id, "name": correspondent.name, - "confidence": confidence + "confidence": confidence, } except Correspondent.DoesNotExist: # Correspondent was suggested but no longer exists; skip it pass - + # Format document type suggestion if scan_result.document_type: type_id, confidence = scan_result.document_type @@ -3638,12 +3642,12 @@ class AISuggestionsView(GenericAPIView): response_data["document_type"] = { "id": doc_type.id, "name": doc_type.name, - "confidence": confidence + "confidence": confidence, } except DocumentType.DoesNotExist: # Document type was suggested but no longer exists; skip it pass - + # Format storage path suggestion if scan_result.storage_path: path_id, confidence = scan_result.storage_path @@ -3652,19 +3656,19 @@ class AISuggestionsView(GenericAPIView): response_data["storage_path"] = { "id": storage_path.id, "name": storage_path.name, - "confidence": confidence + "confidence": confidence, } except StoragePath.DoesNotExist: # Storage path was suggested but no longer exists; skip it pass - + # Format custom fields for field_id, (value, confidence) in scan_result.custom_fields.items(): response_data["custom_fields"][str(field_id)] = { "value": value, - "confidence": confidence + "confidence": confidence, } - + return Response(response_data) @@ -3674,48 +3678,48 @@ class ApplyAISuggestionsView(GenericAPIView): Requires: can_apply_ai_suggestions permission """ - + permission_classes = [IsAuthenticated, CanApplyAISuggestionsPermission] - + def post(self, request): """Apply AI suggestions to a document.""" # Validate request serializer = ApplyAISuggestionsSerializer(data=request.data) serializer.is_valid(raise_exception=True) - - document_id = serializer.validated_data['document_id'] - + + document_id = serializer.validated_data["document_id"] + try: document = Document.objects.get(pk=document_id) except Document.DoesNotExist: return Response( {"error": "Document not found"}, - status=status.HTTP_404_NOT_FOUND + status=status.HTTP_404_NOT_FOUND, ) - + # Check if user has permission to change this document - if not has_perms_owner_aware(request.user, 'documents.change_document', document): + if not has_perms_owner_aware(request.user, "documents.change_document", document): return Response( {"error": "Permission denied"}, - status=status.HTTP_403_FORBIDDEN + status=status.HTTP_403_FORBIDDEN, ) - + # Get AI scanner and scan document scanner = get_ai_scanner() scan_result = scanner.scan_document(document, document.content or "") - + # Apply suggestions based on user selections applied = [] - - if serializer.validated_data.get('apply_tags'): - selected_tags = serializer.validated_data.get('selected_tags', []) + + if serializer.validated_data.get("apply_tags"): + selected_tags = serializer.validated_data.get("selected_tags", []) if selected_tags: # Apply only selected tags tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags] else: # Apply all high-confidence tags tags_to_apply = [tag_id for tag_id, conf in scan_result.tags if conf >= scanner.auto_apply_threshold] - + for tag_id in tags_to_apply: try: tag = Tag.objects.get(pk=tag_id) @@ -3724,8 +3728,8 @@ class ApplyAISuggestionsView(GenericAPIView): except Tag.DoesNotExist: # Tag not found; skip applying this tag pass - - if serializer.validated_data.get('apply_correspondent') and scan_result.correspondent: + + if serializer.validated_data.get("apply_correspondent") and scan_result.correspondent: corr_id, confidence = scan_result.correspondent try: correspondent = Correspondent.objects.get(pk=corr_id) @@ -3734,8 +3738,8 @@ class ApplyAISuggestionsView(GenericAPIView): except Correspondent.DoesNotExist: # Correspondent not found; skip applying pass - - if serializer.validated_data.get('apply_document_type') and scan_result.document_type: + + if serializer.validated_data.get("apply_document_type") and scan_result.document_type: type_id, confidence = scan_result.document_type try: doc_type = DocumentType.objects.get(pk=type_id) @@ -3744,8 +3748,8 @@ class ApplyAISuggestionsView(GenericAPIView): except DocumentType.DoesNotExist: # Document type not found; skip applying pass - - if serializer.validated_data.get('apply_storage_path') and scan_result.storage_path: + + if serializer.validated_data.get("apply_storage_path") and scan_result.storage_path: path_id, confidence = scan_result.storage_path try: storage_path = StoragePath.objects.get(pk=path_id) @@ -3754,18 +3758,18 @@ class ApplyAISuggestionsView(GenericAPIView): except StoragePath.DoesNotExist: # Storage path not found; skip applying pass - - if serializer.validated_data.get('apply_title') and scan_result.title_suggestion: + + if serializer.validated_data.get("apply_title") and scan_result.title_suggestion: document.title = scan_result.title_suggestion applied.append(f"title: {scan_result.title_suggestion}") - + # Save document document.save() - + return Response({ "status": "success", "document_id": document.id, - "applied": applied + "applied": applied, }) @@ -3775,23 +3779,23 @@ class AIConfigurationView(GenericAPIView): Requires: can_configure_ai permission """ - + permission_classes = [IsAuthenticated, CanConfigureAIPermission] - + def get(self, request): """Get current AI configuration.""" scanner = get_ai_scanner() - + config_data = { "auto_apply_threshold": scanner.auto_apply_threshold, "suggest_threshold": scanner.suggest_threshold, "ml_enabled": scanner.ml_enabled, "advanced_ocr_enabled": scanner.advanced_ocr_enabled, } - + serializer = AIConfigurationSerializer(config_data) return Response(serializer.data) - + def post(self, request): """ Update AI configuration. @@ -3802,27 +3806,27 @@ class AIConfigurationView(GenericAPIView): """ serializer = AIConfigurationSerializer(data=request.data) serializer.is_valid(raise_exception=True) - + # Create new scanner with updated configuration config = {} - if 'auto_apply_threshold' in serializer.validated_data: - config['auto_apply_threshold'] = serializer.validated_data['auto_apply_threshold'] - if 'suggest_threshold' in serializer.validated_data: - config['suggest_threshold'] = serializer.validated_data['suggest_threshold'] - if 'ml_enabled' in serializer.validated_data: - config['enable_ml_features'] = serializer.validated_data['ml_enabled'] - if 'advanced_ocr_enabled' in serializer.validated_data: - config['enable_advanced_ocr'] = serializer.validated_data['advanced_ocr_enabled'] - + if "auto_apply_threshold" in serializer.validated_data: + config["auto_apply_threshold"] = serializer.validated_data["auto_apply_threshold"] + if "suggest_threshold" in serializer.validated_data: + config["suggest_threshold"] = serializer.validated_data["suggest_threshold"] + if "ml_enabled" in serializer.validated_data: + config["enable_ml_features"] = serializer.validated_data["ml_enabled"] + if "advanced_ocr_enabled" in serializer.validated_data: + config["enable_advanced_ocr"] = serializer.validated_data["advanced_ocr_enabled"] + # Update global scanner instance # WARNING: Not thread-safe. Consider storing configuration in database # and reloading on each get_ai_scanner() call for production use from documents import ai_scanner ai_scanner._scanner_instance = AIDocumentScanner(**config) - + return Response({ "status": "success", - "message": "AI configuration updated. Changes may require server restart for consistency." + "message": "AI configuration updated. Changes may require server restart for consistency.", }) diff --git a/src/documents/views/deletion_request.py b/src/documents/views/deletion_request.py index 22d8e25c3..9e11d0856 100644 --- a/src/documents/views/deletion_request.py +++ b/src/documents/views/deletion_request.py @@ -30,10 +30,10 @@ class DeletionRequestViewSet(ModelViewSet): Provides CRUD operations plus custom actions for approval workflow. """ - + model = DeletionRequest serializer_class = DeletionRequestSerializer - + def get_queryset(self): """ Return deletion requests for the current user. @@ -45,7 +45,7 @@ class DeletionRequestViewSet(ModelViewSet): if user.is_superuser: return DeletionRequest.objects.all() return DeletionRequest.objects.filter(user=user) - + def _can_manage_request(self, deletion_request): """ Check if current user can manage (approve/reject/cancel) the request. @@ -58,7 +58,7 @@ class DeletionRequestViewSet(ModelViewSet): """ user = self.request.user return user.is_superuser or deletion_request.user == user - + @action(methods=["post"], detail=True) def approve(self, request, pk=None): """ @@ -72,13 +72,13 @@ class DeletionRequestViewSet(ModelViewSet): Response with execution results """ deletion_request = self.get_object() - + # Check permissions if not self._can_manage_request(deletion_request): return HttpResponseForbidden( - "You don't have permission to approve this deletion request." + "You don't have permission to approve this deletion request.", ) - + # Validate status if deletion_request.status != DeletionRequest.STATUS_PENDING: return Response( @@ -88,9 +88,9 @@ class DeletionRequestViewSet(ModelViewSet): }, status=status.HTTP_400_BAD_REQUEST, ) - + comment = request.data.get("comment", "") - + # Execute approval and deletion in a transaction try: with transaction.atomic(): @@ -100,12 +100,12 @@ class DeletionRequestViewSet(ModelViewSet): {"error": "Failed to approve deletion request."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) - + # Execute the deletion documents = list(deletion_request.documents.all()) deleted_count = 0 failed_deletions = [] - + for doc in documents: try: doc_id = doc.id @@ -114,18 +114,18 @@ class DeletionRequestViewSet(ModelViewSet): deleted_count += 1 logger.info( f"Deleted document {doc_id} ('{doc_title}') " - f"as part of deletion request {deletion_request.id}" + f"as part of deletion request {deletion_request.id}", ) except Exception as e: logger.error( - f"Failed to delete document {doc.id}: {str(e)}" + f"Failed to delete document {doc.id}: {e!s}", ) failed_deletions.append({ "id": doc.id, "title": doc.title, "error": str(e), }) - + # Update completion status deletion_request.status = DeletionRequest.STATUS_COMPLETED deletion_request.completed_at = timezone.now() @@ -135,20 +135,20 @@ class DeletionRequestViewSet(ModelViewSet): "total_documents": len(documents), } deletion_request.save() - + logger.info( f"Deletion request {deletion_request.id} completed. " - f"Deleted {deleted_count}/{len(documents)} documents." + f"Deleted {deleted_count}/{len(documents)} documents.", ) except Exception as e: logger.error( - f"Error executing deletion request {deletion_request.id}: {str(e)}" + f"Error executing deletion request {deletion_request.id}: {e!s}", ) return Response( - {"error": f"Failed to execute deletion: {str(e)}"}, + {"error": f"Failed to execute deletion: {e!s}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) - + serializer = self.get_serializer(deletion_request) return Response( { @@ -158,7 +158,7 @@ class DeletionRequestViewSet(ModelViewSet): }, status=status.HTTP_200_OK, ) - + @action(methods=["post"], detail=True) def reject(self, request, pk=None): """ @@ -172,13 +172,13 @@ class DeletionRequestViewSet(ModelViewSet): Response with updated deletion request """ deletion_request = self.get_object() - + # Check permissions if not self._can_manage_request(deletion_request): return HttpResponseForbidden( - "You don't have permission to reject this deletion request." + "You don't have permission to reject this deletion request.", ) - + # Validate status if deletion_request.status != DeletionRequest.STATUS_PENDING: return Response( @@ -188,20 +188,20 @@ class DeletionRequestViewSet(ModelViewSet): }, status=status.HTTP_400_BAD_REQUEST, ) - + comment = request.data.get("comment", "") - + # Reject the request if not deletion_request.reject(request.user, comment): return Response( {"error": "Failed to reject deletion request."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) - + logger.info( - f"Deletion request {deletion_request.id} rejected by user {request.user.username}" + f"Deletion request {deletion_request.id} rejected by user {request.user.username}", ) - + serializer = self.get_serializer(deletion_request) return Response( { @@ -210,7 +210,7 @@ class DeletionRequestViewSet(ModelViewSet): }, status=status.HTTP_200_OK, ) - + @action(methods=["post"], detail=True) def cancel(self, request, pk=None): """ @@ -224,13 +224,13 @@ class DeletionRequestViewSet(ModelViewSet): Response with updated deletion request """ deletion_request = self.get_object() - + # Check permissions if not self._can_manage_request(deletion_request): return HttpResponseForbidden( - "You don't have permission to cancel this deletion request." + "You don't have permission to cancel this deletion request.", ) - + # Validate status if deletion_request.status != DeletionRequest.STATUS_PENDING: return Response( @@ -240,18 +240,18 @@ class DeletionRequestViewSet(ModelViewSet): }, status=status.HTTP_400_BAD_REQUEST, ) - + # Cancel the request deletion_request.status = DeletionRequest.STATUS_CANCELLED deletion_request.reviewed_by = request.user deletion_request.reviewed_at = timezone.now() deletion_request.review_comment = request.data.get("comment", "Cancelled by user") deletion_request.save() - + logger.info( - f"Deletion request {deletion_request.id} cancelled by user {request.user.username}" + f"Deletion request {deletion_request.id} cancelled by user {request.user.username}", ) - + serializer = self.get_serializer(deletion_request) return Response( { diff --git a/src/paperless/middleware.py b/src/paperless/middleware.py index e5863cb19..ba9f7a575 100644 --- a/src/paperless/middleware.py +++ b/src/paperless/middleware.py @@ -160,7 +160,7 @@ class SecurityHeadersMiddleware: # Store nonce in request for use in templates # Templates can access this via {{ request.csp_nonce }} - if hasattr(request, '_csp_nonce'): + if hasattr(request, "_csp_nonce"): request._csp_nonce = nonce # Prevent clickjacking attacks diff --git a/src/paperless/security.py b/src/paperless/security.py index 709091438..a171b8c44 100644 --- a/src/paperless/security.py +++ b/src/paperless/security.py @@ -9,7 +9,6 @@ from __future__ import annotations import hashlib import logging -import mimetypes import os import re from pathlib import Path @@ -26,39 +25,39 @@ logger = logging.getLogger("paperless.security") # Lista explícita de tipos MIME permitidos ALLOWED_MIME_TYPES = { # Documentos - 'application/pdf', - 'application/vnd.oasis.opendocument.text', - 'application/msword', - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - 'application/vnd.ms-excel', - 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - 'application/vnd.ms-powerpoint', - 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - 'application/vnd.oasis.opendocument.spreadsheet', - 'application/vnd.oasis.opendocument.presentation', - 'application/rtf', - 'text/rtf', + "application/pdf", + "application/vnd.oasis.opendocument.text", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "application/vnd.oasis.opendocument.spreadsheet", + "application/vnd.oasis.opendocument.presentation", + "application/rtf", + "text/rtf", # Imágenes - 'image/jpeg', - 'image/png', - 'image/gif', - 'image/tiff', - 'image/bmp', - 'image/webp', + "image/jpeg", + "image/png", + "image/gif", + "image/tiff", + "image/bmp", + "image/webp", # Texto - 'text/plain', - 'text/html', - 'text/csv', - 'text/markdown', + "text/plain", + "text/html", + "text/csv", + "text/markdown", } # Maximum file size (100MB by default) # Can be overridden by settings.MAX_UPLOAD_SIZE try: from django.conf import settings - MAX_FILE_SIZE = getattr(settings, 'MAX_UPLOAD_SIZE', 100 * 1024 * 1024) # 100MB por defecto + MAX_FILE_SIZE = getattr(settings, "MAX_UPLOAD_SIZE", 100 * 1024 * 1024) # 100MB por defecto except ImportError: MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes @@ -114,7 +113,6 @@ ALLOWED_JS_PATTERNS = [ class FileValidationError(Exception): """Raised when file validation fails.""" - pass def has_whitelisted_javascript(content: bytes) -> bool: @@ -143,7 +141,7 @@ def validate_mime_type(mime_type: str) -> None: if mime_type not in ALLOWED_MIME_TYPES: raise FileValidationError( f"MIME type '{mime_type}' is not allowed. " - f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}" + f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}", ) diff --git a/src/paperless/urls.py b/src/paperless/urls.py index c8ee81381..406d12e18 100644 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -86,7 +86,7 @@ api_router.register(r"config", ApplicationConfigurationViewSet) api_router.register(r"processed_mail", ProcessedMailViewSet) api_router.register(r"deletion_requests", DeletionRequestViewSet) api_router.register( - r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests" + r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests", )