fix(syntax): corrige errores de sintaxis y formato en Python

- Corrige paréntesis faltante en DeletionRequestActionSerializer (serialisers.py:2855)
- Elimina espacios en blanco en líneas vacías (W293)
- Elimina espacios finales en líneas (W291)
- Elimina imports no utilizados (F401)
- Normaliza comillas a comillas dobles (Q000)
- Agrega comas finales faltantes (COM812)
- Ordena imports según convenciones (I001)
- Actualiza anotaciones de tipo a PEP 585 (UP006)

Este commit resuelve el error de compilación en el job de CI/CD
que estaba causando que fallara el linting check.

Archivos afectados: 38
Líneas modificadas: ~2200
This commit is contained in:
Claude 2025-11-17 19:08:02 +00:00
parent 9298f64546
commit 69326b883d
No known key found for this signature in database
38 changed files with 2077 additions and 2112 deletions

View file

@ -14,14 +14,10 @@ According to agents.md requirements:
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING
from typing import Any from typing import Any
from django.contrib.auth.models import User from django.contrib.auth.models import User
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ai_deletion") logger = logging.getLogger("paperless.ai_deletion")

View file

@ -29,7 +29,6 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Any from typing import Any
from typing import Dict
from typing import TypedDict from typing import TypedDict
from django.conf import settings from django.conf import settings
@ -142,34 +141,34 @@ class AIScanResult:
""" """
# Convert internal tuple format to TypedDict format # Convert internal tuple format to TypedDict format
result: AIScanResultDict = { result: AIScanResultDict = {
'tags': [{'tag_id': tag_id, 'confidence': conf} for tag_id, conf in self.tags], "tags": [{"tag_id": tag_id, "confidence": conf} for tag_id, conf in self.tags],
'custom_fields': { "custom_fields": {
field_id: {'value': value, 'confidence': conf} field_id: {"value": value, "confidence": conf}
for field_id, (value, conf) in self.custom_fields.items() for field_id, (value, conf) in self.custom_fields.items()
}, },
'workflows': [{'workflow_id': wf_id, 'confidence': conf} for wf_id, conf in self.workflows], "workflows": [{"workflow_id": wf_id, "confidence": conf} for wf_id, conf in self.workflows],
'extracted_entities': self.extracted_entities, "extracted_entities": self.extracted_entities,
'metadata': self.metadata, "metadata": self.metadata,
} }
# Add optional fields only if present # Add optional fields only if present
if self.correspondent: if self.correspondent:
result['correspondent'] = { result["correspondent"] = {
'correspondent_id': self.correspondent[0], "correspondent_id": self.correspondent[0],
'confidence': self.correspondent[1], "confidence": self.correspondent[1],
} }
if self.document_type: if self.document_type:
result['document_type'] = { result["document_type"] = {
'type_id': self.document_type[0], "type_id": self.document_type[0],
'confidence': self.document_type[1], "confidence": self.document_type[1],
} }
if self.storage_path: if self.storage_path:
result['storage_path'] = { result["storage_path"] = {
'path_id': self.storage_path[0], "path_id": self.storage_path[0],
'confidence': self.storage_path[1], "confidence": self.storage_path[1],
} }
if self.title_suggestion: if self.title_suggestion:
result['title_suggestion'] = self.title_suggestion result["title_suggestion"] = self.title_suggestion
return result return result
@ -1054,7 +1053,7 @@ class AIDocumentScanner:
warm_up_time = time.time() - start_time warm_up_time = time.time() - start_time
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s") logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
def get_cache_metrics(self) -> Dict[str, Any]: def get_cache_metrics(self) -> dict[str, Any]:
""" """
Get cache performance metrics. Get cache performance metrics.

View file

@ -310,7 +310,7 @@ class ConsumerPlugin(
# Parsing phase # Parsing phase
document_parser = self._create_parser_instance(parser_class) document_parser = self._create_parser_instance(parser_class)
text, date, thumbnail, archive_path, page_count = self._parse_document( text, date, thumbnail, archive_path, page_count = self._parse_document(
document_parser, mime_type document_parser, mime_type,
) )
# Storage phase # Storage phase
@ -394,7 +394,7 @@ class ConsumerPlugin(
def _attempt_pdf_recovery( def _attempt_pdf_recovery(
self, self,
tempdir: tempfile.TemporaryDirectory, tempdir: tempfile.TemporaryDirectory,
original_mime_type: str original_mime_type: str,
) -> str: ) -> str:
""" """
Attempt to recover a PDF file with incorrect MIME type using qpdf. Attempt to recover a PDF file with incorrect MIME type using qpdf.
@ -438,7 +438,7 @@ class ConsumerPlugin(
def _get_parser_class( def _get_parser_class(
self, self,
mime_type: str, mime_type: str,
tempdir: tempfile.TemporaryDirectory tempdir: tempfile.TemporaryDirectory,
) -> type[DocumentParser]: ) -> type[DocumentParser]:
""" """
Determine which parser to use based on MIME type. Determine which parser to use based on MIME type.
@ -468,7 +468,7 @@ class ConsumerPlugin(
def _create_parser_instance( def _create_parser_instance(
self, self,
parser_class: type[DocumentParser] parser_class: type[DocumentParser],
) -> DocumentParser: ) -> DocumentParser:
""" """
Create a parser instance with progress callback. Create a parser instance with progress callback.
@ -496,7 +496,7 @@ class ConsumerPlugin(
def _parse_document( def _parse_document(
self, self,
document_parser: DocumentParser, document_parser: DocumentParser,
mime_type: str mime_type: str,
) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]: ) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]:
""" """
Parse the document and extract metadata. Parse the document and extract metadata.
@ -670,7 +670,7 @@ class ConsumerPlugin(
self, self,
document: Document, document: Document,
thumbnail: Path, thumbnail: Path,
archive_path: Path | None archive_path: Path | None,
) -> None: ) -> None:
""" """
Store document files (source, thumbnail, archive) to disk. Store document files (source, thumbnail, archive) to disk.
@ -949,7 +949,7 @@ class ConsumerPlugin(
text: The extracted document text text: The extracted document text
""" """
# Check if AI scanner is enabled # Check if AI scanner is enabled
if not getattr(settings, 'PAPERLESS_ENABLE_AI_SCANNER', True): if not getattr(settings, "PAPERLESS_ENABLE_AI_SCANNER", True):
self.log.debug("AI scanner is disabled, skipping AI analysis") self.log.debug("AI scanner is disabled, skipping AI analysis")
return return

View file

@ -1,6 +1,7 @@
# Generated manually for performance optimization # Generated manually for performance optimization
from django.db import migrations, models from django.db import migrations
from django.db import models
class Migration(migrations.Migration): class Migration(migrations.Migration):

View file

@ -1,9 +1,10 @@
# Generated manually for DeletionRequest model # Generated manually for DeletionRequest model
# Based on model definition in documents/models.py # Based on model definition in documents/models.py
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion import django.db.models.deletion
from django.conf import settings
from django.db import migrations
from django.db import models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@ -48,7 +49,7 @@ class Migration(migrations.Migration):
( (
"ai_reason", "ai_reason",
models.TextField( models.TextField(
help_text="Detailed explanation from AI about why deletion is recommended" help_text="Detailed explanation from AI about why deletion is recommended",
), ),
), ),
( (

View file

@ -1,6 +1,7 @@
# Generated manually for DeletionRequest performance optimization # Generated manually for DeletionRequest performance optimization
from django.db import migrations, models from django.db import migrations
from django.db import models
class Migration(migrations.Migration): class Migration(migrations.Migration):

View file

@ -1,9 +1,10 @@
# Generated manually for AI Suggestions API # Generated manually for AI Suggestions API
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.core.validators import django.core.validators
import django.db.models.deletion
from django.conf import settings
from django.db import migrations
from django.db import models
class Migration(migrations.Migration): class Migration(migrations.Migration):

View file

@ -10,9 +10,9 @@ Provides AI/ML capabilities including:
from __future__ import annotations from __future__ import annotations
__all__ = [ __all__ = [
"TransformerDocumentClassifier",
"DocumentNER", "DocumentNER",
"SemanticSearch", "SemanticSearch",
"TransformerDocumentClassifier",
] ]
# Lazy imports to avoid loading heavy ML libraries unless needed # Lazy imports to avoid loading heavy ML libraries unless needed

View file

@ -15,23 +15,16 @@ Logging levels used in this module:
from __future__ import annotations from __future__ import annotations
import logging import logging
from pathlib import Path
from typing import TYPE_CHECKING
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import ( from transformers import AutoModelForSequenceClassification
AutoModelForSequenceClassification, from transformers import AutoTokenizer
AutoTokenizer, from transformers import Trainer
Trainer, from transformers import TrainingArguments
TrainingArguments,
)
from documents.ml.model_cache import ModelCacheManager from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
from documents.models import Document
logger = logging.getLogger("paperless.ml.classifier") logger = logging.getLogger("paperless.ml.classifier")
@ -141,7 +134,7 @@ class TransformerDocumentClassifier:
logger.info( logger.info(
f"Initialized TransformerDocumentClassifier with {model_name} " f"Initialized TransformerDocumentClassifier with {model_name} "
f"(caching: {use_cache})" f"(caching: {use_cache})",
) )
def train( def train(

View file

@ -24,8 +24,9 @@ import pickle
import threading import threading
import time import time
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any
logger = logging.getLogger("paperless.ml.model_cache") logger = logging.getLogger("paperless.ml.model_cache")
@ -58,7 +59,7 @@ class CacheMetrics:
with self.lock: with self.lock:
self.loads += 1 self.loads += 1
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> dict[str, Any]:
with self.lock: with self.lock:
total = self.hits + self.misses total = self.hits + self.misses
hit_rate = (self.hits / total * 100) if total > 0 else 0.0 hit_rate = (self.hits / total * 100) if total > 0 else 0.0
@ -98,7 +99,7 @@ class LRUCache:
self.lock = threading.Lock() self.lock = threading.Lock()
self.metrics = CacheMetrics() self.metrics = CacheMetrics()
def get(self, key: str) -> Optional[Any]: def get(self, key: str) -> Any | None:
""" """
Get item from cache. Get item from cache.
@ -153,7 +154,7 @@ class LRUCache:
with self.lock: with self.lock:
return len(self.cache) return len(self.cache)
def get_metrics(self) -> Dict[str, Any]: def get_metrics(self) -> dict[str, Any]:
"""Get cache metrics.""" """Get cache metrics."""
return self.metrics.get_stats() return self.metrics.get_stats()
@ -173,7 +174,7 @@ class ModelCacheManager:
model = cache.get_or_load_model("classifier", loader_func) model = cache.get_or_load_model("classifier", loader_func)
""" """
_instance: Optional[ModelCacheManager] = None _instance: ModelCacheManager | None = None
_lock = threading.Lock() _lock = threading.Lock()
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
@ -187,7 +188,7 @@ class ModelCacheManager:
def __init__( def __init__(
self, self,
max_models: int = 3, max_models: int = 3,
disk_cache_dir: Optional[str] = None, disk_cache_dir: str | None = None,
): ):
""" """
Initialize model cache manager. Initialize model cache manager.
@ -215,7 +216,7 @@ class ModelCacheManager:
def get_instance( def get_instance(
cls, cls,
max_models: int = 3, max_models: int = 3,
disk_cache_dir: Optional[str] = None, disk_cache_dir: str | None = None,
) -> ModelCacheManager: ) -> ModelCacheManager:
""" """
Get singleton instance of ModelCacheManager. Get singleton instance of ModelCacheManager.
@ -278,7 +279,7 @@ class ModelCacheManager:
load_time = time.time() - start_time load_time = time.time() - start_time
logger.info( logger.info(
f"Model loaded successfully: {model_key} " f"Model loaded successfully: {model_key} "
f"(took {load_time:.2f}s)" f"(took {load_time:.2f}s)",
) )
return model return model
@ -289,7 +290,7 @@ class ModelCacheManager:
def save_embeddings_to_disk( def save_embeddings_to_disk(
self, self,
key: str, key: str,
embeddings: Dict[int, Any], embeddings: dict[int, Any],
) -> bool: ) -> bool:
""" """
Save embeddings to disk cache. Save embeddings to disk cache.
@ -311,7 +312,7 @@ class ModelCacheManager:
cache_file = self.disk_cache_dir / f"{key}.pkl" cache_file = self.disk_cache_dir / f"{key}.pkl"
try: try:
with open(cache_file, 'wb') as f: with open(cache_file, "wb") as f:
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}") logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}")
return True return True
@ -330,7 +331,7 @@ class ModelCacheManager:
def load_embeddings_from_disk( def load_embeddings_from_disk(
self, self,
key: str, key: str,
) -> Optional[Dict[int, Any]]: ) -> dict[int, Any] | None:
""" """
Load embeddings from disk cache. Load embeddings from disk cache.
@ -393,7 +394,7 @@ class ModelCacheManager:
except Exception as e: except Exception as e:
logger.error(f"Failed to delete {cache_file}: {e}") logger.error(f"Failed to delete {cache_file}: {e}")
def get_metrics(self) -> Dict[str, Any]: def get_metrics(self) -> dict[str, Any]:
""" """
Get cache performance metrics. Get cache performance metrics.
@ -416,7 +417,7 @@ class ModelCacheManager:
def warm_up( def warm_up(
self, self,
model_loaders: Dict[str, Callable[[], Any]], model_loaders: dict[str, Callable[[], Any]],
) -> None: ) -> None:
""" """
Pre-load models on startup (warm-up). Pre-load models on startup (warm-up).

View file

@ -14,15 +14,11 @@ from __future__ import annotations
import logging import logging
import re import re
from typing import TYPE_CHECKING
from transformers import pipeline from transformers import pipeline
from documents.ml.model_cache import ModelCacheManager from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ml.ner") logger = logging.getLogger("paperless.ml.ner")

View file

@ -18,18 +18,14 @@ Examples:
from __future__ import annotations from __future__ import annotations
import logging import logging
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
from sentence_transformers import SentenceTransformer, util from sentence_transformers import SentenceTransformer
from sentence_transformers import util
from documents.ml.model_cache import ModelCacheManager from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ml.semantic_search") logger = logging.getLogger("paperless.ml.semantic_search")
@ -67,7 +63,7 @@ class SemanticSearch:
""" """
logger.info( logger.info(
f"Initializing SemanticSearch with model: {model_name} " f"Initializing SemanticSearch with model: {model_name} "
f"(caching: {use_cache})" f"(caching: {use_cache})",
) )
self.model_name = model_name self.model_name = model_name
@ -127,11 +123,11 @@ class SemanticSearch:
if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor): if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor):
logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor") logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor")
return False return False
if hasattr(embedding, 'size'): if hasattr(embedding, "size"):
if embedding.size == 0: if embedding.size == 0:
logger.warning(f"Embedding for doc {doc_id} is empty") logger.warning(f"Embedding for doc {doc_id} is empty")
return False return False
elif hasattr(embedding, 'numel'): elif hasattr(embedding, "numel"):
if embedding.numel() == 0: if embedding.numel() == 0:
logger.warning(f"Embedding for doc {doc_id} is empty") logger.warning(f"Embedding for doc {doc_id} is empty")
return False return False
@ -216,11 +212,11 @@ class SemanticSearch:
try: try:
result = self.cache_manager.save_embeddings_to_disk( result = self.cache_manager.save_embeddings_to_disk(
"document_embeddings", "document_embeddings",
self.document_embeddings self.document_embeddings,
) )
if result: if result:
logger.info( logger.info(
f"Successfully saved {len(self.document_embeddings)} embeddings to disk" f"Successfully saved {len(self.document_embeddings)} embeddings to disk",
) )
else: else:
logger.error("Failed to save embeddings to disk (returned False)") logger.error("Failed to save embeddings to disk (returned False)")

View file

@ -1604,30 +1604,30 @@ class DeletionRequest(models.Model):
# Requester (AI system) # Requester (AI system)
requested_by_ai = models.BooleanField(default=True) requested_by_ai = models.BooleanField(default=True)
ai_reason = models.TextField( ai_reason = models.TextField(
help_text=_("Detailed explanation from AI about why deletion is recommended") help_text=_("Detailed explanation from AI about why deletion is recommended"),
) )
# User who must approve # User who must approve
user = models.ForeignKey( user = models.ForeignKey(
User, User,
on_delete=models.CASCADE, on_delete=models.CASCADE,
related_name='deletion_requests', related_name="deletion_requests",
help_text=_("User who must approve this deletion"), help_text=_("User who must approve this deletion"),
) )
# Status tracking # Status tracking
STATUS_PENDING = 'pending' STATUS_PENDING = "pending"
STATUS_APPROVED = 'approved' STATUS_APPROVED = "approved"
STATUS_REJECTED = 'rejected' STATUS_REJECTED = "rejected"
STATUS_CANCELLED = 'cancelled' STATUS_CANCELLED = "cancelled"
STATUS_COMPLETED = 'completed' STATUS_COMPLETED = "completed"
STATUS_CHOICES = [ STATUS_CHOICES = [
(STATUS_PENDING, _('Pending')), (STATUS_PENDING, _("Pending")),
(STATUS_APPROVED, _('Approved')), (STATUS_APPROVED, _("Approved")),
(STATUS_REJECTED, _('Rejected')), (STATUS_REJECTED, _("Rejected")),
(STATUS_CANCELLED, _('Cancelled')), (STATUS_CANCELLED, _("Cancelled")),
(STATUS_COMPLETED, _('Completed')), (STATUS_COMPLETED, _("Completed")),
] ]
status = models.CharField( status = models.CharField(
@ -1639,7 +1639,7 @@ class DeletionRequest(models.Model):
# Documents to be deleted # Documents to be deleted
documents = models.ManyToManyField( documents = models.ManyToManyField(
Document, Document,
related_name='deletion_requests', related_name="deletion_requests",
help_text=_("Documents that would be deleted if approved"), help_text=_("Documents that would be deleted if approved"),
) )
@ -1656,7 +1656,7 @@ class DeletionRequest(models.Model):
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
null=True, null=True,
blank=True, blank=True,
related_name='reviewed_deletion_requests', related_name="reviewed_deletion_requests",
help_text=_("User who reviewed and approved/rejected"), help_text=_("User who reviewed and approved/rejected"),
) )
review_comment = models.TextField( review_comment = models.TextField(
@ -1672,21 +1672,21 @@ class DeletionRequest(models.Model):
) )
class Meta: class Meta:
ordering = ['-created_at'] ordering = ["-created_at"]
verbose_name = _("deletion request") verbose_name = _("deletion request")
verbose_name_plural = _("deletion requests") verbose_name_plural = _("deletion requests")
indexes = [ indexes = [
# Composite index for common listing queries (by user, filtered by status, sorted by date) # Composite index for common listing queries (by user, filtered by status, sorted by date)
# PostgreSQL can use this index for queries on: user, user+status, user+status+created_at # PostgreSQL can use this index for queries on: user, user+status, user+status+created_at
models.Index(fields=['user', 'status', 'created_at'], name='delreq_user_status_created_idx'), models.Index(fields=["user", "status", "created_at"], name="delreq_user_status_created_idx"),
# Index for queries filtering by status and date without user filter # Index for queries filtering by status and date without user filter
models.Index(fields=['status', 'created_at'], name='delreq_status_created_idx'), models.Index(fields=["status", "created_at"], name="delreq_status_created_idx"),
# Index for queries filtering by user and date (common for user-specific views) # Index for queries filtering by user and date (common for user-specific views)
models.Index(fields=['user', 'created_at'], name='delreq_user_created_idx'), models.Index(fields=["user", "created_at"], name="delreq_user_created_idx"),
# Index for queries filtering by review date # Index for queries filtering by review date
models.Index(fields=['reviewed_at'], name='delreq_reviewed_at_idx'), models.Index(fields=["reviewed_at"], name="delreq_reviewed_at_idx"),
# Index for queries filtering by completion date # Index for queries filtering by completion date
models.Index(fields=['completed_at'], name='delreq_completed_at_idx'), models.Index(fields=["completed_at"], name="delreq_completed_at_idx"),
] ]
def __str__(self): def __str__(self):
@ -1745,67 +1745,67 @@ class AISuggestionFeedback(models.Model):
""" """
# Suggestion types # Suggestion types
TYPE_TAG = 'tag' TYPE_TAG = "tag"
TYPE_CORRESPONDENT = 'correspondent' TYPE_CORRESPONDENT = "correspondent"
TYPE_DOCUMENT_TYPE = 'document_type' TYPE_DOCUMENT_TYPE = "document_type"
TYPE_STORAGE_PATH = 'storage_path' TYPE_STORAGE_PATH = "storage_path"
TYPE_CUSTOM_FIELD = 'custom_field' TYPE_CUSTOM_FIELD = "custom_field"
TYPE_WORKFLOW = 'workflow' TYPE_WORKFLOW = "workflow"
TYPE_TITLE = 'title' TYPE_TITLE = "title"
SUGGESTION_TYPES = ( SUGGESTION_TYPES = (
(TYPE_TAG, _('Tag')), (TYPE_TAG, _("Tag")),
(TYPE_CORRESPONDENT, _('Correspondent')), (TYPE_CORRESPONDENT, _("Correspondent")),
(TYPE_DOCUMENT_TYPE, _('Document Type')), (TYPE_DOCUMENT_TYPE, _("Document Type")),
(TYPE_STORAGE_PATH, _('Storage Path')), (TYPE_STORAGE_PATH, _("Storage Path")),
(TYPE_CUSTOM_FIELD, _('Custom Field')), (TYPE_CUSTOM_FIELD, _("Custom Field")),
(TYPE_WORKFLOW, _('Workflow')), (TYPE_WORKFLOW, _("Workflow")),
(TYPE_TITLE, _('Title')), (TYPE_TITLE, _("Title")),
) )
# Feedback status # Feedback status
STATUS_APPLIED = 'applied' STATUS_APPLIED = "applied"
STATUS_REJECTED = 'rejected' STATUS_REJECTED = "rejected"
FEEDBACK_STATUS = ( FEEDBACK_STATUS = (
(STATUS_APPLIED, _('Applied')), (STATUS_APPLIED, _("Applied")),
(STATUS_REJECTED, _('Rejected')), (STATUS_REJECTED, _("Rejected")),
) )
document = models.ForeignKey( document = models.ForeignKey(
Document, Document,
on_delete=models.CASCADE, on_delete=models.CASCADE,
related_name='ai_suggestion_feedbacks', related_name="ai_suggestion_feedbacks",
verbose_name=_('document'), verbose_name=_("document"),
) )
suggestion_type = models.CharField( suggestion_type = models.CharField(
_('suggestion type'), _("suggestion type"),
max_length=50, max_length=50,
choices=SUGGESTION_TYPES, choices=SUGGESTION_TYPES,
) )
suggested_value_id = models.IntegerField( suggested_value_id = models.IntegerField(
_('suggested value ID'), _("suggested value ID"),
null=True, null=True,
blank=True, blank=True,
help_text=_('ID of the suggested object (tag, correspondent, etc.)'), help_text=_("ID of the suggested object (tag, correspondent, etc.)"),
) )
suggested_value_text = models.TextField( suggested_value_text = models.TextField(
_('suggested value text'), _("suggested value text"),
blank=True, blank=True,
help_text=_('Text representation of the suggested value'), help_text=_("Text representation of the suggested value"),
) )
confidence = models.FloatField( confidence = models.FloatField(
_('confidence'), _("confidence"),
help_text=_('AI confidence score (0.0 to 1.0)'), help_text=_("AI confidence score (0.0 to 1.0)"),
validators=[MinValueValidator(0.0), MaxValueValidator(1.0)], validators=[MinValueValidator(0.0), MaxValueValidator(1.0)],
) )
status = models.CharField( status = models.CharField(
_('status'), _("status"),
max_length=20, max_length=20,
choices=FEEDBACK_STATUS, choices=FEEDBACK_STATUS,
) )
@ -1815,36 +1815,36 @@ class AISuggestionFeedback(models.Model):
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
null=True, null=True,
blank=True, blank=True,
related_name='ai_suggestion_feedbacks', related_name="ai_suggestion_feedbacks",
verbose_name=_('user'), verbose_name=_("user"),
help_text=_('User who applied or rejected the suggestion'), help_text=_("User who applied or rejected the suggestion"),
) )
created_at = models.DateTimeField( created_at = models.DateTimeField(
_('created at'), _("created at"),
auto_now_add=True, auto_now_add=True,
) )
applied_at = models.DateTimeField( applied_at = models.DateTimeField(
_('applied/rejected at'), _("applied/rejected at"),
auto_now=True, auto_now=True,
) )
metadata = models.JSONField( metadata = models.JSONField(
_('metadata'), _("metadata"),
default=dict, default=dict,
blank=True, blank=True,
help_text=_('Additional metadata about the suggestion'), help_text=_("Additional metadata about the suggestion"),
) )
class Meta: class Meta:
verbose_name = _('AI suggestion feedback') verbose_name = _("AI suggestion feedback")
verbose_name_plural = _('AI suggestion feedbacks') verbose_name_plural = _("AI suggestion feedbacks")
ordering = ['-created_at'] ordering = ["-created_at"]
indexes = [ indexes = [
models.Index(fields=['document', 'suggestion_type']), models.Index(fields=["document", "suggestion_type"]),
models.Index(fields=['status', 'created_at']), models.Index(fields=["status", "created_at"]),
models.Index(fields=['suggestion_type', 'status']), models.Index(fields=["suggestion_type", "status"]),
] ]
def __str__(self): def __str__(self):

View file

@ -11,21 +11,21 @@ Lazy imports are used to avoid loading heavy dependencies unless needed.
""" """
__all__ = [ __all__ = [
'TableExtractor', "FormFieldDetector",
'HandwritingRecognizer', "HandwritingRecognizer",
'FormFieldDetector', "TableExtractor",
] ]
def __getattr__(name): def __getattr__(name):
"""Lazy import to avoid loading heavy ML models on startup.""" """Lazy import to avoid loading heavy ML models on startup."""
if name == 'TableExtractor': if name == "TableExtractor":
from .table_extractor import TableExtractor from .table_extractor import TableExtractor
return TableExtractor return TableExtractor
elif name == 'HandwritingRecognizer': elif name == "HandwritingRecognizer":
from .handwriting import HandwritingRecognizer from .handwriting import HandwritingRecognizer
return HandwritingRecognizer return HandwritingRecognizer
elif name == 'FormFieldDetector': elif name == "FormFieldDetector":
from .form_detector import FormFieldDetector from .form_detector import FormFieldDetector
return FormFieldDetector return FormFieldDetector
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View file

@ -8,8 +8,8 @@ This module provides capabilities to:
""" """
import logging import logging
from pathlib import Path from typing import Any
from typing import List, Dict, Any, Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -59,8 +59,8 @@ class FormFieldDetector:
self, self,
image: Image.Image, image: Image.Image,
min_size: int = 10, min_size: int = 10,
max_size: int = 50 max_size: int = 50,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Detect checkboxes in a form image. Detect checkboxes in a form image.
@ -114,9 +114,9 @@ class FormFieldDetector:
checked, confidence = self._is_checkbox_checked(checkbox_region) checked, confidence = self._is_checkbox_checked(checkbox_region)
checkboxes.append({ checkboxes.append({
'bbox': [x, y, x+w, y+h], "bbox": [x, y, x+w, y+h],
'checked': checked, "checked": checked,
'confidence': confidence "confidence": confidence,
}) })
logger.info(f"Detected {len(checkboxes)} checkboxes") logger.info(f"Detected {len(checkboxes)} checkboxes")
@ -129,7 +129,7 @@ class FormFieldDetector:
logger.error(f"Error detecting checkboxes: {e}") logger.error(f"Error detecting checkboxes: {e}")
return [] return []
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> Tuple[bool, float]: def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> tuple[bool, float]:
""" """
Determine if a checkbox is checked. Determine if a checkbox is checked.
@ -167,8 +167,8 @@ class FormFieldDetector:
def detect_text_fields( def detect_text_fields(
self, self,
image: Image.Image, image: Image.Image,
min_width: int = 100 min_width: int = 100,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Detect text input fields in a form. Detect text input fields in a form.
@ -202,14 +202,14 @@ class FormFieldDetector:
cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1], cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1],
cv2.MORPH_OPEN, cv2.MORPH_OPEN,
horizontal_kernel, horizontal_kernel,
iterations=2 iterations=2,
) )
# Find contours of horizontal lines # Find contours of horizontal lines
contours, _ = cv2.findContours( contours, _ = cv2.findContours(
detect_horizontal, detect_horizontal,
cv2.RETR_EXTERNAL, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE cv2.CHAIN_APPROX_SIMPLE,
) )
text_fields = [] text_fields = []
@ -221,8 +221,8 @@ class FormFieldDetector:
# Expand upward to include text area # Expand upward to include text area
text_bbox = [x, max(0, y-30), x+w, y+h] text_bbox = [x, max(0, y-30), x+w, y+h]
text_fields.append({ text_fields.append({
'bbox': text_bbox, "bbox": text_bbox,
'type': 'line' "type": "line",
}) })
# Detect rectangular boxes (bordered text fields) # Detect rectangular boxes (bordered text fields)
@ -236,8 +236,8 @@ class FormFieldDetector:
aspect_ratio = w / h if h > 0 else 0 aspect_ratio = w / h if h > 0 else 0
if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2: if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2:
text_fields.append({ text_fields.append({
'bbox': [x, y, x+w, y+h], "bbox": [x, y, x+w, y+h],
'type': 'box' "type": "box",
}) })
logger.info(f"Detected {len(text_fields)} text fields") logger.info(f"Detected {len(text_fields)} text fields")
@ -253,8 +253,8 @@ class FormFieldDetector:
def detect_labels( def detect_labels(
self, self,
image: Image.Image, image: Image.Image,
field_bboxes: List[List[int]] field_bboxes: list[list[int]],
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Detect labels near form fields. Detect labels near form fields.
@ -271,17 +271,17 @@ class FormFieldDetector:
# Get all text with bounding boxes # Get all text with bounding boxes
ocr_data = pytesseract.image_to_data( ocr_data = pytesseract.image_to_data(
image, image,
output_type=pytesseract.Output.DICT output_type=pytesseract.Output.DICT,
) )
# Group text into potential labels # Group text into potential labels
labels = [] labels = []
for i, text in enumerate(ocr_data['text']): for i, text in enumerate(ocr_data["text"]):
if text.strip() and len(text.strip()) > 2: if text.strip() and len(text.strip()) > 2:
x = ocr_data['left'][i] x = ocr_data["left"][i]
y = ocr_data['top'][i] y = ocr_data["top"][i]
w = ocr_data['width'][i] w = ocr_data["width"][i]
h = ocr_data['height'][i] h = ocr_data["height"][i]
label_bbox = [x, y, x+w, y+h] label_bbox = [x, y, x+w, y+h]
@ -289,9 +289,9 @@ class FormFieldDetector:
closest_field_idx = self._find_closest_field(label_bbox, field_bboxes) closest_field_idx = self._find_closest_field(label_bbox, field_bboxes)
labels.append({ labels.append({
'text': text.strip(), "text": text.strip(),
'bbox': label_bbox, "bbox": label_bbox,
'field_index': closest_field_idx "field_index": closest_field_idx,
}) })
return labels return labels
@ -305,9 +305,9 @@ class FormFieldDetector:
def _find_closest_field( def _find_closest_field(
self, self,
label_bbox: List[int], label_bbox: list[int],
field_bboxes: List[List[int]] field_bboxes: list[list[int]],
) -> Optional[int]: ) -> int | None:
""" """
Find the closest field to a label. Find the closest field to a label.
@ -325,7 +325,7 @@ class FormFieldDetector:
label_center_x = (label_bbox[0] + label_bbox[2]) / 2 label_center_x = (label_bbox[0] + label_bbox[2]) / 2
label_center_y = (label_bbox[1] + label_bbox[3]) / 2 label_center_y = (label_bbox[1] + label_bbox[3]) / 2
min_distance = float('inf') min_distance = float("inf")
closest_idx = 0 closest_idx = 0
for i, field_bbox in enumerate(field_bboxes): for i, field_bbox in enumerate(field_bboxes):
@ -336,7 +336,7 @@ class FormFieldDetector:
# Euclidean distance # Euclidean distance
distance = np.sqrt( distance = np.sqrt(
(label_center_x - field_center_x)**2 + (label_center_x - field_center_x)**2 +
(label_center_y - field_center_y)**2 (label_center_y - field_center_y)**2,
) )
if distance < min_distance: if distance < min_distance:
@ -348,8 +348,8 @@ class FormFieldDetector:
def detect_form_fields( def detect_form_fields(
self, self,
image_path: str, image_path: str,
extract_values: bool = True extract_values: bool = True,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Detect all form fields and extract their values. Detect all form fields and extract their values.
@ -372,14 +372,14 @@ class FormFieldDetector:
""" """
try: try:
# Load image # Load image
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert("RGB")
# Detect different field types # Detect different field types
text_fields = self.detect_text_fields(image) text_fields = self.detect_text_fields(image)
checkboxes = self.detect_checkboxes(image) checkboxes = self.detect_checkboxes(image)
# Combine all field bboxes for label detection # Combine all field bboxes for label detection
all_field_bboxes = [f['bbox'] for f in text_fields] + [cb['bbox'] for cb in checkboxes] all_field_bboxes = [f["bbox"] for f in text_fields] + [cb["bbox"] for cb in checkboxes]
# Detect labels # Detect labels
labels = self.detect_labels(image, all_field_bboxes) labels = self.detect_labels(image, all_field_bboxes)
@ -393,20 +393,20 @@ class FormFieldDetector:
label_text = self._find_label_for_field(i, labels, len(text_fields)) label_text = self._find_label_for_field(i, labels, len(text_fields))
result = { result = {
'type': 'text', "type": "text",
'label': label_text, "label": label_text,
'bbox': field['bbox'], "bbox": field["bbox"],
} }
# Extract value if requested # Extract value if requested
if extract_values: if extract_values:
x1, y1, x2, y2 = field['bbox'] x1, y1, x2, y2 = field["bbox"]
field_image = image.crop((x1, y1, x2, y2)) field_image = image.crop((x1, y1, x2, y2))
recognizer = self._get_handwriting_recognizer() recognizer = self._get_handwriting_recognizer()
value = recognizer.recognize_from_image(field_image, preprocess=True) value = recognizer.recognize_from_image(field_image, preprocess=True)
result['value'] = value.strip() result["value"] = value.strip()
result['confidence'] = recognizer._estimate_confidence(value) result["confidence"] = recognizer._estimate_confidence(value)
results.append(result) results.append(result)
@ -416,11 +416,11 @@ class FormFieldDetector:
label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes)) label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes))
results.append({ results.append({
'type': 'checkbox', "type": "checkbox",
'label': label_text, "label": label_text,
'value': checkbox['checked'], "value": checkbox["checked"],
'bbox': checkbox['bbox'], "bbox": checkbox["bbox"],
'confidence': checkbox['confidence'] "confidence": checkbox["confidence"],
}) })
logger.info(f"Detected {len(results)} form fields from {image_path}") logger.info(f"Detected {len(results)} form fields from {image_path}")
@ -433,8 +433,8 @@ class FormFieldDetector:
def _find_label_for_field( def _find_label_for_field(
self, self,
field_idx: int, field_idx: int,
labels: List[Dict[str, Any]], labels: list[dict[str, Any]],
total_fields: int total_fields: int,
) -> str: ) -> str:
""" """
Find the label text for a specific field. Find the label text for a specific field.
@ -449,19 +449,19 @@ class FormFieldDetector:
""" """
matching_labels = [ matching_labels = [
label for label in labels label for label in labels
if label['field_index'] == field_idx if label["field_index"] == field_idx
] ]
if matching_labels: if matching_labels:
# Combine multiple label parts if found # Combine multiple label parts if found
return ' '.join(label['text'] for label in matching_labels) return " ".join(label["text"] for label in matching_labels)
return f"Field_{field_idx + 1}" return f"Field_{field_idx + 1}"
def extract_form_data( def extract_form_data(
self, self,
image_path: str, image_path: str,
output_format: str = 'dict' output_format: str = "dict",
) -> Any: ) -> Any:
""" """
Extract all form data as structured output. Extract all form data as structured output.
@ -476,16 +476,16 @@ class FormFieldDetector:
# Detect and extract fields # Detect and extract fields
fields = self.detect_form_fields(image_path, extract_values=True) fields = self.detect_form_fields(image_path, extract_values=True)
if output_format == 'dict': if output_format == "dict":
# Return as dictionary # Return as dictionary
return {field['label']: field['value'] for field in fields} return {field["label"]: field["value"] for field in fields}
elif output_format == 'json': elif output_format == "json":
import json import json
data = {field['label']: field['value'] for field in fields} data = {field["label"]: field["value"] for field in fields}
return json.dumps(data, indent=2) return json.dumps(data, indent=2)
elif output_format == 'dataframe': elif output_format == "dataframe":
import pandas as pd import pandas as pd
return pd.DataFrame(fields) return pd.DataFrame(fields)

View file

@ -8,8 +8,8 @@ This module provides handwriting OCR capabilities using:
""" """
import logging import logging
from pathlib import Path from typing import Any
from typing import List, Dict, Any, Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -65,8 +65,9 @@ class HandwritingRecognizer:
return return
try: try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch import torch
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
logger.info(f"Loading handwriting recognition model: {self.model_name}") logger.info(f"Loading handwriting recognition model: {self.model_name}")
@ -90,7 +91,7 @@ class HandwritingRecognizer:
def recognize_from_image( def recognize_from_image(
self, self,
image: Image.Image, image: Image.Image,
preprocess: bool = True preprocess: bool = True,
) -> str: ) -> str:
""" """
Recognize text from a single image. Recognize text from a single image.
@ -142,11 +143,12 @@ class HandwritingRecognizer:
Preprocessed PIL Image Preprocessed PIL Image
""" """
try: try:
from PIL import ImageEnhance, ImageFilter from PIL import ImageEnhance
from PIL import ImageFilter
# Convert to grayscale # Convert to grayscale
if image.mode != 'L': if image.mode != "L":
image = image.convert('L') image = image.convert("L")
# Enhance contrast # Enhance contrast
enhancer = ImageEnhance.Contrast(image) enhancer = ImageEnhance.Contrast(image)
@ -156,7 +158,7 @@ class HandwritingRecognizer:
image = image.filter(ImageFilter.MedianFilter(size=3)) image = image.filter(ImageFilter.MedianFilter(size=3))
# Convert back to RGB (required by model) # Convert back to RGB (required by model)
image = image.convert('RGB') image = image.convert("RGB")
return image return image
@ -164,7 +166,7 @@ class HandwritingRecognizer:
logger.warning(f"Error preprocessing image: {e}") logger.warning(f"Error preprocessing image: {e}")
return image return image
def detect_text_lines(self, image: Image.Image) -> List[Dict[str, Any]]: def detect_text_lines(self, image: Image.Image) -> list[dict[str, Any]]:
""" """
Detect individual text lines in an image. Detect individual text lines in an image.
@ -208,12 +210,12 @@ class HandwritingRecognizer:
# Crop line from original image # Crop line from original image
line_img = image.crop((x, y, x+w, y+h)) line_img = image.crop((x, y, x+w, y+h))
lines.append({ lines.append({
'bbox': [x, y, x+w, y+h], "bbox": [x, y, x+w, y+h],
'image': line_img "image": line_img,
}) })
# Sort lines top to bottom # Sort lines top to bottom
lines.sort(key=lambda l: l['bbox'][1]) lines.sort(key=lambda l: l["bbox"][1])
logger.info(f"Detected {len(lines)} text lines") logger.info(f"Detected {len(lines)} text lines")
return lines return lines
@ -228,8 +230,8 @@ class HandwritingRecognizer:
def recognize_lines( def recognize_lines(
self, self,
image_path: str, image_path: str,
return_confidence: bool = True return_confidence: bool = True,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Recognize text from each line in an image. Recognize text from each line in an image.
@ -250,7 +252,7 @@ class HandwritingRecognizer:
""" """
try: try:
# Load image # Load image
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert("RGB")
# Detect lines # Detect lines
lines = self.detect_text_lines(image) lines = self.detect_text_lines(image)
@ -260,18 +262,18 @@ class HandwritingRecognizer:
for i, line in enumerate(lines): for i, line in enumerate(lines):
logger.debug(f"Recognizing line {i+1}/{len(lines)}") logger.debug(f"Recognizing line {i+1}/{len(lines)}")
text = self.recognize_from_image(line['image'], preprocess=True) text = self.recognize_from_image(line["image"], preprocess=True)
result = { result = {
'text': text, "text": text,
'bbox': line['bbox'], "bbox": line["bbox"],
'line_index': i "line_index": i,
} }
if return_confidence: if return_confidence:
# Simple confidence based on text length and content # Simple confidence based on text length and content
confidence = self._estimate_confidence(text) confidence = self._estimate_confidence(text)
result['confidence'] = confidence result["confidence"] = confidence
results.append(result) results.append(result)
@ -309,7 +311,7 @@ class HandwritingRecognizer:
score += 0.1 score += 0.1
# Text with spaces (words) is more reliable # Text with spaces (words) is more reliable
if ' ' in text: if " " in text:
score += 0.1 score += 0.1
# Penalize if too many special characters # Penalize if too many special characters
@ -322,8 +324,8 @@ class HandwritingRecognizer:
def recognize_from_file( def recognize_from_file(
self, self,
image_path: str, image_path: str,
mode: str = 'full' mode: str = "full",
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Recognize handwriting from an image file. Recognize handwriting from an image file.
@ -337,30 +339,30 @@ class HandwritingRecognizer:
Dictionary with recognized text and metadata Dictionary with recognized text and metadata
""" """
try: try:
if mode == 'full': if mode == "full":
# Recognize entire image # Recognize entire image
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert("RGB")
text = self.recognize_from_image(image, preprocess=True) text = self.recognize_from_image(image, preprocess=True)
return { return {
'text': text, "text": text,
'mode': 'full', "mode": "full",
'confidence': self._estimate_confidence(text) "confidence": self._estimate_confidence(text),
} }
elif mode == 'lines': elif mode == "lines":
# Recognize line by line # Recognize line by line
lines = self.recognize_lines(image_path, return_confidence=True) lines = self.recognize_lines(image_path, return_confidence=True)
# Combine all lines # Combine all lines
full_text = '\n'.join(line['text'] for line in lines) full_text = "\n".join(line["text"] for line in lines)
avg_confidence = np.mean([line['confidence'] for line in lines]) if lines else 0.0 avg_confidence = np.mean([line["confidence"] for line in lines]) if lines else 0.0
return { return {
'text': full_text, "text": full_text,
'lines': lines, "lines": lines,
'mode': 'lines', "mode": "lines",
'confidence': float(avg_confidence) "confidence": float(avg_confidence),
} }
else: else:
@ -369,17 +371,17 @@ class HandwritingRecognizer:
except Exception as e: except Exception as e:
logger.error(f"Error recognizing from file {image_path}: {e}") logger.error(f"Error recognizing from file {image_path}: {e}")
return { return {
'text': '', "text": "",
'mode': mode, "mode": mode,
'confidence': 0.0, "confidence": 0.0,
'error': str(e) "error": str(e),
} }
def recognize_form_fields( def recognize_form_fields(
self, self,
image_path: str, image_path: str,
field_regions: List[Dict[str, Any]] field_regions: list[dict[str, Any]],
) -> Dict[str, str]: ) -> dict[str, str]:
""" """
Recognize text from specific form fields. Recognize text from specific form fields.
@ -399,13 +401,13 @@ class HandwritingRecognizer:
""" """
try: try:
# Load image # Load image
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert("RGB")
# Extract and recognize each field # Extract and recognize each field
results = {} results = {}
for field in field_regions: for field in field_regions:
name = field['name'] name = field["name"]
bbox = field['bbox'] bbox = field["bbox"]
# Crop field region # Crop field region
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
@ -425,9 +427,9 @@ class HandwritingRecognizer:
def batch_recognize( def batch_recognize(
self, self,
image_paths: List[str], image_paths: list[str],
mode: str = 'full' mode: str = "full",
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Recognize handwriting from multiple images in batch. Recognize handwriting from multiple images in batch.
@ -442,7 +444,7 @@ class HandwritingRecognizer:
for i, path in enumerate(image_paths): for i, path in enumerate(image_paths):
logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}") logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
result = self.recognize_from_file(path, mode=mode) result = self.recognize_from_file(path, mode=mode)
result['image_path'] = path result["image_path"] = path
results.append(result) results.append(result)
return results return results

View file

@ -8,9 +8,8 @@ This module uses various techniques to detect and extract tables from documents:
""" """
import logging import logging
from pathlib import Path from typing import Any
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from PIL import Image from PIL import Image
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,8 +58,9 @@ class TableExtractor:
return return
try: try:
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch import torch
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
logger.info(f"Loading table detection model: {self.model_name}") logger.info(f"Loading table detection model: {self.model_name}")
@ -79,7 +79,7 @@ class TableExtractor:
logger.error("Please install required packages: pip install transformers torch pillow") logger.error("Please install required packages: pip install transformers torch pillow")
raise raise
def detect_tables(self, image: Image.Image) -> List[Dict[str, Any]]: def detect_tables(self, image: Image.Image) -> list[dict[str, Any]]:
""" """
Detect tables in an image. Detect tables in an image.
@ -117,16 +117,16 @@ class TableExtractor:
results = self._processor.post_process_object_detection( results = self._processor.post_process_object_detection(
outputs, outputs,
threshold=self.confidence_threshold, threshold=self.confidence_threshold,
target_sizes=target_sizes target_sizes=target_sizes,
)[0] )[0]
# Convert to list of dicts # Convert to list of dicts
tables = [] tables = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
tables.append({ tables.append({
'bbox': box.cpu().tolist(), "bbox": box.cpu().tolist(),
'score': score.item(), "score": score.item(),
'label': self._model.config.id2label[label.item()] "label": self._model.config.id2label[label.item()],
}) })
logger.info(f"Detected {len(tables)} tables in image") logger.info(f"Detected {len(tables)} tables in image")
@ -139,9 +139,9 @@ class TableExtractor:
def extract_table_from_region( def extract_table_from_region(
self, self,
image: Image.Image, image: Image.Image,
bbox: List[float], bbox: list[float],
use_ocr: bool = True use_ocr: bool = True,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
""" """
Extract table data from a specific region of an image. Extract table data from a specific region of an image.
@ -166,7 +166,7 @@ class TableExtractor:
# Get detailed OCR data # Get detailed OCR data
ocr_data = pytesseract.image_to_data( ocr_data = pytesseract.image_to_data(
table_image, table_image,
output_type=pytesseract.Output.DICT output_type=pytesseract.Output.DICT,
) )
# Reconstruct table structure from OCR data # Reconstruct table structure from OCR data
@ -176,20 +176,20 @@ class TableExtractor:
raw_text = pytesseract.image_to_string(table_image) raw_text = pytesseract.image_to_string(table_image)
return { return {
'data': table_data, "data": table_data,
'raw_text': raw_text, "raw_text": raw_text,
'bbox': bbox, "bbox": bbox,
'image_size': table_image.size "image_size": table_image.size,
} }
else: else:
# Fallback to basic OCR without structure # Fallback to basic OCR without structure
import pytesseract import pytesseract
raw_text = pytesseract.image_to_string(table_image) raw_text = pytesseract.image_to_string(table_image)
return { return {
'data': None, "data": None,
'raw_text': raw_text, "raw_text": raw_text,
'bbox': bbox, "bbox": bbox,
'image_size': table_image.size "image_size": table_image.size,
} }
except ImportError: except ImportError:
@ -199,7 +199,7 @@ class TableExtractor:
logger.error(f"Error extracting table from region: {e}") logger.error(f"Error extracting table from region: {e}")
return None return None
def _reconstruct_table_from_ocr(self, ocr_data: Dict) -> Optional[Any]: def _reconstruct_table_from_ocr(self, ocr_data: dict) -> Any | None:
""" """
Reconstruct table structure from OCR output. Reconstruct table structure from OCR output.
@ -214,10 +214,10 @@ class TableExtractor:
# Group text by vertical position (rows) # Group text by vertical position (rows)
rows = {} rows = {}
for i, text in enumerate(ocr_data['text']): for i, text in enumerate(ocr_data["text"]):
if text.strip(): if text.strip():
top = ocr_data['top'][i] top = ocr_data["top"][i]
left = ocr_data['left'][i] left = ocr_data["left"][i]
# Group by approximate row (within 20 pixels) # Group by approximate row (within 20 pixels)
row_key = round(top / 20) * 20 row_key = round(top / 20) * 20
@ -235,14 +235,14 @@ class TableExtractor:
if table_rows: if table_rows:
# Pad rows to same length # Pad rows to same length
max_cols = max(len(row) for row in table_rows) max_cols = max(len(row) for row in table_rows)
table_rows = [row + [''] * (max_cols - len(row)) for row in table_rows] table_rows = [row + [""] * (max_cols - len(row)) for row in table_rows]
# Create DataFrame # Create DataFrame
df = pd.DataFrame(table_rows) df = pd.DataFrame(table_rows)
# Try to use first row as header if it looks like one # Try to use first row as header if it looks like one
if len(df) > 1: if len(df) > 1:
first_row_text = ' '.join(str(x) for x in df.iloc[0]) first_row_text = " ".join(str(x) for x in df.iloc[0])
if not any(char.isdigit() for char in first_row_text): if not any(char.isdigit() for char in first_row_text):
df.columns = df.iloc[0] df.columns = df.iloc[0]
df = df[1:].reset_index(drop=True) df = df[1:].reset_index(drop=True)
@ -261,8 +261,8 @@ class TableExtractor:
def extract_tables_from_image( def extract_tables_from_image(
self, self,
image_path: str, image_path: str,
output_format: str = 'dataframe' output_format: str = "dataframe",
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Extract all tables from an image file. Extract all tables from an image file.
@ -275,7 +275,7 @@ class TableExtractor:
""" """
try: try:
# Load image # Load image
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert("RGB")
# Detect tables # Detect tables
detections = self.detect_tables(image) detections = self.detect_tables(image)
@ -287,18 +287,18 @@ class TableExtractor:
table_data = self.extract_table_from_region( table_data = self.extract_table_from_region(
image, image,
detection['bbox'] detection["bbox"],
) )
if table_data: if table_data:
table_data['detection_score'] = detection['score'] table_data["detection_score"] = detection["score"]
table_data['table_index'] = i table_data["table_index"] = i
# Convert to requested format # Convert to requested format
if output_format == 'csv' and table_data['data'] is not None: if output_format == "csv" and table_data["data"] is not None:
table_data['csv'] = table_data['data'].to_csv(index=False) table_data["csv"] = table_data["data"].to_csv(index=False)
elif output_format == 'json' and table_data['data'] is not None: elif output_format == "json" and table_data["data"] is not None:
table_data['json'] = table_data['data'].to_json(orient='records') table_data["json"] = table_data["data"].to_json(orient="records")
tables.append(table_data) tables.append(table_data)
@ -312,8 +312,8 @@ class TableExtractor:
def extract_tables_from_pdf( def extract_tables_from_pdf(
self, self,
pdf_path: str, pdf_path: str,
page_numbers: Optional[List[int]] = None page_numbers: list[int] | None = None,
) -> Dict[int, List[Dict[str, Any]]]: ) -> dict[int, list[dict[str, Any]]]:
""" """
Extract tables from a PDF document. Extract tables from a PDF document.
@ -334,7 +334,7 @@ class TableExtractor:
images = convert_from_path( images = convert_from_path(
pdf_path, pdf_path,
first_page=min(page_numbers), first_page=min(page_numbers),
last_page=max(page_numbers) last_page=max(page_numbers),
) )
else: else:
images = convert_from_path(pdf_path) images = convert_from_path(pdf_path)
@ -352,11 +352,11 @@ class TableExtractor:
for detection in detections: for detection in detections:
table_data = self.extract_table_from_region( table_data = self.extract_table_from_region(
image, image,
detection['bbox'] detection["bbox"],
) )
if table_data: if table_data:
table_data['detection_score'] = detection['score'] table_data["detection_score"] = detection["score"]
table_data['page'] = page_num table_data["page"] = page_num
tables.append(table_data) tables.append(table_data)
if tables: if tables:
@ -374,8 +374,8 @@ class TableExtractor:
def save_tables_to_excel( def save_tables_to_excel(
self, self,
tables: List[Dict[str, Any]], tables: list[dict[str, Any]],
output_path: str output_path: str,
) -> bool: ) -> bool:
""" """
Save extracted tables to an Excel file. Save extracted tables to an Excel file.
@ -390,17 +390,17 @@ class TableExtractor:
try: try:
import pandas as pd import pandas as pd
with pd.ExcelWriter(output_path, engine='openpyxl') as writer: with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
for i, table in enumerate(tables): for i, table in enumerate(tables):
if table.get('data') is not None: if table.get("data") is not None:
sheet_name = f"Table_{i+1}" sheet_name = f"Table_{i+1}"
if 'page' in table: if "page" in table:
sheet_name = f"Page_{table['page']}_Table_{i+1}" sheet_name = f"Page_{table['page']}_Table_{i+1}"
table['data'].to_excel( table["data"].to_excel(
writer, writer,
sheet_name=sheet_name, sheet_name=sheet_name,
index=False index=False,
) )
logger.info(f"Saved {len(tables)} tables to {output_path}") logger.info(f"Saved {len(tables)} tables to {output_path}")

View file

@ -46,7 +46,6 @@ if settings.AUDIT_LOG_ENABLED:
from documents import bulk_edit from documents import bulk_edit
from documents.data_models import DocumentSource from documents.data_models import DocumentSource
from documents.filters import CustomFieldQueryParser from documents.filters import CustomFieldQueryParser
from documents.models import AISuggestionFeedback
from documents.models import Correspondent from documents.models import Correspondent
from documents.models import CustomField from documents.models import CustomField
from documents.models import CustomFieldInstance from documents.models import CustomFieldInstance
@ -2788,9 +2787,9 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
"""Serializer for DeletionRequest model with document details.""" """Serializer for DeletionRequest model with document details."""
document_details = serializers.SerializerMethodField() document_details = serializers.SerializerMethodField()
user_username = serializers.CharField(source='user.username', read_only=True) user_username = serializers.CharField(source="user.username", read_only=True)
reviewed_by_username = serializers.CharField( reviewed_by_username = serializers.CharField(
source='reviewed_by.username', source="reviewed_by.username",
read_only=True, read_only=True,
allow_null=True, allow_null=True,
) )
@ -2799,31 +2798,31 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
from documents.models import DeletionRequest from documents.models import DeletionRequest
model = DeletionRequest model = DeletionRequest
fields = [ fields = [
'id', "id",
'created_at', "created_at",
'updated_at', "updated_at",
'requested_by_ai', "requested_by_ai",
'ai_reason', "ai_reason",
'user', "user",
'user_username', "user_username",
'status', "status",
'impact_summary', "impact_summary",
'reviewed_at', "reviewed_at",
'reviewed_by', "reviewed_by",
'reviewed_by_username', "reviewed_by_username",
'review_comment', "review_comment",
'completed_at', "completed_at",
'completion_details', "completion_details",
'document_details', "document_details",
] ]
read_only_fields = [ read_only_fields = [
'id', "id",
'created_at', "created_at",
'updated_at', "updated_at",
'reviewed_at', "reviewed_at",
'reviewed_by', "reviewed_by",
'completed_at', "completed_at",
'completion_details', "completion_details",
] ]
def get_document_details(self, obj): def get_document_details(self, obj):
@ -2831,12 +2830,12 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
documents = obj.documents.all() documents = obj.documents.all()
return [ return [
{ {
'id': doc.id, "id": doc.id,
'title': doc.title, "title": doc.title,
'created': doc.created.isoformat() if doc.created else None, "created": doc.created.isoformat() if doc.created else None,
'correspondent': doc.correspondent.name if doc.correspondent else None, "correspondent": doc.correspondent.name if doc.correspondent else None,
'document_type': doc.document_type.name if doc.document_type else None, "document_type": doc.document_type.name if doc.document_type else None,
'tags': [tag.name for tag in doc.tags.all()], "tags": [tag.name for tag in doc.tags.all()],
} }
for doc in documents for doc in documents
] ]
@ -2852,6 +2851,9 @@ class DeletionRequestActionSerializer(serializers.Serializer):
allow_blank=True, allow_blank=True,
label="Review Comment", label="Review Comment",
help_text="Optional comment when reviewing the deletion request", help_text="Optional comment when reviewing the deletion request",
)
class AISuggestionsRequestSerializer(serializers.Serializer): class AISuggestionsRequestSerializer(serializers.Serializer):
"""Serializer for requesting AI suggestions for a document.""" """Serializer for requesting AI suggestions for a document."""

View file

@ -1,17 +1,15 @@
"""Serializers package for documents app.""" """Serializers package for documents app."""
from .ai_suggestions import ( from .ai_suggestions import AISuggestionFeedbackSerializer
AISuggestionFeedbackSerializer, from .ai_suggestions import AISuggestionsSerializer
AISuggestionsSerializer, from .ai_suggestions import AISuggestionStatsSerializer
AISuggestionStatsSerializer, from .ai_suggestions import ApplySuggestionSerializer
ApplySuggestionSerializer, from .ai_suggestions import RejectSuggestionSerializer
RejectSuggestionSerializer,
)
__all__ = [ __all__ = [
'AISuggestionFeedbackSerializer', "AISuggestionFeedbackSerializer",
'AISuggestionsSerializer', "AISuggestionStatsSerializer",
'AISuggestionStatsSerializer', "AISuggestionsSerializer",
'ApplySuggestionSerializer', "ApplySuggestionSerializer",
'RejectSuggestionSerializer', "RejectSuggestionSerializer",
] ]

View file

@ -7,36 +7,33 @@ and handling user feedback on AI suggestions.
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict from typing import Any
from rest_framework import serializers from rest_framework import serializers
from documents.models import ( from documents.models import AISuggestionFeedback
AISuggestionFeedback, from documents.models import Correspondent
Correspondent, from documents.models import CustomField
CustomField, from documents.models import DocumentType
DocumentType, from documents.models import StoragePath
StoragePath, from documents.models import Tag
Tag, from documents.models import Workflow
Workflow,
)
# Suggestion type choices - used across multiple serializers # Suggestion type choices - used across multiple serializers
SUGGESTION_TYPE_CHOICES = [ SUGGESTION_TYPE_CHOICES = [
'tag', "tag",
'correspondent', "correspondent",
'document_type', "document_type",
'storage_path', "storage_path",
'custom_field', "custom_field",
'workflow', "workflow",
'title', "title",
] ]
# Types that require value_id # Types that require value_id
ID_REQUIRED_TYPES = ['tag', 'correspondent', 'document_type', 'storage_path', 'workflow'] ID_REQUIRED_TYPES = ["tag", "correspondent", "document_type", "storage_path", "workflow"]
# Types that require value_text # Types that require value_text
TEXT_REQUIRED_TYPES = ['title'] TEXT_REQUIRED_TYPES = ["title"]
# Types that can use either (custom_field can be ID or text) # Types that can use either (custom_field can be ID or text)
@ -113,7 +110,7 @@ class AISuggestionsSerializer(serializers.Serializer):
title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True) title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True)
@staticmethod @staticmethod
def from_scan_result(scan_result, document_id: int) -> Dict[str, Any]: def from_scan_result(scan_result, document_id: int) -> dict[str, Any]:
""" """
Convert an AIScanResult object to serializer data. Convert an AIScanResult object to serializer data.
@ -133,25 +130,25 @@ class AISuggestionsSerializer(serializers.Serializer):
try: try:
tag = Tag.objects.get(pk=tag_id) tag = Tag.objects.get(pk=tag_id)
tag_suggestions.append({ tag_suggestions.append({
'id': tag.id, "id": tag.id,
'name': tag.name, "name": tag.name,
'color': getattr(tag, 'color', '#000000'), "color": getattr(tag, "color", "#000000"),
'confidence': confidence, "confidence": confidence,
}) })
except Tag.DoesNotExist: except Tag.DoesNotExist:
# Tag no longer exists in database; skip this suggestion # Tag no longer exists in database; skip this suggestion
pass pass
data['tags'] = tag_suggestions data["tags"] = tag_suggestions
# Correspondent # Correspondent
if scan_result.correspondent: if scan_result.correspondent:
corr_id, confidence = scan_result.correspondent corr_id, confidence = scan_result.correspondent
try: try:
correspondent = Correspondent.objects.get(pk=corr_id) correspondent = Correspondent.objects.get(pk=corr_id)
data['correspondent'] = { data["correspondent"] = {
'id': correspondent.id, "id": correspondent.id,
'name': correspondent.name, "name": correspondent.name,
'confidence': confidence, "confidence": confidence,
} }
except Correspondent.DoesNotExist: except Correspondent.DoesNotExist:
# Correspondent no longer exists in database; omit from suggestions # Correspondent no longer exists in database; omit from suggestions
@ -162,10 +159,10 @@ class AISuggestionsSerializer(serializers.Serializer):
type_id, confidence = scan_result.document_type type_id, confidence = scan_result.document_type
try: try:
doc_type = DocumentType.objects.get(pk=type_id) doc_type = DocumentType.objects.get(pk=type_id)
data['document_type'] = { data["document_type"] = {
'id': doc_type.id, "id": doc_type.id,
'name': doc_type.name, "name": doc_type.name,
'confidence': confidence, "confidence": confidence,
} }
except DocumentType.DoesNotExist: except DocumentType.DoesNotExist:
# Document type no longer exists in database; omit from suggestions # Document type no longer exists in database; omit from suggestions
@ -176,11 +173,11 @@ class AISuggestionsSerializer(serializers.Serializer):
path_id, confidence = scan_result.storage_path path_id, confidence = scan_result.storage_path
try: try:
storage_path = StoragePath.objects.get(pk=path_id) storage_path = StoragePath.objects.get(pk=path_id)
data['storage_path'] = { data["storage_path"] = {
'id': storage_path.id, "id": storage_path.id,
'name': storage_path.name, "name": storage_path.name,
'path': storage_path.path, "path": storage_path.path,
'confidence': confidence, "confidence": confidence,
} }
except StoragePath.DoesNotExist: except StoragePath.DoesNotExist:
# Storage path no longer exists in database; omit from suggestions # Storage path no longer exists in database; omit from suggestions
@ -193,15 +190,15 @@ class AISuggestionsSerializer(serializers.Serializer):
try: try:
field = CustomField.objects.get(pk=field_id) field = CustomField.objects.get(pk=field_id)
field_suggestions.append({ field_suggestions.append({
'field_id': field.id, "field_id": field.id,
'field_name': field.name, "field_name": field.name,
'value': str(value), "value": str(value),
'confidence': confidence, "confidence": confidence,
}) })
except CustomField.DoesNotExist: except CustomField.DoesNotExist:
# Custom field no longer exists in database; skip this suggestion # Custom field no longer exists in database; skip this suggestion
pass pass
data['custom_fields'] = field_suggestions data["custom_fields"] = field_suggestions
# Workflows # Workflows
if scan_result.workflows: if scan_result.workflows:
@ -210,19 +207,19 @@ class AISuggestionsSerializer(serializers.Serializer):
try: try:
workflow = Workflow.objects.get(pk=workflow_id) workflow = Workflow.objects.get(pk=workflow_id)
workflow_suggestions.append({ workflow_suggestions.append({
'id': workflow.id, "id": workflow.id,
'name': workflow.name, "name": workflow.name,
'confidence': confidence, "confidence": confidence,
}) })
except Workflow.DoesNotExist: except Workflow.DoesNotExist:
# Workflow no longer exists in database; skip this suggestion # Workflow no longer exists in database; skip this suggestion
pass pass
data['workflows'] = workflow_suggestions data["workflows"] = workflow_suggestions
# Title suggestion # Title suggestion
if scan_result.title_suggestion: if scan_result.title_suggestion:
data['title_suggestion'] = { data["title_suggestion"] = {
'title': scan_result.title_suggestion, "title": scan_result.title_suggestion,
} }
return data return data
@ -234,26 +231,26 @@ class SuggestionSerializerMixin:
""" """
def validate(self, attrs): def validate(self, attrs):
"""Validate that the correct value field is provided for the suggestion type.""" """Validate that the correct value field is provided for the suggestion type."""
suggestion_type = attrs.get('suggestion_type') suggestion_type = attrs.get("suggestion_type")
value_id = attrs.get('value_id') value_id = attrs.get("value_id")
value_text = attrs.get('value_text') value_text = attrs.get("value_text")
# Types that require value_id # Types that require value_id
if suggestion_type in ID_REQUIRED_TYPES and not value_id: if suggestion_type in ID_REQUIRED_TYPES and not value_id:
raise serializers.ValidationError( raise serializers.ValidationError(
f"value_id is required for suggestion_type '{suggestion_type}'" f"value_id is required for suggestion_type '{suggestion_type}'",
) )
# Types that require value_text # Types that require value_text
if suggestion_type in TEXT_REQUIRED_TYPES and not value_text: if suggestion_type in TEXT_REQUIRED_TYPES and not value_text:
raise serializers.ValidationError( raise serializers.ValidationError(
f"value_text is required for suggestion_type '{suggestion_type}'" f"value_text is required for suggestion_type '{suggestion_type}'",
) )
# For custom_field, either is acceptable # For custom_field, either is acceptable
if suggestion_type == 'custom_field' and not value_id and not value_text: if suggestion_type == "custom_field" and not value_id and not value_text:
raise serializers.ValidationError( raise serializers.ValidationError(
"Either value_id or value_text must be provided for custom_field" "Either value_id or value_text must be provided for custom_field",
) )
return attrs return attrs
@ -295,19 +292,19 @@ class AISuggestionFeedbackSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = AISuggestionFeedback model = AISuggestionFeedback
fields = [ fields = [
'id', "id",
'document', "document",
'suggestion_type', "suggestion_type",
'suggested_value_id', "suggested_value_id",
'suggested_value_text', "suggested_value_text",
'confidence', "confidence",
'status', "status",
'user', "user",
'created_at', "created_at",
'applied_at', "applied_at",
'metadata', "metadata",
] ]
read_only_fields = ['id', 'created_at', 'applied_at'] read_only_fields = ["id", "created_at", "applied_at"]
class AISuggestionStatsSerializer(serializers.Serializer): class AISuggestionStatsSerializer(serializers.Serializer):

View file

@ -18,13 +18,11 @@ from django.test import TestCase
from django.utils import timezone from django.utils import timezone
from documents.ai_deletion_manager import AIDeletionManager from documents.ai_deletion_manager import AIDeletionManager
from documents.models import ( from documents.models import Correspondent
Correspondent, from documents.models import DeletionRequest
DeletionRequest, from documents.models import Document
Document, from documents.models import DocumentType
DocumentType, from documents.models import Tag
Tag,
)
class TestAIDeletionManagerCreateRequest(TestCase): class TestAIDeletionManagerCreateRequest(TestCase):

View file

@ -10,24 +10,23 @@ Tests cover:
- Permission assignment and verification - Permission assignment and verification
""" """
from django.contrib.auth.models import Group, Permission, User from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.test import TestCase from django.test import TestCase
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from documents.models import Document from documents.models import Document
from documents.permissions import ( from documents.permissions import CanApplyAISuggestionsPermission
CanApplyAISuggestionsPermission, from documents.permissions import CanApproveDeletionsPermission
CanApproveDeletionsPermission, from documents.permissions import CanConfigureAIPermission
CanConfigureAIPermission, from documents.permissions import CanViewAISuggestionsPermission
CanViewAISuggestionsPermission,
)
class MockView: class MockView:
"""Mock view for testing permissions.""" """Mock view for testing permissions."""
pass
class TestCanViewAISuggestionsPermission(TestCase): class TestCanViewAISuggestionsPermission(TestCase):
@ -41,13 +40,13 @@ class TestCanViewAISuggestionsPermission(TestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.regular_user = User.objects.create_user( self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123" username="regular", email="regular@test.com", password="regular123",
) )
self.permitted_user = User.objects.create_user( self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123" username="permitted", email="permitted@test.com", password="permitted123",
) )
# Assign permission to permitted_user # Assign permission to permitted_user
@ -107,13 +106,13 @@ class TestCanApplyAISuggestionsPermission(TestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.regular_user = User.objects.create_user( self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123" username="regular", email="regular@test.com", password="regular123",
) )
self.permitted_user = User.objects.create_user( self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123" username="permitted", email="permitted@test.com", password="permitted123",
) )
# Assign permission to permitted_user # Assign permission to permitted_user
@ -173,13 +172,13 @@ class TestCanApproveDeletionsPermission(TestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.regular_user = User.objects.create_user( self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123" username="regular", email="regular@test.com", password="regular123",
) )
self.permitted_user = User.objects.create_user( self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123" username="permitted", email="permitted@test.com", password="permitted123",
) )
# Assign permission to permitted_user # Assign permission to permitted_user
@ -239,13 +238,13 @@ class TestCanConfigureAIPermission(TestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.regular_user = User.objects.create_user( self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123" username="regular", email="regular@test.com", password="regular123",
) )
self.permitted_user = User.objects.create_user( self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123" username="permitted", email="permitted@test.com", password="permitted123",
) )
# Assign permission to permitted_user # Assign permission to permitted_user
@ -345,7 +344,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_viewer_role_permissions(self): def test_viewer_role_permissions(self):
"""Test that viewer role has appropriate permissions.""" """Test that viewer role has appropriate permissions."""
user = User.objects.create_user( user = User.objects.create_user(
username="viewer", email="viewer@test.com", password="viewer123" username="viewer", email="viewer@test.com", password="viewer123",
) )
user.groups.add(self.viewer_group) user.groups.add(self.viewer_group)
@ -360,7 +359,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_editor_role_permissions(self): def test_editor_role_permissions(self):
"""Test that editor role has appropriate permissions.""" """Test that editor role has appropriate permissions."""
user = User.objects.create_user( user = User.objects.create_user(
username="editor", email="editor@test.com", password="editor123" username="editor", email="editor@test.com", password="editor123",
) )
user.groups.add(self.editor_group) user.groups.add(self.editor_group)
@ -375,7 +374,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_admin_role_permissions(self): def test_admin_role_permissions(self):
"""Test that admin role has all permissions.""" """Test that admin role has all permissions."""
user = User.objects.create_user( user = User.objects.create_user(
username="ai_admin", email="ai_admin@test.com", password="admin123" username="ai_admin", email="ai_admin@test.com", password="admin123",
) )
user.groups.add(self.admin_group) user.groups.add(self.admin_group)
@ -390,7 +389,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_user_with_multiple_groups(self): def test_user_with_multiple_groups(self):
"""Test that user permissions accumulate from multiple groups.""" """Test that user permissions accumulate from multiple groups."""
user = User.objects.create_user( user = User.objects.create_user(
username="multi_role", email="multi@test.com", password="multi123" username="multi_role", email="multi@test.com", password="multi123",
) )
user.groups.add(self.viewer_group, self.editor_group) user.groups.add(self.viewer_group, self.editor_group)
@ -405,7 +404,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_direct_permission_assignment_overrides_group(self): def test_direct_permission_assignment_overrides_group(self):
"""Test that direct permission assignment works alongside group permissions.""" """Test that direct permission assignment works alongside group permissions."""
user = User.objects.create_user( user = User.objects.create_user(
username="special", email="special@test.com", password="special123" username="special", email="special@test.com", password="special123",
) )
user.groups.add(self.viewer_group) user.groups.add(self.viewer_group)
@ -428,7 +427,7 @@ class TestPermissionAssignment(TestCase):
def setUp(self): def setUp(self):
"""Set up test user.""" """Set up test user."""
self.user = User.objects.create_user( self.user = User.objects.create_user(
username="testuser", email="test@test.com", password="test123" username="testuser", email="test@test.com", password="test123",
) )
content_type = ContentType.objects.get_for_model(Document) content_type = ContentType.objects.get_for_model(Document)
self.view_permission, _ = Permission.objects.get_or_create( self.view_permission, _ = Permission.objects.get_or_create(
@ -500,7 +499,7 @@ class TestPermissionEdgeCases(TestCase):
def test_inactive_user_with_permission(self): def test_inactive_user_with_permission(self):
"""Test that inactive users are denied even with permission.""" """Test that inactive users are denied even with permission."""
user = User.objects.create_user( user = User.objects.create_user(
username="inactive", email="inactive@test.com", password="inactive123" username="inactive", email="inactive@test.com", password="inactive123",
) )
user.is_active = False user.is_active = False
user.save() user.save()

View file

@ -21,24 +21,21 @@ Tests cover:
from unittest import mock from unittest import mock
from django.db import transaction from django.db import transaction
from django.test import TestCase, override_settings from django.test import TestCase
from django.test import override_settings
from documents.ai_scanner import ( from documents.ai_scanner import AIDocumentScanner
AIScanResult, from documents.ai_scanner import AIScanResult
AIDocumentScanner, from documents.ai_scanner import get_ai_scanner
get_ai_scanner, from documents.models import Correspondent
) from documents.models import CustomField
from documents.models import ( from documents.models import Document
Correspondent, from documents.models import DocumentType
CustomField, from documents.models import StoragePath
Document, from documents.models import Tag
DocumentType, from documents.models import Workflow
StoragePath, from documents.models import WorkflowAction
Tag, from documents.models import WorkflowTrigger
Workflow,
WorkflowTrigger,
WorkflowAction,
)
class TestAIScanResult(TestCase): class TestAIScanResult(TestCase):
@ -100,7 +97,7 @@ class TestAIDocumentScannerInitialization(TestCase):
"""Test scanner initialization with custom confidence thresholds.""" """Test scanner initialization with custom confidence thresholds."""
scanner = AIDocumentScanner( scanner = AIDocumentScanner(
auto_apply_threshold=0.90, auto_apply_threshold=0.90,
suggest_threshold=0.70 suggest_threshold=0.70,
) )
self.assertEqual(scanner.auto_apply_threshold, 0.90) self.assertEqual(scanner.auto_apply_threshold, 0.90)
@ -145,14 +142,14 @@ class TestAIDocumentScannerInitialization(TestCase):
class TestAIDocumentScannerLazyLoading(TestCase): class TestAIDocumentScannerLazyLoading(TestCase):
"""Test lazy loading of ML components.""" """Test lazy loading of ML components."""
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_get_classifier_loads_successfully(self, mock_logger): def test_get_classifier_loads_successfully(self, mock_logger):
"""Test successful lazy loading of classifier.""" """Test successful lazy loading of classifier."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
# Mock the import and class # Mock the import and class
mock_classifier_instance = mock.MagicMock() mock_classifier_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier', with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
return_value=mock_classifier_instance) as mock_classifier_class: return_value=mock_classifier_instance) as mock_classifier_class:
classifier = scanner._get_classifier() classifier = scanner._get_classifier()
@ -161,13 +158,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
mock_classifier_class.assert_called_once() mock_classifier_class.assert_called_once()
mock_logger.info.assert_called_with("ML classifier loaded successfully") mock_logger.info.assert_called_with("ML classifier loaded successfully")
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_get_classifier_returns_cached_instance(self, mock_logger): def test_get_classifier_returns_cached_instance(self, mock_logger):
"""Test that classifier is only loaded once.""" """Test that classifier is only loaded once."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
mock_classifier_instance = mock.MagicMock() mock_classifier_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier', with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
return_value=mock_classifier_instance): return_value=mock_classifier_instance):
classifier1 = scanner._get_classifier() classifier1 = scanner._get_classifier()
classifier2 = scanner._get_classifier() classifier2 = scanner._get_classifier()
@ -175,12 +172,12 @@ class TestAIDocumentScannerLazyLoading(TestCase):
self.assertEqual(classifier1, classifier2) self.assertEqual(classifier1, classifier2)
self.assertIs(classifier1, classifier2) self.assertIs(classifier1, classifier2)
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_get_classifier_handles_import_error(self, mock_logger): def test_get_classifier_handles_import_error(self, mock_logger):
"""Test that classifier loading handles import errors gracefully.""" """Test that classifier loading handles import errors gracefully."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier', with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
side_effect=ImportError("Module not found")): side_effect=ImportError("Module not found")):
classifier = scanner._get_classifier() classifier = scanner._get_classifier()
@ -196,13 +193,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
self.assertIsNone(classifier) self.assertIsNone(classifier)
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_get_ner_extractor_loads_successfully(self, mock_logger): def test_get_ner_extractor_loads_successfully(self, mock_logger):
"""Test successful lazy loading of NER extractor.""" """Test successful lazy loading of NER extractor."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
mock_ner_instance = mock.MagicMock() mock_ner_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.DocumentNER', with mock.patch("documents.ai_scanner.DocumentNER",
return_value=mock_ner_instance) as mock_ner_class: return_value=mock_ner_instance) as mock_ner_class:
ner = scanner._get_ner_extractor() ner = scanner._get_ner_extractor()
@ -211,25 +208,25 @@ class TestAIDocumentScannerLazyLoading(TestCase):
mock_ner_class.assert_called_once() mock_ner_class.assert_called_once()
mock_logger.info.assert_called_with("NER extractor loaded successfully") mock_logger.info.assert_called_with("NER extractor loaded successfully")
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_get_ner_extractor_handles_error(self, mock_logger): def test_get_ner_extractor_handles_error(self, mock_logger):
"""Test NER extractor handles loading errors.""" """Test NER extractor handles loading errors."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
with mock.patch('documents.ai_scanner.DocumentNER', with mock.patch("documents.ai_scanner.DocumentNER",
side_effect=Exception("Failed to load")): side_effect=Exception("Failed to load")):
ner = scanner._get_ner_extractor() ner = scanner._get_ner_extractor()
self.assertIsNone(ner) self.assertIsNone(ner)
mock_logger.warning.assert_called() mock_logger.warning.assert_called()
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_get_semantic_search_loads_successfully(self, mock_logger): def test_get_semantic_search_loads_successfully(self, mock_logger):
"""Test successful lazy loading of semantic search.""" """Test successful lazy loading of semantic search."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
mock_search_instance = mock.MagicMock() mock_search_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.SemanticSearch', with mock.patch("documents.ai_scanner.SemanticSearch",
return_value=mock_search_instance) as mock_search_class: return_value=mock_search_instance) as mock_search_class:
search = scanner._get_semantic_search() search = scanner._get_semantic_search()
@ -238,13 +235,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
mock_search_class.assert_called_once() mock_search_class.assert_called_once()
mock_logger.info.assert_called_with("Semantic search loaded successfully") mock_logger.info.assert_called_with("Semantic search loaded successfully")
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_get_table_extractor_loads_successfully(self, mock_logger): def test_get_table_extractor_loads_successfully(self, mock_logger):
"""Test successful lazy loading of table extractor.""" """Test successful lazy loading of table extractor."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
mock_extractor_instance = mock.MagicMock() mock_extractor_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.TableExtractor', with mock.patch("documents.ai_scanner.TableExtractor",
return_value=mock_extractor_instance) as mock_extractor_class: return_value=mock_extractor_instance) as mock_extractor_class:
extractor = scanner._get_table_extractor() extractor = scanner._get_table_extractor()
@ -276,7 +273,7 @@ class TestExtractEntities(TestCase):
"dates": ["2024-01-01", "2024-12-31"], "dates": ["2024-01-01", "2024-12-31"],
"amounts": ["$1,000", "$500"], "amounts": ["$1,000", "$500"],
"locations": ["New York"], "locations": ["New York"],
"misc": ["Invoice#123"] "misc": ["Invoice#123"],
} }
scanner._ner_extractor = mock_ner scanner._ner_extractor = mock_ner
@ -320,7 +317,7 @@ class TestExtractEntities(TestCase):
self.assertEqual(entities, {}) self.assertEqual(entities, {})
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_extract_entities_handles_exception(self, mock_logger): def test_extract_entities_handles_exception(self, mock_logger):
"""Test that entity extraction handles exceptions gracefully.""" """Test that entity extraction handles exceptions gracefully."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -345,10 +342,10 @@ class TestSuggestTags(TestCase):
self.tag3 = Tag.objects.create(name="Tax", matching_algorithm=Tag.MATCH_AUTO) self.tag3 = Tag.objects.create(name="Tax", matching_algorithm=Tag.MATCH_AUTO)
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
@mock.patch('documents.ai_scanner.match_tags') @mock.patch("documents.ai_scanner.match_tags")
def test_suggest_tags_with_matched_tags(self, mock_match_tags): def test_suggest_tags_with_matched_tags(self, mock_match_tags):
"""Test tag suggestions from matching.""" """Test tag suggestions from matching."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -357,7 +354,7 @@ class TestSuggestTags(TestCase):
suggestions = scanner._suggest_tags( suggestions = scanner._suggest_tags(
self.document, self.document,
"Invoice from ACME Corp", "Invoice from ACME Corp",
{} {},
) )
# Should suggest both matched tags # Should suggest both matched tags
@ -370,14 +367,14 @@ class TestSuggestTags(TestCase):
for _, confidence in suggestions: for _, confidence in suggestions:
self.assertGreaterEqual(confidence, 0.6) self.assertGreaterEqual(confidence, 0.6)
@mock.patch('documents.ai_scanner.match_tags') @mock.patch("documents.ai_scanner.match_tags")
def test_suggest_tags_with_organization_entities(self, mock_match_tags): def test_suggest_tags_with_organization_entities(self, mock_match_tags):
"""Test tag suggestions based on organization entities.""" """Test tag suggestions based on organization entities."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
mock_match_tags.return_value = [] mock_match_tags.return_value = []
entities = { entities = {
"organizations": [{"text": "ACME Corp"}] "organizations": [{"text": "ACME Corp"}],
} }
suggestions = scanner._suggest_tags(self.document, "text", entities) suggestions = scanner._suggest_tags(self.document, "text", entities)
@ -386,7 +383,7 @@ class TestSuggestTags(TestCase):
tag_ids = [tag_id for tag_id, _ in suggestions] tag_ids = [tag_id for tag_id, _ in suggestions]
self.assertIn(self.tag2.id, tag_ids) self.assertIn(self.tag2.id, tag_ids)
@mock.patch('documents.ai_scanner.match_tags') @mock.patch("documents.ai_scanner.match_tags")
def test_suggest_tags_removes_duplicates(self, mock_match_tags): def test_suggest_tags_removes_duplicates(self, mock_match_tags):
"""Test that duplicate tags keep highest confidence.""" """Test that duplicate tags keep highest confidence."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -397,8 +394,8 @@ class TestSuggestTags(TestCase):
# Implementation should remove duplicates in actual code # Implementation should remove duplicates in actual code
@mock.patch('documents.ai_scanner.match_tags') @mock.patch("documents.ai_scanner.match_tags")
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_suggest_tags_handles_exception(self, mock_logger, mock_match_tags): def test_suggest_tags_handles_exception(self, mock_logger, mock_match_tags):
"""Test tag suggestion handles exceptions.""" """Test tag suggestion handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -417,18 +414,18 @@ class TestDetectCorrespondent(TestCase):
"""Set up test correspondents.""" """Set up test correspondents."""
self.correspondent1 = Correspondent.objects.create( self.correspondent1 = Correspondent.objects.create(
name="ACME Corporation", name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO matching_algorithm=Correspondent.MATCH_AUTO,
) )
self.correspondent2 = Correspondent.objects.create( self.correspondent2 = Correspondent.objects.create(
name="TechStart Inc", name="TechStart Inc",
matching_algorithm=Correspondent.MATCH_AUTO matching_algorithm=Correspondent.MATCH_AUTO,
) )
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
@mock.patch('documents.ai_scanner.match_correspondents') @mock.patch("documents.ai_scanner.match_correspondents")
def test_detect_correspondent_with_match(self, mock_match): def test_detect_correspondent_with_match(self, mock_match):
"""Test correspondent detection with successful match.""" """Test correspondent detection with successful match."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -441,7 +438,7 @@ class TestDetectCorrespondent(TestCase):
self.assertEqual(corr_id, self.correspondent1.id) self.assertEqual(corr_id, self.correspondent1.id)
self.assertEqual(confidence, 0.85) self.assertEqual(confidence, 0.85)
@mock.patch('documents.ai_scanner.match_correspondents') @mock.patch("documents.ai_scanner.match_correspondents")
def test_detect_correspondent_without_match(self, mock_match): def test_detect_correspondent_without_match(self, mock_match):
"""Test correspondent detection without match.""" """Test correspondent detection without match."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -451,14 +448,14 @@ class TestDetectCorrespondent(TestCase):
self.assertIsNone(result) self.assertIsNone(result)
@mock.patch('documents.ai_scanner.match_correspondents') @mock.patch("documents.ai_scanner.match_correspondents")
def test_detect_correspondent_from_ner_entities(self, mock_match): def test_detect_correspondent_from_ner_entities(self, mock_match):
"""Test correspondent detection from NER organizations.""" """Test correspondent detection from NER organizations."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
mock_match.return_value = [] mock_match.return_value = []
entities = { entities = {
"organizations": [{"text": "ACME Corporation"}] "organizations": [{"text": "ACME Corporation"}],
} }
result = scanner._detect_correspondent(self.document, "text", entities) result = scanner._detect_correspondent(self.document, "text", entities)
@ -468,8 +465,8 @@ class TestDetectCorrespondent(TestCase):
self.assertEqual(corr_id, self.correspondent1.id) self.assertEqual(corr_id, self.correspondent1.id)
self.assertEqual(confidence, 0.70) self.assertEqual(confidence, 0.70)
@mock.patch('documents.ai_scanner.match_correspondents') @mock.patch("documents.ai_scanner.match_correspondents")
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_detect_correspondent_handles_exception(self, mock_logger, mock_match): def test_detect_correspondent_handles_exception(self, mock_logger, mock_match):
"""Test correspondent detection handles exceptions.""" """Test correspondent detection handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -488,18 +485,18 @@ class TestClassifyDocumentType(TestCase):
"""Set up test document types.""" """Set up test document types."""
self.doc_type1 = DocumentType.objects.create( self.doc_type1 = DocumentType.objects.create(
name="Invoice", name="Invoice",
matching_algorithm=DocumentType.MATCH_AUTO matching_algorithm=DocumentType.MATCH_AUTO,
) )
self.doc_type2 = DocumentType.objects.create( self.doc_type2 = DocumentType.objects.create(
name="Receipt", name="Receipt",
matching_algorithm=DocumentType.MATCH_AUTO matching_algorithm=DocumentType.MATCH_AUTO,
) )
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
@mock.patch('documents.ai_scanner.match_document_types') @mock.patch("documents.ai_scanner.match_document_types")
def test_classify_document_type_with_match(self, mock_match): def test_classify_document_type_with_match(self, mock_match):
"""Test document type classification with match.""" """Test document type classification with match."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -512,7 +509,7 @@ class TestClassifyDocumentType(TestCase):
self.assertEqual(type_id, self.doc_type1.id) self.assertEqual(type_id, self.doc_type1.id)
self.assertEqual(confidence, 0.85) self.assertEqual(confidence, 0.85)
@mock.patch('documents.ai_scanner.match_document_types') @mock.patch("documents.ai_scanner.match_document_types")
def test_classify_document_type_without_match(self, mock_match): def test_classify_document_type_without_match(self, mock_match):
"""Test document type classification without match.""" """Test document type classification without match."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -522,8 +519,8 @@ class TestClassifyDocumentType(TestCase):
self.assertIsNone(result) self.assertIsNone(result)
@mock.patch('documents.ai_scanner.match_document_types') @mock.patch("documents.ai_scanner.match_document_types")
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_classify_document_type_handles_exception(self, mock_logger, mock_match): def test_classify_document_type_handles_exception(self, mock_logger, mock_match):
"""Test classification handles exceptions.""" """Test classification handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -543,14 +540,14 @@ class TestSuggestStoragePath(TestCase):
self.storage_path1 = StoragePath.objects.create( self.storage_path1 = StoragePath.objects.create(
name="Invoices", name="Invoices",
path="/documents/invoices", path="/documents/invoices",
matching_algorithm=StoragePath.MATCH_AUTO matching_algorithm=StoragePath.MATCH_AUTO,
) )
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
@mock.patch('documents.ai_scanner.match_storage_paths') @mock.patch("documents.ai_scanner.match_storage_paths")
def test_suggest_storage_path_with_match(self, mock_match): def test_suggest_storage_path_with_match(self, mock_match):
"""Test storage path suggestion with match.""" """Test storage path suggestion with match."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -564,7 +561,7 @@ class TestSuggestStoragePath(TestCase):
self.assertEqual(path_id, self.storage_path1.id) self.assertEqual(path_id, self.storage_path1.id)
self.assertEqual(confidence, 0.80) self.assertEqual(confidence, 0.80)
@mock.patch('documents.ai_scanner.match_storage_paths') @mock.patch("documents.ai_scanner.match_storage_paths")
def test_suggest_storage_path_without_match(self, mock_match): def test_suggest_storage_path_without_match(self, mock_match):
"""Test storage path suggestion without match.""" """Test storage path suggestion without match."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -575,8 +572,8 @@ class TestSuggestStoragePath(TestCase):
self.assertIsNone(result) self.assertIsNone(result)
@mock.patch('documents.ai_scanner.match_storage_paths') @mock.patch("documents.ai_scanner.match_storage_paths")
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_suggest_storage_path_handles_exception(self, mock_logger, mock_match): def test_suggest_storage_path_handles_exception(self, mock_logger, mock_match):
"""Test storage path suggestion handles exceptions.""" """Test storage path suggestion handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -596,19 +593,19 @@ class TestExtractCustomFields(TestCase):
"""Set up test custom fields.""" """Set up test custom fields."""
self.field_date = CustomField.objects.create( self.field_date = CustomField.objects.create(
name="Invoice Date", name="Invoice Date",
data_type=CustomField.FieldDataType.DATE data_type=CustomField.FieldDataType.DATE,
) )
self.field_amount = CustomField.objects.create( self.field_amount = CustomField.objects.create(
name="Total Amount", name="Total Amount",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_email = CustomField.objects.create( self.field_email = CustomField.objects.create(
name="Contact Email", name="Contact Email",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
def test_extract_custom_fields_with_entities(self): def test_extract_custom_fields_with_entities(self):
@ -618,7 +615,7 @@ class TestExtractCustomFields(TestCase):
entities = { entities = {
"dates": [{"text": "2024-01-01"}], "dates": [{"text": "2024-01-01"}],
"amounts": [{"text": "$1,000"}], "amounts": [{"text": "$1,000"}],
"emails": ["test@example.com"] "emails": ["test@example.com"],
} }
fields = scanner._extract_custom_fields(self.document, "text", entities) fields = scanner._extract_custom_fields(self.document, "text", entities)
@ -638,12 +635,12 @@ class TestExtractCustomFields(TestCase):
# Should return empty dict # Should return empty dict
self.assertEqual(fields, {}) self.assertEqual(fields, {})
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_extract_custom_fields_handles_exception(self, mock_logger): def test_extract_custom_fields_handles_exception(self, mock_logger):
"""Test custom field extraction handles exceptions.""" """Test custom field extraction handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
with mock.patch.object(CustomField.objects, 'all', with mock.patch.object(CustomField.objects, "all",
side_effect=Exception("DB error")): side_effect=Exception("DB error")):
fields = scanner._extract_custom_fields(self.document, "text", {}) fields = scanner._extract_custom_fields(self.document, "text", {})
@ -658,31 +655,31 @@ class TestExtractFieldValue(TestCase):
"""Set up test fields.""" """Set up test fields."""
self.field_date = CustomField.objects.create( self.field_date = CustomField.objects.create(
name="Invoice Date", name="Invoice Date",
data_type=CustomField.FieldDataType.DATE data_type=CustomField.FieldDataType.DATE,
) )
self.field_amount = CustomField.objects.create( self.field_amount = CustomField.objects.create(
name="Total Amount", name="Total Amount",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_invoice = CustomField.objects.create( self.field_invoice = CustomField.objects.create(
name="Invoice Number", name="Invoice Number",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_email = CustomField.objects.create( self.field_email = CustomField.objects.create(
name="Contact Email", name="Contact Email",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_phone = CustomField.objects.create( self.field_phone = CustomField.objects.create(
name="Phone Number", name="Phone Number",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_person = CustomField.objects.create( self.field_person = CustomField.objects.create(
name="Person Name", name="Person Name",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_company = CustomField.objects.create( self.field_company = CustomField.objects.create(
name="Company Name", name="Company Name",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
def test_extract_field_value_date(self): def test_extract_field_value_date(self):
@ -691,7 +688,7 @@ class TestExtractFieldValue(TestCase):
entities = {"dates": [{"text": "2024-01-01"}]} entities = {"dates": [{"text": "2024-01-01"}]}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_date, "text", entities self.field_date, "text", entities,
) )
self.assertEqual(value, "2024-01-01") self.assertEqual(value, "2024-01-01")
@ -703,7 +700,7 @@ class TestExtractFieldValue(TestCase):
entities = {"amounts": [{"text": "$1,000"}]} entities = {"amounts": [{"text": "$1,000"}]}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_amount, "text", entities self.field_amount, "text", entities,
) )
self.assertEqual(value, "$1,000") self.assertEqual(value, "$1,000")
@ -715,7 +712,7 @@ class TestExtractFieldValue(TestCase):
entities = {"invoice_numbers": ["INV-12345"]} entities = {"invoice_numbers": ["INV-12345"]}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_invoice, "text", entities self.field_invoice, "text", entities,
) )
self.assertEqual(value, "INV-12345") self.assertEqual(value, "INV-12345")
@ -727,7 +724,7 @@ class TestExtractFieldValue(TestCase):
entities = {"emails": ["test@example.com"]} entities = {"emails": ["test@example.com"]}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_email, "text", entities self.field_email, "text", entities,
) )
self.assertEqual(value, "test@example.com") self.assertEqual(value, "test@example.com")
@ -739,7 +736,7 @@ class TestExtractFieldValue(TestCase):
entities = {"phones": ["+1-555-1234"]} entities = {"phones": ["+1-555-1234"]}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_phone, "text", entities self.field_phone, "text", entities,
) )
self.assertEqual(value, "+1-555-1234") self.assertEqual(value, "+1-555-1234")
@ -751,7 +748,7 @@ class TestExtractFieldValue(TestCase):
entities = {"persons": [{"text": "John Doe"}]} entities = {"persons": [{"text": "John Doe"}]}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_person, "text", entities self.field_person, "text", entities,
) )
self.assertEqual(value, "John Doe") self.assertEqual(value, "John Doe")
@ -763,7 +760,7 @@ class TestExtractFieldValue(TestCase):
entities = {"organizations": [{"text": "ACME Corp"}]} entities = {"organizations": [{"text": "ACME Corp"}]}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_company, "text", entities self.field_company, "text", entities,
) )
self.assertEqual(value, "ACME Corp") self.assertEqual(value, "ACME Corp")
@ -775,7 +772,7 @@ class TestExtractFieldValue(TestCase):
entities = {} entities = {}
value, confidence = scanner._extract_field_value( value, confidence = scanner._extract_field_value(
self.field_date, "text", entities self.field_date, "text", entities,
) )
self.assertIsNone(value) self.assertIsNone(value)
@ -789,23 +786,23 @@ class TestSuggestWorkflows(TestCase):
"""Set up test workflows.""" """Set up test workflows."""
self.workflow1 = Workflow.objects.create( self.workflow1 = Workflow.objects.create(
name="Invoice Processing", name="Invoice Processing",
enabled=True enabled=True,
) )
self.trigger1 = WorkflowTrigger.objects.create( self.trigger1 = WorkflowTrigger.objects.create(
workflow=self.workflow1, workflow=self.workflow1,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
) )
self.workflow2 = Workflow.objects.create( self.workflow2 = Workflow.objects.create(
name="Document Archival", name="Document Archival",
enabled=True enabled=True,
) )
self.trigger2 = WorkflowTrigger.objects.create( self.trigger2 = WorkflowTrigger.objects.create(
workflow=self.workflow2, workflow=self.workflow2,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
) )
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
def test_suggest_workflows_with_matches(self): def test_suggest_workflows_with_matches(self):
@ -820,7 +817,7 @@ class TestSuggestWorkflows(TestCase):
# Create action for workflow # Create action for workflow
WorkflowAction.objects.create( WorkflowAction.objects.create(
workflow=self.workflow1, workflow=self.workflow1,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
) )
suggestions = scanner._suggest_workflows(self.document, "text", scan_result) suggestions = scanner._suggest_workflows(self.document, "text", scan_result)
@ -841,17 +838,17 @@ class TestSuggestWorkflows(TestCase):
# Should not suggest any (confidence too low) # Should not suggest any (confidence too low)
self.assertEqual(len(suggestions), 0) self.assertEqual(len(suggestions), 0)
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_suggest_workflows_handles_exception(self, mock_logger): def test_suggest_workflows_handles_exception(self, mock_logger):
"""Test workflow suggestion handles exceptions.""" """Test workflow suggestion handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
scan_result = AIScanResult() scan_result = AIScanResult()
with mock.patch.object(Workflow.objects, 'filter', with mock.patch.object(Workflow.objects, "filter",
side_effect=Exception("DB error")): side_effect=Exception("DB error")):
suggestions = scanner._suggest_workflows( suggestions = scanner._suggest_workflows(
self.document, "text", scan_result self.document, "text", scan_result,
) )
self.assertEqual(suggestions, []) self.assertEqual(suggestions, [])
@ -865,11 +862,11 @@ class TestEvaluateWorkflowMatch(TestCase):
"""Set up test workflow.""" """Set up test workflow."""
self.workflow = Workflow.objects.create( self.workflow = Workflow.objects.create(
name="Test Workflow", name="Test Workflow",
enabled=True enabled=True,
) )
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
def test_evaluate_workflow_match_base_confidence(self): def test_evaluate_workflow_match_base_confidence(self):
@ -878,7 +875,7 @@ class TestEvaluateWorkflowMatch(TestCase):
scan_result = AIScanResult() scan_result = AIScanResult()
confidence = scanner._evaluate_workflow_match( confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result self.workflow, self.document, scan_result,
) )
self.assertEqual(confidence, 0.5) self.assertEqual(confidence, 0.5)
@ -892,11 +889,11 @@ class TestEvaluateWorkflowMatch(TestCase):
# Create action for workflow # Create action for workflow
WorkflowAction.objects.create( WorkflowAction.objects.create(
workflow=self.workflow, workflow=self.workflow,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
) )
confidence = scanner._evaluate_workflow_match( confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result self.workflow, self.document, scan_result,
) )
self.assertGreater(confidence, 0.5) self.assertGreater(confidence, 0.5)
@ -908,7 +905,7 @@ class TestEvaluateWorkflowMatch(TestCase):
scan_result.correspondent = (1, 0.90) scan_result.correspondent = (1, 0.90)
confidence = scanner._evaluate_workflow_match( confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result self.workflow, self.document, scan_result,
) )
self.assertGreater(confidence, 0.5) self.assertGreater(confidence, 0.5)
@ -920,7 +917,7 @@ class TestEvaluateWorkflowMatch(TestCase):
scan_result.tags = [(1, 0.80), (2, 0.75)] scan_result.tags = [(1, 0.80), (2, 0.75)]
confidence = scanner._evaluate_workflow_match( confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result self.workflow, self.document, scan_result,
) )
self.assertGreater(confidence, 0.5) self.assertGreater(confidence, 0.5)
@ -936,11 +933,11 @@ class TestEvaluateWorkflowMatch(TestCase):
# Create action # Create action
WorkflowAction.objects.create( WorkflowAction.objects.create(
workflow=self.workflow, workflow=self.workflow,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
) )
confidence = scanner._evaluate_workflow_match( confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result self.workflow, self.document, scan_result,
) )
self.assertLessEqual(confidence, 1.0) self.assertLessEqual(confidence, 1.0)
@ -953,7 +950,7 @@ class TestSuggestTitle(TestCase):
"""Set up test document.""" """Set up test document."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
def test_suggest_title_with_all_entities(self): def test_suggest_title_with_all_entities(self):
@ -963,7 +960,7 @@ class TestSuggestTitle(TestCase):
entities = { entities = {
"document_type": "Invoice", "document_type": "Invoice",
"organizations": [{"text": "ACME Corporation"}], "organizations": [{"text": "ACME Corporation"}],
"dates": [{"text": "2024-01-01"}] "dates": [{"text": "2024-01-01"}],
} }
title = scanner._suggest_title(self.document, "text", entities) title = scanner._suggest_title(self.document, "text", entities)
@ -978,7 +975,7 @@ class TestSuggestTitle(TestCase):
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
entities = { entities = {
"organizations": [{"text": "TechStart Inc"}] "organizations": [{"text": "TechStart Inc"}],
} }
title = scanner._suggest_title(self.document, "text", entities) title = scanner._suggest_title(self.document, "text", entities)
@ -1002,7 +999,7 @@ class TestSuggestTitle(TestCase):
long_org = "A" * 100 long_org = "A" * 100
entities = { entities = {
"organizations": [{"text": long_org}], "organizations": [{"text": long_org}],
"dates": [{"text": "2024-01-01"}] "dates": [{"text": "2024-01-01"}],
} }
title = scanner._suggest_title(self.document, "text", entities) title = scanner._suggest_title(self.document, "text", entities)
@ -1010,7 +1007,7 @@ class TestSuggestTitle(TestCase):
self.assertIsNotNone(title) self.assertIsNotNone(title)
self.assertLessEqual(len(title), 127) self.assertLessEqual(len(title), 127)
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_suggest_title_handles_exception(self, mock_logger): def test_suggest_title_handles_exception(self, mock_logger):
"""Test title suggestion handles exceptions.""" """Test title suggestion handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -1034,7 +1031,7 @@ class TestExtractTables(TestCase):
mock_extractor = mock.MagicMock() mock_extractor = mock.MagicMock()
mock_extractor.extract_tables_from_image.return_value = [ mock_extractor.extract_tables_from_image.return_value = [
{"data": [[1, 2], [3, 4]], "headers": ["A", "B"]} {"data": [[1, 2], [3, 4]], "headers": ["A", "B"]},
] ]
scanner._table_extractor = mock_extractor scanner._table_extractor = mock_extractor
@ -1053,7 +1050,7 @@ class TestExtractTables(TestCase):
self.assertEqual(tables, []) self.assertEqual(tables, [])
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_extract_tables_handles_exception(self, mock_logger): def test_extract_tables_handles_exception(self, mock_logger):
"""Test table extraction handles exceptions.""" """Test table extraction handles exceptions."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -1075,17 +1072,17 @@ class TestScanDocument(TestCase):
"""Set up test document.""" """Set up test document."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Invoice from ACME Corporation dated 2024-01-01" content="Invoice from ACME Corporation dated 2024-01-01",
) )
@mock.patch.object(AIDocumentScanner, '_extract_entities') @mock.patch.object(AIDocumentScanner, "_extract_entities")
@mock.patch.object(AIDocumentScanner, '_suggest_tags') @mock.patch.object(AIDocumentScanner, "_suggest_tags")
@mock.patch.object(AIDocumentScanner, '_detect_correspondent') @mock.patch.object(AIDocumentScanner, "_detect_correspondent")
@mock.patch.object(AIDocumentScanner, '_classify_document_type') @mock.patch.object(AIDocumentScanner, "_classify_document_type")
@mock.patch.object(AIDocumentScanner, '_suggest_storage_path') @mock.patch.object(AIDocumentScanner, "_suggest_storage_path")
@mock.patch.object(AIDocumentScanner, '_extract_custom_fields') @mock.patch.object(AIDocumentScanner, "_extract_custom_fields")
@mock.patch.object(AIDocumentScanner, '_suggest_workflows') @mock.patch.object(AIDocumentScanner, "_suggest_workflows")
@mock.patch.object(AIDocumentScanner, '_suggest_title') @mock.patch.object(AIDocumentScanner, "_suggest_title")
def test_scan_document_orchestrates_all_methods( def test_scan_document_orchestrates_all_methods(
self, self,
mock_title, mock_title,
@ -1095,7 +1092,7 @@ class TestScanDocument(TestCase):
mock_doc_type, mock_doc_type,
mock_correspondent, mock_correspondent,
mock_tags, mock_tags,
mock_entities mock_entities,
): ):
"""Test that scan_document calls all extraction methods.""" """Test that scan_document calls all extraction methods."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -1127,26 +1124,26 @@ class TestScanDocument(TestCase):
self.assertEqual(result.correspondent, (1, 0.90)) self.assertEqual(result.correspondent, (1, 0.90))
self.assertEqual(result.document_type, (1, 0.80)) self.assertEqual(result.document_type, (1, 0.80))
@mock.patch.object(AIDocumentScanner, '_extract_tables') @mock.patch.object(AIDocumentScanner, "_extract_tables")
def test_scan_document_extracts_tables_when_enabled(self, mock_extract_tables): def test_scan_document_extracts_tables_when_enabled(self, mock_extract_tables):
"""Test that tables are extracted when OCR is enabled and file path provided.""" """Test that tables are extracted when OCR is enabled and file path provided."""
scanner = AIDocumentScanner(enable_advanced_ocr=True) scanner = AIDocumentScanner(enable_advanced_ocr=True)
mock_extract_tables.return_value = [{"data": "test"}] mock_extract_tables.return_value = [{"data": "test"}]
# Mock other methods to avoid complexity # Mock other methods to avoid complexity
with mock.patch.object(scanner, '_extract_entities', return_value={}), \ with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document( result = scanner.scan_document(
self.document, self.document,
"Document text", "Document text",
original_file_path="/path/to/file.pdf" original_file_path="/path/to/file.pdf",
) )
mock_extract_tables.assert_called_once_with("/path/to/file.pdf") mock_extract_tables.assert_called_once_with("/path/to/file.pdf")
@ -1156,15 +1153,15 @@ class TestScanDocument(TestCase):
"""Test that tables are not extracted when file path is not provided.""" """Test that tables are not extracted when file path is not provided."""
scanner = AIDocumentScanner(enable_advanced_ocr=True) scanner = AIDocumentScanner(enable_advanced_ocr=True)
with mock.patch.object(scanner, '_extract_tables') as mock_extract_tables, \ with mock.patch.object(scanner, "_extract_tables") as mock_extract_tables, \
mock.patch.object(scanner, '_extract_entities', return_value={}), \ mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, "Document text") result = scanner.scan_document(self.document, "Document text")
@ -1183,11 +1180,11 @@ class TestApplyScanResults(TestCase):
self.doc_type = DocumentType.objects.create(name="Invoice") self.doc_type = DocumentType.objects.create(name="Invoice")
self.storage_path = StoragePath.objects.create( self.storage_path = StoragePath.objects.create(
name="Invoices", name="Invoices",
path="/invoices" path="/invoices",
) )
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
def test_apply_scan_results_auto_applies_high_confidence(self): def test_apply_scan_results_auto_applies_high_confidence(self):
@ -1203,7 +1200,7 @@ class TestApplyScanResults(TestCase):
result = scanner.apply_scan_results( result = scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
# Verify auto-applied # Verify auto-applied
@ -1222,7 +1219,7 @@ class TestApplyScanResults(TestCase):
"""Test that medium confidence items are suggested, not applied.""" """Test that medium confidence items are suggested, not applied."""
scanner = AIDocumentScanner( scanner = AIDocumentScanner(
auto_apply_threshold=0.80, auto_apply_threshold=0.80,
suggest_threshold=0.60 suggest_threshold=0.60,
) )
scan_result = AIScanResult() scan_result = AIScanResult()
@ -1232,7 +1229,7 @@ class TestApplyScanResults(TestCase):
result = scanner.apply_scan_results( result = scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
# Verify suggested but not applied # Verify suggested but not applied
@ -1255,7 +1252,7 @@ class TestApplyScanResults(TestCase):
result = scanner.apply_scan_results( result = scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=False auto_apply=False,
) )
# Verify nothing was applied # Verify nothing was applied
@ -1268,21 +1265,21 @@ class TestApplyScanResults(TestCase):
scan_result = AIScanResult() scan_result = AIScanResult()
scan_result.correspondent = (self.correspondent.id, 0.90) scan_result.correspondent = (self.correspondent.id, 0.90)
with mock.patch.object(self.document, 'save', with mock.patch.object(self.document, "save",
side_effect=Exception("Save failed")): side_effect=Exception("Save failed")):
with self.assertRaises(Exception): with self.assertRaises(Exception):
with transaction.atomic(): with transaction.atomic():
scanner.apply_scan_results( scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
# Verify transaction was rolled back # Verify transaction was rolled back
self.document.refresh_from_db() self.document.refresh_from_db()
self.assertIsNone(self.document.correspondent) self.assertIsNone(self.document.correspondent)
@mock.patch('documents.ai_scanner.logger') @mock.patch("documents.ai_scanner.logger")
def test_apply_scan_results_handles_exception(self, mock_logger): def test_apply_scan_results_handles_exception(self, mock_logger):
"""Test that apply_scan_results handles exceptions gracefully.""" """Test that apply_scan_results handles exceptions gracefully."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
@ -1293,7 +1290,7 @@ class TestApplyScanResults(TestCase):
scanner.apply_scan_results( scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
mock_logger.error.assert_called() mock_logger.error.assert_called()
@ -1323,21 +1320,21 @@ class TestEdgeCasesAndErrorHandling(TestCase):
"""Set up test document.""" """Set up test document."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
def test_scan_document_with_empty_text(self): def test_scan_document_with_empty_text(self):
"""Test scanning document with empty text.""" """Test scanning document with empty text."""
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
with mock.patch.object(scanner, '_extract_entities', return_value={}), \ with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, "") result = scanner.scan_document(self.document, "")
@ -1349,14 +1346,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
long_text = "A" * 100000 long_text = "A" * 100000
with mock.patch.object(scanner, '_extract_entities', return_value={}), \ with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, long_text) result = scanner.scan_document(self.document, long_text)
@ -1367,14 +1364,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
scanner = AIDocumentScanner() scanner = AIDocumentScanner()
special_text = "Test with émojis 😀 and special chars: <>{}[]|\\`~" special_text = "Test with émojis 😀 and special chars: <>{}[]|\\`~"
with mock.patch.object(scanner, '_extract_entities', return_value={}), \ with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, special_text) result = scanner.scan_document(self.document, special_text)
@ -1388,7 +1385,7 @@ class TestEdgeCasesAndErrorHandling(TestCase):
result = scanner.apply_scan_results( result = scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
self.assertEqual(result["applied"]["tags"], []) self.assertEqual(result["applied"]["tags"], [])
@ -1403,12 +1400,12 @@ class TestEdgeCasesAndErrorHandling(TestCase):
# Test extreme values # Test extreme values
scanner_low = AIDocumentScanner( scanner_low = AIDocumentScanner(
auto_apply_threshold=0.01, auto_apply_threshold=0.01,
suggest_threshold=0.01 suggest_threshold=0.01,
) )
self.assertEqual(scanner_low.auto_apply_threshold, 0.01) self.assertEqual(scanner_low.auto_apply_threshold, 0.01)
scanner_high = AIDocumentScanner( scanner_high = AIDocumentScanner(
auto_apply_threshold=0.99, auto_apply_threshold=0.99,
suggest_threshold=0.80 suggest_threshold=0.80,
) )
self.assertEqual(scanner_high.auto_apply_threshold, 0.99) self.assertEqual(scanner_high.auto_apply_threshold, 0.99)

View file

@ -8,24 +8,21 @@ document consumption to metadata application.
from unittest import mock from unittest import mock
from django.test import TestCase, TransactionTestCase from django.test import TestCase
from django.test import TransactionTestCase
from documents.ai_scanner import ( from documents.ai_scanner import AIDocumentScanner
AIDocumentScanner, from documents.ai_scanner import AIScanResult
AIScanResult, from documents.ai_scanner import get_ai_scanner
get_ai_scanner, from documents.models import Correspondent
) from documents.models import CustomField
from documents.models import ( from documents.models import Document
Correspondent, from documents.models import DocumentType
CustomField, from documents.models import StoragePath
Document, from documents.models import Tag
DocumentType, from documents.models import Workflow
StoragePath, from documents.models import WorkflowAction
Tag, from documents.models import WorkflowTrigger
Workflow,
WorkflowTrigger,
WorkflowAction,
)
class TestAIScannerIntegrationBasic(TestCase): class TestAIScannerIntegrationBasic(TestCase):
@ -35,49 +32,49 @@ class TestAIScannerIntegrationBasic(TestCase):
"""Set up test data.""" """Set up test data."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Invoice from ACME Corporation", title="Invoice from ACME Corporation",
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000" content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000",
) )
self.tag_invoice = Tag.objects.create( self.tag_invoice = Tag.objects.create(
name="Invoice", name="Invoice",
matching_algorithm=Tag.MATCH_AUTO, matching_algorithm=Tag.MATCH_AUTO,
match="invoice" match="invoice",
) )
self.tag_important = Tag.objects.create( self.tag_important = Tag.objects.create(
name="Important", name="Important",
matching_algorithm=Tag.MATCH_AUTO, matching_algorithm=Tag.MATCH_AUTO,
match="total" match="total",
) )
self.correspondent = Correspondent.objects.create( self.correspondent = Correspondent.objects.create(
name="ACME Corporation", name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO, matching_algorithm=Correspondent.MATCH_AUTO,
match="acme" match="acme",
) )
self.doc_type = DocumentType.objects.create( self.doc_type = DocumentType.objects.create(
name="Invoice", name="Invoice",
matching_algorithm=DocumentType.MATCH_AUTO, matching_algorithm=DocumentType.MATCH_AUTO,
match="invoice" match="invoice",
) )
self.storage_path = StoragePath.objects.create( self.storage_path = StoragePath.objects.create(
name="Invoices", name="Invoices",
path="/invoices", path="/invoices",
matching_algorithm=StoragePath.MATCH_AUTO, matching_algorithm=StoragePath.MATCH_AUTO,
match="invoice" match="invoice",
) )
@mock.patch('documents.ai_scanner.match_tags') @mock.patch("documents.ai_scanner.match_tags")
@mock.patch('documents.ai_scanner.match_correspondents') @mock.patch("documents.ai_scanner.match_correspondents")
@mock.patch('documents.ai_scanner.match_document_types') @mock.patch("documents.ai_scanner.match_document_types")
@mock.patch('documents.ai_scanner.match_storage_paths') @mock.patch("documents.ai_scanner.match_storage_paths")
def test_full_scan_and_apply_workflow( def test_full_scan_and_apply_workflow(
self, self,
mock_storage, mock_storage,
mock_types, mock_types,
mock_correspondents, mock_correspondents,
mock_tags mock_tags,
): ):
"""Test complete workflow from scan to application.""" """Test complete workflow from scan to application."""
# Mock the matching functions to return our test data # Mock the matching functions to return our test data
@ -91,7 +88,7 @@ class TestAIScannerIntegrationBasic(TestCase):
# Scan the document # Scan the document
scan_result = scanner.scan_document( scan_result = scanner.scan_document(
self.document, self.document,
self.document.content self.document.content,
) )
# Verify scan results # Verify scan results
@ -105,7 +102,7 @@ class TestAIScannerIntegrationBasic(TestCase):
result = scanner.apply_scan_results( result = scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
# Verify application # Verify application
@ -118,7 +115,7 @@ class TestAIScannerIntegrationBasic(TestCase):
self.assertEqual(self.document.document_type, self.doc_type) self.assertEqual(self.document.document_type, self.doc_type)
self.assertEqual(self.document.storage_path, self.storage_path) self.assertEqual(self.document.storage_path, self.storage_path)
@mock.patch('documents.ai_scanner.match_tags') @mock.patch("documents.ai_scanner.match_tags")
def test_scan_with_no_matches(self, mock_tags): def test_scan_with_no_matches(self, mock_tags):
"""Test scanning when no matches are found.""" """Test scanning when no matches are found."""
mock_tags.return_value = [] mock_tags.return_value = []
@ -127,7 +124,7 @@ class TestAIScannerIntegrationBasic(TestCase):
scan_result = scanner.scan_document( scan_result = scanner.scan_document(
self.document, self.document,
"Some random text with no matches" "Some random text with no matches",
) )
# Should return empty results # Should return empty results
@ -143,24 +140,24 @@ class TestAIScannerIntegrationCustomFields(TestCase):
"""Set up test data with custom fields.""" """Set up test data with custom fields."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Invoice", title="Invoice",
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com" content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com",
) )
self.field_date = CustomField.objects.create( self.field_date = CustomField.objects.create(
name="Invoice Date", name="Invoice Date",
data_type=CustomField.FieldDataType.DATE data_type=CustomField.FieldDataType.DATE,
) )
self.field_number = CustomField.objects.create( self.field_number = CustomField.objects.create(
name="Invoice Number", name="Invoice Number",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_amount = CustomField.objects.create( self.field_amount = CustomField.objects.create(
name="Total Amount", name="Total Amount",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
self.field_email = CustomField.objects.create( self.field_email = CustomField.objects.create(
name="Contact Email", name="Contact Email",
data_type=CustomField.FieldDataType.STRING data_type=CustomField.FieldDataType.STRING,
) )
def test_custom_field_extraction_integration(self): def test_custom_field_extraction_integration(self):
@ -173,7 +170,7 @@ class TestAIScannerIntegrationCustomFields(TestCase):
"dates": [{"text": "2024-01-01"}], "dates": [{"text": "2024-01-01"}],
"amounts": [{"text": "$1,500"}], "amounts": [{"text": "$1,500"}],
"invoice_numbers": ["INV-123"], "invoice_numbers": ["INV-123"],
"emails": ["john@example.com"] "emails": ["john@example.com"],
} }
scanner._ner_extractor = mock_ner scanner._ner_extractor = mock_ner
@ -196,29 +193,29 @@ class TestAIScannerIntegrationWorkflows(TestCase):
"""Set up test workflows.""" """Set up test workflows."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Invoice", title="Invoice",
content="Invoice document" content="Invoice document",
) )
self.workflow1 = Workflow.objects.create( self.workflow1 = Workflow.objects.create(
name="Invoice Processing", name="Invoice Processing",
enabled=True enabled=True,
) )
self.trigger1 = WorkflowTrigger.objects.create( self.trigger1 = WorkflowTrigger.objects.create(
workflow=self.workflow1, workflow=self.workflow1,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
) )
self.action1 = WorkflowAction.objects.create( self.action1 = WorkflowAction.objects.create(
workflow=self.workflow1, workflow=self.workflow1,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
) )
self.workflow2 = Workflow.objects.create( self.workflow2 = Workflow.objects.create(
name="Archive Documents", name="Archive Documents",
enabled=True enabled=True,
) )
self.trigger2 = WorkflowTrigger.objects.create( self.trigger2 = WorkflowTrigger.objects.create(
workflow=self.workflow2, workflow=self.workflow2,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
) )
def test_workflow_suggestion_integration(self): def test_workflow_suggestion_integration(self):
@ -234,7 +231,7 @@ class TestAIScannerIntegrationWorkflows(TestCase):
workflows = scanner._suggest_workflows( workflows = scanner._suggest_workflows(
self.document, self.document,
self.document.content, self.document.content,
scan_result scan_result,
) )
# Should suggest workflows # Should suggest workflows
@ -250,7 +247,7 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
"""Set up test data.""" """Set up test data."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
self.tag = Tag.objects.create(name="TestTag") self.tag = Tag.objects.create(name="TestTag")
self.correspondent = Correspondent.objects.create(name="TestCorp") self.correspondent = Correspondent.objects.create(name="TestCorp")
@ -273,12 +270,12 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
raise Exception("Forced save failure") raise Exception("Forced save failure")
return original_save(self, *args, **kwargs) return original_save(self, *args, **kwargs)
with mock.patch.object(Document, 'save', failing_save): with mock.patch.object(Document, "save", failing_save):
with self.assertRaises(Exception): with self.assertRaises(Exception):
scanner.apply_scan_results( scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
# Verify changes were rolled back # Verify changes were rolled back
@ -297,19 +294,19 @@ class TestAIScannerIntegrationPerformance(TestCase):
for i in range(5): for i in range(5):
doc = Document.objects.create( doc = Document.objects.create(
title=f"Document {i}", title=f"Document {i}",
content=f"Content for document {i}" content=f"Content for document {i}",
) )
documents.append(doc) documents.append(doc)
# Mock to avoid actual ML loading # Mock to avoid actual ML loading
with mock.patch.object(scanner, '_extract_entities', return_value={}), \ with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
results = [] results = []
for doc in documents: for doc in documents:
@ -329,16 +326,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
"""Set up test data.""" """Set up test data."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Business Invoice", title="Business Invoice",
content="Invoice from ACME Corporation" content="Invoice from ACME Corporation",
) )
self.correspondent_acme = Correspondent.objects.create( self.correspondent_acme = Correspondent.objects.create(
name="ACME Corporation", name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO matching_algorithm=Correspondent.MATCH_AUTO,
) )
self.correspondent_other = Correspondent.objects.create( self.correspondent_other = Correspondent.objects.create(
name="Other Company", name="Other Company",
matching_algorithm=Correspondent.MATCH_AUTO matching_algorithm=Correspondent.MATCH_AUTO,
) )
def test_correspondent_matching_with_ner_entities(self): def test_correspondent_matching_with_ner_entities(self):
@ -348,16 +345,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
# Mock NER to extract organization # Mock NER to extract organization
mock_ner = mock.MagicMock() mock_ner = mock.MagicMock()
mock_ner.extract_all.return_value = { mock_ner.extract_all.return_value = {
"organizations": [{"text": "ACME Corporation"}] "organizations": [{"text": "ACME Corporation"}],
} }
scanner._ner_extractor = mock_ner scanner._ner_extractor = mock_ner
# Mock matching to return empty (so NER-based matching is used) # Mock matching to return empty (so NER-based matching is used)
with mock.patch('documents.ai_scanner.match_correspondents', return_value=[]): with mock.patch("documents.ai_scanner.match_correspondents", return_value=[]):
result = scanner._detect_correspondent( result = scanner._detect_correspondent(
self.document, self.document,
self.document.content, self.document.content,
{"organizations": [{"text": "ACME Corporation"}]} {"organizations": [{"text": "ACME Corporation"}]},
) )
# Should detect ACME correspondent # Should detect ACME correspondent
@ -375,13 +372,13 @@ class TestAIScannerIntegrationTitleGeneration(TestCase):
document = Document.objects.create( document = Document.objects.create(
title="document.pdf", title="document.pdf",
content="Invoice from ACME Corp dated 2024-01-15" content="Invoice from ACME Corp dated 2024-01-15",
) )
entities = { entities = {
"document_type": "Invoice", "document_type": "Invoice",
"organizations": [{"text": "ACME Corp"}], "organizations": [{"text": "ACME Corp"}],
"dates": [{"text": "2024-01-15"}] "dates": [{"text": "2024-01-15"}],
} }
title = scanner._suggest_title(document, document.content, entities) title = scanner._suggest_title(document, document.content, entities)
@ -399,7 +396,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
"""Set up test data.""" """Set up test data."""
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test", title="Test",
content="Test" content="Test",
) )
self.tag_high = Tag.objects.create(name="HighConfidence") self.tag_high = Tag.objects.create(name="HighConfidence")
self.tag_medium = Tag.objects.create(name="MediumConfidence") self.tag_medium = Tag.objects.create(name="MediumConfidence")
@ -409,7 +406,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
"""Test that only high confidence suggestions are auto-applied.""" """Test that only high confidence suggestions are auto-applied."""
scanner = AIDocumentScanner( scanner = AIDocumentScanner(
auto_apply_threshold=0.80, auto_apply_threshold=0.80,
suggest_threshold=0.60 suggest_threshold=0.60,
) )
scan_result = AIScanResult() scan_result = AIScanResult()
@ -422,7 +419,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
result = scanner.apply_scan_results( result = scanner.apply_scan_results(
self.document, self.document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
# Verify high confidence was applied # Verify high confidence was applied
@ -448,17 +445,17 @@ class TestAIScannerIntegrationGlobalInstance(TestCase):
# Should be functional # Should be functional
document = Document.objects.create( document = Document.objects.create(
title="Test", title="Test",
content="Test content" content="Test content",
) )
with mock.patch.object(scanner1, '_extract_entities', return_value={}), \ with mock.patch.object(scanner1, "_extract_entities", return_value={}), \
mock.patch.object(scanner1, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner1, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner1, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner1, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner1, '_classify_document_type', return_value=None), \ mock.patch.object(scanner1, "_classify_document_type", return_value=None), \
mock.patch.object(scanner1, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner1, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner1, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner1, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner1, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner1, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner1, '_suggest_title', return_value=None): mock.patch.object(scanner1, "_suggest_title", return_value=None):
result1 = scanner1.scan_document(document, document.content) result1 = scanner1.scan_document(document, document.content)
result2 = scanner2.scan_document(document, document.content) result2 = scanner2.scan_document(document, document.content)
@ -476,17 +473,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
document = Document.objects.create( document = Document.objects.create(
title="", title="",
content="" content="",
) )
with mock.patch.object(scanner, '_extract_entities', return_value={}), \ with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(document, document.content) result = scanner.scan_document(document, document.content)
@ -498,7 +495,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
document = Document.objects.create( document = Document.objects.create(
title="Test", title="Test",
content="Test" content="Test",
) )
scan_result = AIScanResult() scan_result = AIScanResult()
@ -509,7 +506,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
result = scanner.apply_scan_results( result = scanner.apply_scan_results(
document, document,
scan_result, scan_result,
auto_apply=True auto_apply=True,
) )
# Should not crash, just log errors # Should not crash, just log errors
@ -521,17 +518,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
document = Document.objects.create( document = Document.objects.create(
title="Factura - España 🇪🇸", title="Factura - España 🇪🇸",
content="Société française • 日本語 • Ελληνικά • مرحبا" content="Société française • 日本語 • Ελληνικά • مرحبا",
) )
with mock.patch.object(scanner, '_extract_entities', return_value={}), \ with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \ mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \ mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \ mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \ mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \ mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \ mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None): mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(document, document.content) result = scanner.scan_document(document, document.content)

View file

@ -12,18 +12,17 @@ Tests cover:
from unittest import mock from unittest import mock
from django.contrib.auth.models import Permission, User from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from rest_framework import status from rest_framework import status
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents.models import ( from documents.models import Correspondent
Correspondent, from documents.models import DeletionRequest
DeletionRequest, from documents.models import Document
Document, from documents.models import DocumentType
DocumentType, from documents.models import Tag
Tag,
)
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
@ -36,13 +35,13 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.user_with_permission = User.objects.create_user( self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123" username="permitted", email="permitted@test.com", password="permitted123",
) )
self.user_without_permission = User.objects.create_user( self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123" username="regular", email="regular@test.com", password="regular123",
) )
# Assign view permission # Assign view permission
@ -57,7 +56,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create test document # Create test document
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="This is a test invoice from ACME Corporation" content="This is a test invoice from ACME Corporation",
) )
# Create test metadata objects # Create test metadata objects
@ -70,7 +69,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/suggestions/", "/api/ai/suggestions/",
{"document_id": self.document.id}, {"document_id": self.document.id},
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@ -82,7 +81,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/suggestions/", "/api/ai/suggestions/",
{"document_id": self.document.id}, {"document_id": self.document.id},
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -91,7 +90,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test that superusers can access the endpoint.""" """Test that superusers can access the endpoint."""
self.client.force_authenticate(user=self.superuser) self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner: with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response # Mock the scanner response
mock_scan_result = mock.MagicMock() mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [(self.tag.id, 0.85)] mock_scan_result.tags = [(self.tag.id, 0.85)]
@ -108,7 +107,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/suggestions/", "/api/ai/suggestions/",
{"document_id": self.document.id}, {"document_id": self.document.id},
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -119,7 +118,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test that users with permission can access the endpoint.""" """Test that users with permission can access the endpoint."""
self.client.force_authenticate(user=self.user_with_permission) self.client.force_authenticate(user=self.user_with_permission)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner: with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response # Mock the scanner response
mock_scan_result = mock.MagicMock() mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [] mock_scan_result.tags = []
@ -136,7 +135,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/suggestions/", "/api/ai/suggestions/",
{"document_id": self.document.id}, {"document_id": self.document.id},
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -148,7 +147,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/suggestions/", "/api/ai/suggestions/",
{"document_id": 99999}, {"document_id": 99999},
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -160,7 +159,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/suggestions/", "/api/ai/suggestions/",
{}, {},
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -175,10 +174,10 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.user_with_permission = User.objects.create_user( self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123" username="permitted", email="permitted@test.com", password="permitted123",
) )
# Assign apply permission # Assign apply permission
@ -193,7 +192,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create test document # Create test document
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Document", title="Test Document",
content="Test content" content="Test content",
) )
# Create test metadata # Create test metadata
@ -205,7 +204,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/suggestions/apply/", "/api/ai/suggestions/apply/",
{"document_id": self.document.id}, {"document_id": self.document.id},
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@ -214,7 +213,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test successfully applying tag suggestions.""" """Test successfully applying tag suggestions."""
self.client.force_authenticate(user=self.superuser) self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner: with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response # Mock the scanner response
mock_scan_result = mock.MagicMock() mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [(self.tag.id, 0.85)] mock_scan_result.tags = [(self.tag.id, 0.85)]
@ -233,9 +232,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/suggestions/apply/", "/api/ai/suggestions/apply/",
{ {
"document_id": self.document.id, "document_id": self.document.id,
"apply_tags": True "apply_tags": True,
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -245,7 +244,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test successfully applying correspondent suggestion.""" """Test successfully applying correspondent suggestion."""
self.client.force_authenticate(user=self.superuser) self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner: with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response # Mock the scanner response
mock_scan_result = mock.MagicMock() mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [] mock_scan_result.tags = []
@ -264,9 +263,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/suggestions/apply/", "/api/ai/suggestions/apply/",
{ {
"document_id": self.document.id, "document_id": self.document.id,
"apply_correspondent": True "apply_correspondent": True,
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -285,10 +284,10 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.user_without_permission = User.objects.create_user( self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123" username="regular", email="regular@test.com", password="regular123",
) )
def test_unauthorized_access_denied(self): def test_unauthorized_access_denied(self):
@ -309,7 +308,7 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
"""Test getting AI configuration.""" """Test getting AI configuration."""
self.client.force_authenticate(user=self.superuser) self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner: with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
mock_scanner_instance = mock.MagicMock() mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.auto_apply_threshold = 0.80 mock_scanner_instance.auto_apply_threshold = 0.80
mock_scanner_instance.suggest_threshold = 0.60 mock_scanner_instance.suggest_threshold = 0.60
@ -331,9 +330,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/config/", "/api/ai/config/",
{ {
"auto_apply_threshold": 0.90, "auto_apply_threshold": 0.90,
"suggest_threshold": 0.70 "suggest_threshold": 0.70,
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -346,9 +345,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post( response = self.client.post(
"/api/ai/config/", "/api/ai/config/",
{ {
"auto_apply_threshold": 1.5 # Invalid: > 1.0 "auto_apply_threshold": 1.5, # Invalid: > 1.0
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -363,13 +362,13 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
# Create users # Create users
self.superuser = User.objects.create_superuser( self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123" username="admin", email="admin@test.com", password="admin123",
) )
self.user_with_permission = User.objects.create_user( self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123" username="permitted", email="permitted@test.com", password="permitted123",
) )
self.user_without_permission = User.objects.create_user( self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123" username="regular", email="regular@test.com", password="regular123",
) )
# Assign approval permission # Assign approval permission
@ -385,7 +384,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
self.deletion_request = DeletionRequest.objects.create( self.deletion_request = DeletionRequest.objects.create(
user=self.user_with_permission, user=self.user_with_permission,
requested_by_ai=True, requested_by_ai=True,
ai_reason="Document appears to be a duplicate" ai_reason="Document appears to be a duplicate",
) )
def test_unauthorized_access_denied(self): def test_unauthorized_access_denied(self):
@ -394,9 +393,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/", "/api/ai/deletions/approve/",
{ {
"request_id": self.deletion_request.id, "request_id": self.deletion_request.id,
"action": "approve" "action": "approve",
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@ -409,9 +408,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/", "/api/ai/deletions/approve/",
{ {
"request_id": self.deletion_request.id, "request_id": self.deletion_request.id,
"action": "approve" "action": "approve",
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -424,9 +423,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/", "/api/ai/deletions/approve/",
{ {
"request_id": self.deletion_request.id, "request_id": self.deletion_request.id,
"action": "approve" "action": "approve",
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -436,7 +435,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
self.deletion_request.refresh_from_db() self.deletion_request.refresh_from_db()
self.assertEqual( self.assertEqual(
self.deletion_request.status, self.deletion_request.status,
DeletionRequest.STATUS_APPROVED DeletionRequest.STATUS_APPROVED,
) )
def test_reject_deletion_success(self): def test_reject_deletion_success(self):
@ -448,9 +447,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
{ {
"request_id": self.deletion_request.id, "request_id": self.deletion_request.id,
"action": "reject", "action": "reject",
"reason": "Document is still needed" "reason": "Document is still needed",
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -459,7 +458,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
self.deletion_request.refresh_from_db() self.deletion_request.refresh_from_db()
self.assertEqual( self.assertEqual(
self.deletion_request.status, self.deletion_request.status,
DeletionRequest.STATUS_REJECTED DeletionRequest.STATUS_REJECTED,
) )
def test_invalid_request_id(self): def test_invalid_request_id(self):
@ -470,9 +469,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/", "/api/ai/deletions/approve/",
{ {
"request_id": 99999, "request_id": 99999,
"action": "approve" "action": "approve",
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -485,9 +484,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/", "/api/ai/deletions/approve/",
{ {
"request_id": self.deletion_request.id, "request_id": self.deletion_request.id,
"action": "approve" "action": "approve",
}, },
format="json" format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -502,7 +501,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
# Create user with all AI permissions # Create user with all AI permissions
self.power_user = User.objects.create_user( self.power_user = User.objects.create_user(
username="power_user", email="power@test.com", password="power123" username="power_user", email="power@test.com", password="power123",
) )
content_type = ContentType.objects.get_for_model(Document) content_type = ContentType.objects.get_for_model(Document)
@ -525,7 +524,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
self.document = Document.objects.create( self.document = Document.objects.create(
title="Test Doc", title="Test Doc",
content="Test" content="Test",
) )
def test_power_user_can_access_all_endpoints(self): def test_power_user_can_access_all_endpoints(self):
@ -533,7 +532,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
self.client.force_authenticate(user=self.power_user) self.client.force_authenticate(user=self.power_user)
# Test suggestions endpoint # Test suggestions endpoint
with mock.patch('documents.views.get_ai_scanner') as mock_scanner: with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
mock_scan_result = mock.MagicMock() mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [] mock_scan_result.tags = []
mock_scan_result.correspondent = None mock_scan_result.correspondent = None
@ -553,7 +552,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
response1 = self.client.post( response1 = self.client.post(
"/api/ai/suggestions/", "/api/ai/suggestions/",
{"document_id": self.document.id}, {"document_id": self.document.id},
format="json" format="json",
) )
self.assertEqual(response1.status_code, status.HTTP_200_OK) self.assertEqual(response1.status_code, status.HTTP_200_OK)
@ -562,9 +561,9 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
"/api/ai/suggestions/apply/", "/api/ai/suggestions/apply/",
{ {
"document_id": self.document.id, "document_id": self.document.id,
"apply_tags": False "apply_tags": False,
}, },
format="json" format="json",
) )
self.assertEqual(response2.status_code, status.HTTP_200_OK) self.assertEqual(response2.status_code, status.HTTP_200_OK)

View file

@ -9,14 +9,12 @@ from rest_framework import status
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents.ai_scanner import AIScanResult from documents.ai_scanner import AIScanResult
from documents.models import ( from documents.models import AISuggestionFeedback
AISuggestionFeedback, from documents.models import Correspondent
Correspondent, from documents.models import Document
Document, from documents.models import DocumentType
DocumentType, from documents.models import StoragePath
StoragePath, from documents.models import Tag
Tag,
)
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
@ -64,12 +62,12 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_ai_suggestions_endpoint_exists(self): def test_ai_suggestions_endpoint_exists(self):
"""Test that the ai-suggestions endpoint is accessible.""" """Test that the ai-suggestions endpoint is accessible."""
response = self.client.get( response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/" f"/api/documents/{self.document.pk}/ai-suggestions/",
) )
# Should not be 404 # Should not be 404
self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@mock.patch('documents.ai_scanner.get_ai_scanner') @mock.patch("documents.ai_scanner.get_ai_scanner")
def test_get_ai_suggestions_success(self, mock_get_scanner): def test_get_ai_suggestions_success(self, mock_get_scanner):
"""Test successfully getting AI suggestions for a document.""" """Test successfully getting AI suggestions for a document."""
# Create mock scan result # Create mock scan result
@ -87,7 +85,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Make request # Make request
response = self.client.get( response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/" f"/api/documents/{self.document.pk}/ai-suggestions/",
) )
# Verify response # Verify response
@ -95,23 +93,23 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
data = response.json() data = response.json()
# Check tags # Check tags
self.assertIn('tags', data) self.assertIn("tags", data)
self.assertEqual(len(data['tags']), 2) self.assertEqual(len(data["tags"]), 2)
self.assertEqual(data['tags'][0]['id'], self.tag1.id) self.assertEqual(data["tags"][0]["id"], self.tag1.id)
self.assertEqual(data['tags'][0]['confidence'], 0.85) self.assertEqual(data["tags"][0]["confidence"], 0.85)
# Check correspondent # Check correspondent
self.assertIn('correspondent', data) self.assertIn("correspondent", data)
self.assertEqual(data['correspondent']['id'], self.correspondent.id) self.assertEqual(data["correspondent"]["id"], self.correspondent.id)
self.assertEqual(data['correspondent']['confidence'], 0.90) self.assertEqual(data["correspondent"]["confidence"], 0.90)
# Check document type # Check document type
self.assertIn('document_type', data) self.assertIn("document_type", data)
self.assertEqual(data['document_type']['id'], self.doc_type.id) self.assertEqual(data["document_type"]["id"], self.doc_type.id)
# Check title suggestion # Check title suggestion
self.assertIn('title_suggestion', data) self.assertIn("title_suggestion", data)
self.assertEqual(data['title_suggestion']['title'], "Suggested Title") self.assertEqual(data["title_suggestion"]["title"], "Suggested Title")
def test_get_ai_suggestions_no_content(self): def test_get_ai_suggestions_no_content(self):
"""Test getting AI suggestions for document without content.""" """Test getting AI suggestions for document without content."""
@ -126,7 +124,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/") response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("no content", response.json()['detail'].lower()) self.assertIn("no content", response.json()["detail"].lower())
def test_get_ai_suggestions_document_not_found(self): def test_get_ai_suggestions_document_not_found(self):
"""Test getting AI suggestions for non-existent document.""" """Test getting AI suggestions for non-existent document."""
@ -137,19 +135,19 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_tag(self): def test_apply_suggestion_tag(self):
"""Test applying a tag suggestion.""" """Test applying a tag suggestion."""
request_data = { request_data = {
'suggestion_type': 'tag', "suggestion_type": "tag",
'value_id': self.tag1.id, "value_id": self.tag1.id,
'confidence': 0.85, "confidence": 0.85,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/", f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['status'], 'success') self.assertEqual(response.json()["status"], "success")
# Verify tag was applied # Verify tag was applied
self.document.refresh_from_db() self.document.refresh_from_db()
@ -158,7 +156,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Verify feedback was recorded # Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter( feedback = AISuggestionFeedback.objects.filter(
document=self.document, document=self.document,
suggestion_type='tag', suggestion_type="tag",
).first() ).first()
self.assertIsNotNone(feedback) self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
@ -169,15 +167,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_correspondent(self): def test_apply_suggestion_correspondent(self):
"""Test applying a correspondent suggestion.""" """Test applying a correspondent suggestion."""
request_data = { request_data = {
'suggestion_type': 'correspondent', "suggestion_type": "correspondent",
'value_id': self.correspondent.id, "value_id": self.correspondent.id,
'confidence': 0.90, "confidence": 0.90,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/", f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -189,7 +187,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Verify feedback was recorded # Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter( feedback = AISuggestionFeedback.objects.filter(
document=self.document, document=self.document,
suggestion_type='correspondent', suggestion_type="correspondent",
).first() ).first()
self.assertIsNotNone(feedback) self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
@ -197,15 +195,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_document_type(self): def test_apply_suggestion_document_type(self):
"""Test applying a document type suggestion.""" """Test applying a document type suggestion."""
request_data = { request_data = {
'suggestion_type': 'document_type', "suggestion_type": "document_type",
'value_id': self.doc_type.id, "value_id": self.doc_type.id,
'confidence': 0.88, "confidence": 0.88,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/", f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -217,35 +215,35 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_title(self): def test_apply_suggestion_title(self):
"""Test applying a title suggestion.""" """Test applying a title suggestion."""
request_data = { request_data = {
'suggestion_type': 'title', "suggestion_type": "title",
'value_text': 'New Suggested Title', "value_text": "New Suggested Title",
'confidence': 0.80, "confidence": 0.80,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/", f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify title was applied # Verify title was applied
self.document.refresh_from_db() self.document.refresh_from_db()
self.assertEqual(self.document.title, 'New Suggested Title') self.assertEqual(self.document.title, "New Suggested Title")
def test_apply_suggestion_invalid_type(self): def test_apply_suggestion_invalid_type(self):
"""Test applying suggestion with invalid type.""" """Test applying suggestion with invalid type."""
request_data = { request_data = {
'suggestion_type': 'invalid_type', "suggestion_type": "invalid_type",
'value_id': 1, "value_id": 1,
'confidence': 0.85, "confidence": 0.85,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/", f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -253,14 +251,14 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_missing_value(self): def test_apply_suggestion_missing_value(self):
"""Test applying suggestion without value_id or value_text.""" """Test applying suggestion without value_id or value_text."""
request_data = { request_data = {
'suggestion_type': 'tag', "suggestion_type": "tag",
'confidence': 0.85, "confidence": 0.85,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/", f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -268,15 +266,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_nonexistent_object(self): def test_apply_suggestion_nonexistent_object(self):
"""Test applying suggestion with non-existent object ID.""" """Test applying suggestion with non-existent object ID."""
request_data = { request_data = {
'suggestion_type': 'tag', "suggestion_type": "tag",
'value_id': 99999, "value_id": 99999,
'confidence': 0.85, "confidence": 0.85,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/", f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -284,24 +282,24 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_reject_suggestion(self): def test_reject_suggestion(self):
"""Test rejecting an AI suggestion.""" """Test rejecting an AI suggestion."""
request_data = { request_data = {
'suggestion_type': 'tag', "suggestion_type": "tag",
'value_id': self.tag1.id, "value_id": self.tag1.id,
'confidence': 0.65, "confidence": 0.65,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/reject-suggestion/", f"/api/documents/{self.document.pk}/reject-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['status'], 'success') self.assertEqual(response.json()["status"], "success")
# Verify feedback was recorded # Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter( feedback = AISuggestionFeedback.objects.filter(
document=self.document, document=self.document,
suggestion_type='tag', suggestion_type="tag",
).first() ).first()
self.assertIsNotNone(feedback) self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
@ -312,15 +310,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_reject_suggestion_with_text(self): def test_reject_suggestion_with_text(self):
"""Test rejecting a suggestion with text value.""" """Test rejecting a suggestion with text value."""
request_data = { request_data = {
'suggestion_type': 'title', "suggestion_type": "title",
'value_text': 'Bad Title Suggestion', "value_text": "Bad Title Suggestion",
'confidence': 0.50, "confidence": 0.50,
} }
response = self.client.post( response = self.client.post(
f"/api/documents/{self.document.pk}/reject-suggestion/", f"/api/documents/{self.document.pk}/reject-suggestion/",
data=request_data, data=request_data,
format='json', format="json",
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -328,11 +326,11 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Verify feedback was recorded # Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter( feedback = AISuggestionFeedback.objects.filter(
document=self.document, document=self.document,
suggestion_type='title', suggestion_type="title",
).first() ).first()
self.assertIsNotNone(feedback) self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED) self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
self.assertEqual(feedback.suggested_value_text, 'Bad Title Suggestion') self.assertEqual(feedback.suggested_value_text, "Bad Title Suggestion")
def test_ai_suggestion_stats_empty(self): def test_ai_suggestion_stats_empty(self):
"""Test getting statistics when no feedback exists.""" """Test getting statistics when no feedback exists."""
@ -341,17 +339,17 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json() data = response.json()
self.assertEqual(data['total_suggestions'], 0) self.assertEqual(data["total_suggestions"], 0)
self.assertEqual(data['total_applied'], 0) self.assertEqual(data["total_applied"], 0)
self.assertEqual(data['total_rejected'], 0) self.assertEqual(data["total_rejected"], 0)
self.assertEqual(data['accuracy_rate'], 0) self.assertEqual(data["accuracy_rate"], 0)
def test_ai_suggestion_stats_with_data(self): def test_ai_suggestion_stats_with_data(self):
"""Test getting statistics with feedback data.""" """Test getting statistics with feedback data."""
# Create some feedback entries # Create some feedback entries
AISuggestionFeedback.objects.create( AISuggestionFeedback.objects.create(
document=self.document, document=self.document,
suggestion_type='tag', suggestion_type="tag",
suggested_value_id=self.tag1.id, suggested_value_id=self.tag1.id,
confidence=0.85, confidence=0.85,
status=AISuggestionFeedback.STATUS_APPLIED, status=AISuggestionFeedback.STATUS_APPLIED,
@ -359,7 +357,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
) )
AISuggestionFeedback.objects.create( AISuggestionFeedback.objects.create(
document=self.document, document=self.document,
suggestion_type='tag', suggestion_type="tag",
suggested_value_id=self.tag2.id, suggested_value_id=self.tag2.id,
confidence=0.70, confidence=0.70,
status=AISuggestionFeedback.STATUS_APPLIED, status=AISuggestionFeedback.STATUS_APPLIED,
@ -367,7 +365,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
) )
AISuggestionFeedback.objects.create( AISuggestionFeedback.objects.create(
document=self.document, document=self.document,
suggestion_type='correspondent', suggestion_type="correspondent",
suggested_value_id=self.correspondent.id, suggested_value_id=self.correspondent.id,
confidence=0.60, confidence=0.60,
status=AISuggestionFeedback.STATUS_REJECTED, status=AISuggestionFeedback.STATUS_REJECTED,
@ -380,25 +378,25 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
data = response.json() data = response.json()
# Check overall stats # Check overall stats
self.assertEqual(data['total_suggestions'], 3) self.assertEqual(data["total_suggestions"], 3)
self.assertEqual(data['total_applied'], 2) self.assertEqual(data["total_applied"], 2)
self.assertEqual(data['total_rejected'], 1) self.assertEqual(data["total_rejected"], 1)
self.assertAlmostEqual(data['accuracy_rate'], 66.67, places=1) self.assertAlmostEqual(data["accuracy_rate"], 66.67, places=1)
# Check by_type stats # Check by_type stats
self.assertIn('by_type', data) self.assertIn("by_type", data)
self.assertIn('tag', data['by_type']) self.assertIn("tag", data["by_type"])
self.assertEqual(data['by_type']['tag']['total'], 2) self.assertEqual(data["by_type"]["tag"]["total"], 2)
self.assertEqual(data['by_type']['tag']['applied'], 2) self.assertEqual(data["by_type"]["tag"]["applied"], 2)
self.assertEqual(data['by_type']['tag']['rejected'], 0) self.assertEqual(data["by_type"]["tag"]["rejected"], 0)
# Check confidence averages # Check confidence averages
self.assertGreater(data['average_confidence_applied'], 0) self.assertGreater(data["average_confidence_applied"], 0)
self.assertGreater(data['average_confidence_rejected'], 0) self.assertGreater(data["average_confidence_rejected"], 0)
# Check recent suggestions # Check recent suggestions
self.assertIn('recent_suggestions', data) self.assertIn("recent_suggestions", data)
self.assertEqual(len(data['recent_suggestions']), 3) self.assertEqual(len(data["recent_suggestions"]), 3)
def test_ai_suggestion_stats_accuracy_calculation(self): def test_ai_suggestion_stats_accuracy_calculation(self):
"""Test that accuracy rate is calculated correctly.""" """Test that accuracy rate is calculated correctly."""
@ -406,7 +404,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
for i in range(7): for i in range(7):
AISuggestionFeedback.objects.create( AISuggestionFeedback.objects.create(
document=self.document, document=self.document,
suggestion_type='tag', suggestion_type="tag",
suggested_value_id=self.tag1.id, suggested_value_id=self.tag1.id,
confidence=0.80, confidence=0.80,
status=AISuggestionFeedback.STATUS_APPLIED, status=AISuggestionFeedback.STATUS_APPLIED,
@ -416,7 +414,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
for i in range(3): for i in range(3):
AISuggestionFeedback.objects.create( AISuggestionFeedback.objects.create(
document=self.document, document=self.document,
suggestion_type='tag', suggestion_type="tag",
suggested_value_id=self.tag2.id, suggested_value_id=self.tag2.id,
confidence=0.60, confidence=0.60,
status=AISuggestionFeedback.STATUS_REJECTED, status=AISuggestionFeedback.STATUS_REJECTED,
@ -428,10 +426,10 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json() data = response.json()
self.assertEqual(data['total_suggestions'], 10) self.assertEqual(data["total_suggestions"], 10)
self.assertEqual(data['total_applied'], 7) self.assertEqual(data["total_applied"], 7)
self.assertEqual(data['total_rejected'], 3) self.assertEqual(data["total_rejected"], 3)
self.assertEqual(data['accuracy_rate'], 70.0) self.assertEqual(data["accuracy_rate"], 70.0)
def test_authentication_required(self): def test_authentication_required(self):
"""Test that authentication is required for all endpoints.""" """Test that authentication is required for all endpoints."""
@ -439,7 +437,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Test ai-suggestions endpoint # Test ai-suggestions endpoint
response = self.client.get( response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/" f"/api/documents/{self.document.pk}/ai-suggestions/",
) )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

View file

@ -11,17 +11,14 @@ Tests cover:
""" """
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import override_settings
from rest_framework import status from rest_framework import status
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents.models import ( from documents.models import Correspondent
Correspondent, from documents.models import DeletionRequest
DeletionRequest, from documents.models import Document
Document, from documents.models import DocumentType
DocumentType, from documents.models import Tag
Tag,
)
class TestDeletionRequestAPI(APITestCase): class TestDeletionRequestAPI(APITestCase):

View file

@ -1561,7 +1561,6 @@ class TestConsumerAIScannerIntegration(
Verifies that AI scanner respects database transactions and handles Verifies that AI scanner respects database transactions and handles
rollbacks correctly. rollbacks correctly.
""" """
from django.db import transaction as db_transaction
tag = Tag.objects.create(name="Invoice") tag = Tag.objects.create(name="Invoice")

View file

@ -15,13 +15,11 @@ from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
from documents.models import ( from documents.models import Correspondent
Correspondent, from documents.models import DeletionRequest
DeletionRequest, from documents.models import Document
Document, from documents.models import DocumentType
DocumentType, from documents.models import Tag
Tag,
)
class TestDeletionRequestModelCreation(TestCase): class TestDeletionRequestModelCreation(TestCase):

View file

@ -8,11 +8,9 @@ from unittest import mock
from django.test import TestCase from django.test import TestCase
from documents.ml.model_cache import ( from documents.ml.model_cache import CacheMetrics
CacheMetrics, from documents.ml.model_cache import LRUCache
LRUCache, from documents.ml.model_cache import ModelCacheManager
ModelCacheManager,
)
class TestCacheMetrics(TestCase): class TestCacheMetrics(TestCase):

View file

@ -150,7 +150,6 @@ class TestMLCacheDirectory:
def test_model_cache_writable(self, tmp_path): def test_model_cache_writable(self, tmp_path):
"""Test that we can write to model cache directory.""" """Test that we can write to model cache directory."""
import pathlib
# Use tmp_path fixture for testing # Use tmp_path fixture for testing
cache_dir = tmp_path / ".cache" / "huggingface" cache_dir = tmp_path / ".cache" / "huggingface"
@ -169,7 +168,6 @@ class TestMLCacheDirectory:
def test_torch_cache_directory(self, tmp_path, monkeypatch): def test_torch_cache_directory(self, tmp_path, monkeypatch):
"""Test that PyTorch can use a custom cache directory.""" """Test that PyTorch can use a custom cache directory."""
import torch
# Set custom cache directory # Set custom cache directory
cache_dir = tmp_path / ".cache" / "torch" cache_dir = tmp_path / ".cache" / "torch"
@ -204,9 +202,10 @@ class TestMLPerformanceBasic:
def test_numpy_performance_basic(self): def test_numpy_performance_basic(self):
"""Test basic NumPy performance with larger arrays.""" """Test basic NumPy performance with larger arrays."""
import numpy as np
import time import time
import numpy as np
# Create large array (10 million elements) # Create large array (10 million elements)
arr = np.random.rand(10_000_000) arr = np.random.rand(10_000_000)

View file

@ -89,6 +89,8 @@ from rest_framework.viewsets import ViewSet
from documents import bulk_edit from documents import bulk_edit
from documents import index from documents import index
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import get_ai_scanner
from documents.bulk_download import ArchiveOnlyStrategy from documents.bulk_download import ArchiveOnlyStrategy
from documents.bulk_download import OriginalAndArchiveStrategy from documents.bulk_download import OriginalAndArchiveStrategy
from documents.bulk_download import OriginalsOnlyStrategy from documents.bulk_download import OriginalsOnlyStrategy
@ -141,13 +143,10 @@ from documents.models import UiSettings
from documents.models import Workflow from documents.models import Workflow
from documents.models import WorkflowAction from documents.models import WorkflowAction
from documents.models import WorkflowTrigger from documents.models import WorkflowTrigger
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import get_ai_scanner
from documents.parsers import get_parser_class_for_mime_type from documents.parsers import get_parser_class_for_mime_type
from documents.parsers import parse_date_generator from documents.parsers import parse_date_generator
from documents.permissions import AcknowledgeTasksPermissions from documents.permissions import AcknowledgeTasksPermissions
from documents.permissions import CanApplyAISuggestionsPermission from documents.permissions import CanApplyAISuggestionsPermission
from documents.permissions import CanApproveDeletionsPermission
from documents.permissions import CanConfigureAIPermission from documents.permissions import CanConfigureAIPermission
from documents.permissions import CanViewAISuggestionsPermission from documents.permissions import CanViewAISuggestionsPermission
from documents.permissions import PaperlessAdminPermissions from documents.permissions import PaperlessAdminPermissions
@ -1388,7 +1387,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
scan_result = scanner.scan_document( scan_result = scanner.scan_document(
document=document, document=document,
document_text=document.content, document_text=document.content,
original_file_path=document.source_path if hasattr(document, 'source_path') else None, original_file_path=document.source_path if hasattr(document, "source_path") else None,
) )
# Convert scan result to serializable format # Convert scan result to serializable format
@ -1424,43 +1423,43 @@ class UnifiedSearchViewSet(DocumentViewSet):
serializer = ApplySuggestionSerializer(data=request.data) serializer = ApplySuggestionSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
suggestion_type = serializer.validated_data['suggestion_type'] suggestion_type = serializer.validated_data["suggestion_type"]
value_id = serializer.validated_data.get('value_id') value_id = serializer.validated_data.get("value_id")
value_text = serializer.validated_data.get('value_text') value_text = serializer.validated_data.get("value_text")
confidence = serializer.validated_data['confidence'] confidence = serializer.validated_data["confidence"]
# Apply the suggestion based on type # Apply the suggestion based on type
applied = False applied = False
result_message = "" result_message = ""
if suggestion_type == 'tag' and value_id: if suggestion_type == "tag" and value_id:
tag = Tag.objects.get(pk=value_id) tag = Tag.objects.get(pk=value_id)
document.tags.add(tag) document.tags.add(tag)
applied = True applied = True
result_message = f"Tag '{tag.name}' applied" result_message = f"Tag '{tag.name}' applied"
elif suggestion_type == 'correspondent' and value_id: elif suggestion_type == "correspondent" and value_id:
correspondent = Correspondent.objects.get(pk=value_id) correspondent = Correspondent.objects.get(pk=value_id)
document.correspondent = correspondent document.correspondent = correspondent
document.save() document.save()
applied = True applied = True
result_message = f"Correspondent '{correspondent.name}' applied" result_message = f"Correspondent '{correspondent.name}' applied"
elif suggestion_type == 'document_type' and value_id: elif suggestion_type == "document_type" and value_id:
doc_type = DocumentType.objects.get(pk=value_id) doc_type = DocumentType.objects.get(pk=value_id)
document.document_type = doc_type document.document_type = doc_type
document.save() document.save()
applied = True applied = True
result_message = f"Document type '{doc_type.name}' applied" result_message = f"Document type '{doc_type.name}' applied"
elif suggestion_type == 'storage_path' and value_id: elif suggestion_type == "storage_path" and value_id:
storage_path = StoragePath.objects.get(pk=value_id) storage_path = StoragePath.objects.get(pk=value_id)
document.storage_path = storage_path document.storage_path = storage_path
document.save() document.save()
applied = True applied = True
result_message = f"Storage path '{storage_path.name}' applied" result_message = f"Storage path '{storage_path.name}' applied"
elif suggestion_type == 'title' and value_text: elif suggestion_type == "title" and value_text:
document.title = value_text document.title = value_text
document.save() document.save()
applied = True applied = True
@ -1518,10 +1517,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
serializer = RejectSuggestionSerializer(data=request.data) serializer = RejectSuggestionSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
suggestion_type = serializer.validated_data['suggestion_type'] suggestion_type = serializer.validated_data["suggestion_type"]
value_id = serializer.validated_data.get('value_id') value_id = serializer.validated_data.get("value_id")
value_text = serializer.validated_data.get('value_text') value_text = serializer.validated_data.get("value_text")
confidence = serializer.validated_data['confidence'] confidence = serializer.validated_data["confidence"]
# Record feedback # Record feedback
AISuggestionFeedback.objects.create( AISuggestionFeedback.objects.create(
@ -1554,7 +1553,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
Returns aggregated data about applied vs rejected suggestions, Returns aggregated data about applied vs rejected suggestions,
accuracy rates, and confidence scores. accuracy rates, and confidence scores.
""" """
from django.db.models import Avg, Count, Q from django.db.models import Avg
from django.db.models import Count
from django.db.models import Q
from documents.models import AISuggestionFeedback from documents.models import AISuggestionFeedback
from documents.serializers.ai_suggestions import AISuggestionStatsSerializer from documents.serializers.ai_suggestions import AISuggestionStatsSerializer
@ -1562,61 +1564,63 @@ class UnifiedSearchViewSet(DocumentViewSet):
# Get overall counts # Get overall counts
total_feedbacks = AISuggestionFeedback.objects.count() total_feedbacks = AISuggestionFeedback.objects.count()
total_applied = AISuggestionFeedback.objects.filter( total_applied = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_APPLIED status=AISuggestionFeedback.STATUS_APPLIED,
).count() ).count()
total_rejected = AISuggestionFeedback.objects.filter( total_rejected = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_REJECTED status=AISuggestionFeedback.STATUS_REJECTED,
).count() ).count()
# Calculate accuracy rate # Calculate accuracy rate
accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0 accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0
# Get statistics by suggestion type using a single aggregated query # Get statistics by suggestion type using a single aggregated query
stats_by_type = AISuggestionFeedback.objects.values('suggestion_type').annotate( stats_by_type = AISuggestionFeedback.objects.values("suggestion_type").annotate(
total=Count('id'), total=Count("id"),
applied=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)), applied=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
rejected=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)) rejected=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)),
) )
# Build the by_type dictionary using the aggregated results # Build the by_type dictionary using the aggregated results
by_type = {} by_type = {}
for stat in stats_by_type: for stat in stats_by_type:
suggestion_type = stat['suggestion_type'] suggestion_type = stat["suggestion_type"]
type_total = stat['total'] type_total = stat["total"]
type_applied = stat['applied'] type_applied = stat["applied"]
type_rejected = stat['rejected'] type_rejected = stat["rejected"]
by_type[suggestion_type] = { by_type[suggestion_type] = {
'total': type_total, "total": type_total,
'applied': type_applied, "applied": type_applied,
'rejected': type_rejected, "rejected": type_rejected,
'accuracy_rate': (type_applied / type_total * 100) if type_total > 0 else 0, "accuracy_rate": (type_applied / type_total * 100) if type_total > 0 else 0,
} }
# Get average confidence scores # Get average confidence scores
avg_confidence_applied = AISuggestionFeedback.objects.filter( avg_confidence_applied = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_APPLIED status=AISuggestionFeedback.STATUS_APPLIED,
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0 ).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
avg_confidence_rejected = AISuggestionFeedback.objects.filter( avg_confidence_rejected = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_REJECTED status=AISuggestionFeedback.STATUS_REJECTED,
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0 ).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
# Get recent suggestions (last 10) # Get recent suggestions (last 10)
recent_suggestions = AISuggestionFeedback.objects.order_by('-created_at')[:10] recent_suggestions = AISuggestionFeedback.objects.order_by("-created_at")[:10]
# Build response data # Build response data
from documents.serializers.ai_suggestions import AISuggestionFeedbackSerializer from documents.serializers.ai_suggestions import (
AISuggestionFeedbackSerializer,
)
data = { data = {
'total_suggestions': total_feedbacks, "total_suggestions": total_feedbacks,
'total_applied': total_applied, "total_applied": total_applied,
'total_rejected': total_rejected, "total_rejected": total_rejected,
'accuracy_rate': accuracy_rate, "accuracy_rate": accuracy_rate,
'by_type': by_type, "by_type": by_type,
'average_confidence_applied': avg_confidence_applied, "average_confidence_applied": avg_confidence_applied,
'average_confidence_rejected': avg_confidence_rejected, "average_confidence_rejected": avg_confidence_rejected,
'recent_suggestions': AISuggestionFeedbackSerializer( "recent_suggestions": AISuggestionFeedbackSerializer(
recent_suggestions, many=True recent_suggestions, many=True,
).data, ).data,
} }
@ -3571,21 +3575,21 @@ class AISuggestionsView(GenericAPIView):
request_serializer = AISuggestionsRequestSerializer(data=request.data) request_serializer = AISuggestionsRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True) request_serializer.is_valid(raise_exception=True)
document_id = request_serializer.validated_data['document_id'] document_id = request_serializer.validated_data["document_id"]
try: try:
document = Document.objects.get(pk=document_id) document = Document.objects.get(pk=document_id)
except Document.DoesNotExist: except Document.DoesNotExist:
return Response( return Response(
{"error": "Document not found or you don't have permission to view it"}, {"error": "Document not found or you don't have permission to view it"},
status=status.HTTP_404_NOT_FOUND status=status.HTTP_404_NOT_FOUND,
) )
# Check if user has permission to view this document # Check if user has permission to view this document
if not has_perms_owner_aware(request.user, 'documents.view_document', document): if not has_perms_owner_aware(request.user, "documents.view_document", document):
return Response( return Response(
{"error": "Permission denied"}, {"error": "Permission denied"},
status=status.HTTP_403_FORBIDDEN status=status.HTTP_403_FORBIDDEN,
) )
# Get AI scanner and scan document # Get AI scanner and scan document
@ -3600,7 +3604,7 @@ class AISuggestionsView(GenericAPIView):
"document_type": None, "document_type": None,
"storage_path": None, "storage_path": None,
"title_suggestion": scan_result.title_suggestion, "title_suggestion": scan_result.title_suggestion,
"custom_fields": {} "custom_fields": {},
} }
# Format tag suggestions # Format tag suggestions
@ -3610,7 +3614,7 @@ class AISuggestionsView(GenericAPIView):
response_data["tags"].append({ response_data["tags"].append({
"id": tag.id, "id": tag.id,
"name": tag.name, "name": tag.name,
"confidence": confidence "confidence": confidence,
}) })
except Tag.DoesNotExist: except Tag.DoesNotExist:
# Tag was suggested by AI but no longer exists; skip it # Tag was suggested by AI but no longer exists; skip it
@ -3624,7 +3628,7 @@ class AISuggestionsView(GenericAPIView):
response_data["correspondent"] = { response_data["correspondent"] = {
"id": correspondent.id, "id": correspondent.id,
"name": correspondent.name, "name": correspondent.name,
"confidence": confidence "confidence": confidence,
} }
except Correspondent.DoesNotExist: except Correspondent.DoesNotExist:
# Correspondent was suggested but no longer exists; skip it # Correspondent was suggested but no longer exists; skip it
@ -3638,7 +3642,7 @@ class AISuggestionsView(GenericAPIView):
response_data["document_type"] = { response_data["document_type"] = {
"id": doc_type.id, "id": doc_type.id,
"name": doc_type.name, "name": doc_type.name,
"confidence": confidence "confidence": confidence,
} }
except DocumentType.DoesNotExist: except DocumentType.DoesNotExist:
# Document type was suggested but no longer exists; skip it # Document type was suggested but no longer exists; skip it
@ -3652,7 +3656,7 @@ class AISuggestionsView(GenericAPIView):
response_data["storage_path"] = { response_data["storage_path"] = {
"id": storage_path.id, "id": storage_path.id,
"name": storage_path.name, "name": storage_path.name,
"confidence": confidence "confidence": confidence,
} }
except StoragePath.DoesNotExist: except StoragePath.DoesNotExist:
# Storage path was suggested but no longer exists; skip it # Storage path was suggested but no longer exists; skip it
@ -3662,7 +3666,7 @@ class AISuggestionsView(GenericAPIView):
for field_id, (value, confidence) in scan_result.custom_fields.items(): for field_id, (value, confidence) in scan_result.custom_fields.items():
response_data["custom_fields"][str(field_id)] = { response_data["custom_fields"][str(field_id)] = {
"value": value, "value": value,
"confidence": confidence "confidence": confidence,
} }
return Response(response_data) return Response(response_data)
@ -3683,21 +3687,21 @@ class ApplyAISuggestionsView(GenericAPIView):
serializer = ApplyAISuggestionsSerializer(data=request.data) serializer = ApplyAISuggestionsSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
document_id = serializer.validated_data['document_id'] document_id = serializer.validated_data["document_id"]
try: try:
document = Document.objects.get(pk=document_id) document = Document.objects.get(pk=document_id)
except Document.DoesNotExist: except Document.DoesNotExist:
return Response( return Response(
{"error": "Document not found"}, {"error": "Document not found"},
status=status.HTTP_404_NOT_FOUND status=status.HTTP_404_NOT_FOUND,
) )
# Check if user has permission to change this document # Check if user has permission to change this document
if not has_perms_owner_aware(request.user, 'documents.change_document', document): if not has_perms_owner_aware(request.user, "documents.change_document", document):
return Response( return Response(
{"error": "Permission denied"}, {"error": "Permission denied"},
status=status.HTTP_403_FORBIDDEN status=status.HTTP_403_FORBIDDEN,
) )
# Get AI scanner and scan document # Get AI scanner and scan document
@ -3707,8 +3711,8 @@ class ApplyAISuggestionsView(GenericAPIView):
# Apply suggestions based on user selections # Apply suggestions based on user selections
applied = [] applied = []
if serializer.validated_data.get('apply_tags'): if serializer.validated_data.get("apply_tags"):
selected_tags = serializer.validated_data.get('selected_tags', []) selected_tags = serializer.validated_data.get("selected_tags", [])
if selected_tags: if selected_tags:
# Apply only selected tags # Apply only selected tags
tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags] tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags]
@ -3725,7 +3729,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Tag not found; skip applying this tag # Tag not found; skip applying this tag
pass pass
if serializer.validated_data.get('apply_correspondent') and scan_result.correspondent: if serializer.validated_data.get("apply_correspondent") and scan_result.correspondent:
corr_id, confidence = scan_result.correspondent corr_id, confidence = scan_result.correspondent
try: try:
correspondent = Correspondent.objects.get(pk=corr_id) correspondent = Correspondent.objects.get(pk=corr_id)
@ -3735,7 +3739,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Correspondent not found; skip applying # Correspondent not found; skip applying
pass pass
if serializer.validated_data.get('apply_document_type') and scan_result.document_type: if serializer.validated_data.get("apply_document_type") and scan_result.document_type:
type_id, confidence = scan_result.document_type type_id, confidence = scan_result.document_type
try: try:
doc_type = DocumentType.objects.get(pk=type_id) doc_type = DocumentType.objects.get(pk=type_id)
@ -3745,7 +3749,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Document type not found; skip applying # Document type not found; skip applying
pass pass
if serializer.validated_data.get('apply_storage_path') and scan_result.storage_path: if serializer.validated_data.get("apply_storage_path") and scan_result.storage_path:
path_id, confidence = scan_result.storage_path path_id, confidence = scan_result.storage_path
try: try:
storage_path = StoragePath.objects.get(pk=path_id) storage_path = StoragePath.objects.get(pk=path_id)
@ -3755,7 +3759,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Storage path not found; skip applying # Storage path not found; skip applying
pass pass
if serializer.validated_data.get('apply_title') and scan_result.title_suggestion: if serializer.validated_data.get("apply_title") and scan_result.title_suggestion:
document.title = scan_result.title_suggestion document.title = scan_result.title_suggestion
applied.append(f"title: {scan_result.title_suggestion}") applied.append(f"title: {scan_result.title_suggestion}")
@ -3765,7 +3769,7 @@ class ApplyAISuggestionsView(GenericAPIView):
return Response({ return Response({
"status": "success", "status": "success",
"document_id": document.id, "document_id": document.id,
"applied": applied "applied": applied,
}) })
@ -3805,14 +3809,14 @@ class AIConfigurationView(GenericAPIView):
# Create new scanner with updated configuration # Create new scanner with updated configuration
config = {} config = {}
if 'auto_apply_threshold' in serializer.validated_data: if "auto_apply_threshold" in serializer.validated_data:
config['auto_apply_threshold'] = serializer.validated_data['auto_apply_threshold'] config["auto_apply_threshold"] = serializer.validated_data["auto_apply_threshold"]
if 'suggest_threshold' in serializer.validated_data: if "suggest_threshold" in serializer.validated_data:
config['suggest_threshold'] = serializer.validated_data['suggest_threshold'] config["suggest_threshold"] = serializer.validated_data["suggest_threshold"]
if 'ml_enabled' in serializer.validated_data: if "ml_enabled" in serializer.validated_data:
config['enable_ml_features'] = serializer.validated_data['ml_enabled'] config["enable_ml_features"] = serializer.validated_data["ml_enabled"]
if 'advanced_ocr_enabled' in serializer.validated_data: if "advanced_ocr_enabled" in serializer.validated_data:
config['enable_advanced_ocr'] = serializer.validated_data['advanced_ocr_enabled'] config["enable_advanced_ocr"] = serializer.validated_data["advanced_ocr_enabled"]
# Update global scanner instance # Update global scanner instance
# WARNING: Not thread-safe. Consider storing configuration in database # WARNING: Not thread-safe. Consider storing configuration in database
@ -3822,7 +3826,7 @@ class AIConfigurationView(GenericAPIView):
return Response({ return Response({
"status": "success", "status": "success",
"message": "AI configuration updated. Changes may require server restart for consistency." "message": "AI configuration updated. Changes may require server restart for consistency.",
}) })

View file

@ -76,7 +76,7 @@ class DeletionRequestViewSet(ModelViewSet):
# Check permissions # Check permissions
if not self._can_manage_request(deletion_request): if not self._can_manage_request(deletion_request):
return HttpResponseForbidden( return HttpResponseForbidden(
"You don't have permission to approve this deletion request." "You don't have permission to approve this deletion request.",
) )
# Validate status # Validate status
@ -114,11 +114,11 @@ class DeletionRequestViewSet(ModelViewSet):
deleted_count += 1 deleted_count += 1
logger.info( logger.info(
f"Deleted document {doc_id} ('{doc_title}') " f"Deleted document {doc_id} ('{doc_title}') "
f"as part of deletion request {deletion_request.id}" f"as part of deletion request {deletion_request.id}",
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to delete document {doc.id}: {str(e)}" f"Failed to delete document {doc.id}: {e!s}",
) )
failed_deletions.append({ failed_deletions.append({
"id": doc.id, "id": doc.id,
@ -138,14 +138,14 @@ class DeletionRequestViewSet(ModelViewSet):
logger.info( logger.info(
f"Deletion request {deletion_request.id} completed. " f"Deletion request {deletion_request.id} completed. "
f"Deleted {deleted_count}/{len(documents)} documents." f"Deleted {deleted_count}/{len(documents)} documents.",
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error executing deletion request {deletion_request.id}: {str(e)}" f"Error executing deletion request {deletion_request.id}: {e!s}",
) )
return Response( return Response(
{"error": f"Failed to execute deletion: {str(e)}"}, {"error": f"Failed to execute deletion: {e!s}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR, status=status.HTTP_500_INTERNAL_SERVER_ERROR,
) )
@ -176,7 +176,7 @@ class DeletionRequestViewSet(ModelViewSet):
# Check permissions # Check permissions
if not self._can_manage_request(deletion_request): if not self._can_manage_request(deletion_request):
return HttpResponseForbidden( return HttpResponseForbidden(
"You don't have permission to reject this deletion request." "You don't have permission to reject this deletion request.",
) )
# Validate status # Validate status
@ -199,7 +199,7 @@ class DeletionRequestViewSet(ModelViewSet):
) )
logger.info( logger.info(
f"Deletion request {deletion_request.id} rejected by user {request.user.username}" f"Deletion request {deletion_request.id} rejected by user {request.user.username}",
) )
serializer = self.get_serializer(deletion_request) serializer = self.get_serializer(deletion_request)
@ -228,7 +228,7 @@ class DeletionRequestViewSet(ModelViewSet):
# Check permissions # Check permissions
if not self._can_manage_request(deletion_request): if not self._can_manage_request(deletion_request):
return HttpResponseForbidden( return HttpResponseForbidden(
"You don't have permission to cancel this deletion request." "You don't have permission to cancel this deletion request.",
) )
# Validate status # Validate status
@ -249,7 +249,7 @@ class DeletionRequestViewSet(ModelViewSet):
deletion_request.save() deletion_request.save()
logger.info( logger.info(
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}" f"Deletion request {deletion_request.id} cancelled by user {request.user.username}",
) )
serializer = self.get_serializer(deletion_request) serializer = self.get_serializer(deletion_request)

View file

@ -160,7 +160,7 @@ class SecurityHeadersMiddleware:
# Store nonce in request for use in templates # Store nonce in request for use in templates
# Templates can access this via {{ request.csp_nonce }} # Templates can access this via {{ request.csp_nonce }}
if hasattr(request, '_csp_nonce'): if hasattr(request, "_csp_nonce"):
request._csp_nonce = nonce request._csp_nonce = nonce
# Prevent clickjacking attacks # Prevent clickjacking attacks

View file

@ -9,7 +9,6 @@ from __future__ import annotations
import hashlib import hashlib
import logging import logging
import mimetypes
import os import os
import re import re
from pathlib import Path from pathlib import Path
@ -26,39 +25,39 @@ logger = logging.getLogger("paperless.security")
# Lista explícita de tipos MIME permitidos # Lista explícita de tipos MIME permitidos
ALLOWED_MIME_TYPES = { ALLOWED_MIME_TYPES = {
# Documentos # Documentos
'application/pdf', "application/pdf",
'application/vnd.oasis.opendocument.text', "application/vnd.oasis.opendocument.text",
'application/msword', "application/msword",
'application/vnd.openxmlformats-officedocument.wordprocessingml.document', "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
'application/vnd.ms-excel', "application/vnd.ms-excel",
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
'application/vnd.ms-powerpoint', "application/vnd.ms-powerpoint",
'application/vnd.openxmlformats-officedocument.presentationml.presentation', "application/vnd.openxmlformats-officedocument.presentationml.presentation",
'application/vnd.oasis.opendocument.spreadsheet', "application/vnd.oasis.opendocument.spreadsheet",
'application/vnd.oasis.opendocument.presentation', "application/vnd.oasis.opendocument.presentation",
'application/rtf', "application/rtf",
'text/rtf', "text/rtf",
# Imágenes # Imágenes
'image/jpeg', "image/jpeg",
'image/png', "image/png",
'image/gif', "image/gif",
'image/tiff', "image/tiff",
'image/bmp', "image/bmp",
'image/webp', "image/webp",
# Texto # Texto
'text/plain', "text/plain",
'text/html', "text/html",
'text/csv', "text/csv",
'text/markdown', "text/markdown",
} }
# Maximum file size (100MB by default) # Maximum file size (100MB by default)
# Can be overridden by settings.MAX_UPLOAD_SIZE # Can be overridden by settings.MAX_UPLOAD_SIZE
try: try:
from django.conf import settings from django.conf import settings
MAX_FILE_SIZE = getattr(settings, 'MAX_UPLOAD_SIZE', 100 * 1024 * 1024) # 100MB por defecto MAX_FILE_SIZE = getattr(settings, "MAX_UPLOAD_SIZE", 100 * 1024 * 1024) # 100MB por defecto
except ImportError: except ImportError:
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes
@ -114,7 +113,6 @@ ALLOWED_JS_PATTERNS = [
class FileValidationError(Exception): class FileValidationError(Exception):
"""Raised when file validation fails.""" """Raised when file validation fails."""
pass
def has_whitelisted_javascript(content: bytes) -> bool: def has_whitelisted_javascript(content: bytes) -> bool:
@ -143,7 +141,7 @@ def validate_mime_type(mime_type: str) -> None:
if mime_type not in ALLOWED_MIME_TYPES: if mime_type not in ALLOWED_MIME_TYPES:
raise FileValidationError( raise FileValidationError(
f"MIME type '{mime_type}' is not allowed. " f"MIME type '{mime_type}' is not allowed. "
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}" f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}",
) )

View file

@ -86,7 +86,7 @@ api_router.register(r"config", ApplicationConfigurationViewSet)
api_router.register(r"processed_mail", ProcessedMailViewSet) api_router.register(r"processed_mail", ProcessedMailViewSet)
api_router.register(r"deletion_requests", DeletionRequestViewSet) api_router.register(r"deletion_requests", DeletionRequestViewSet)
api_router.register( api_router.register(
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests" r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests",
) )