fix(syntax): corrige errores de sintaxis y formato en Python

- Corrige paréntesis faltante en DeletionRequestActionSerializer (serialisers.py:2855)
- Elimina espacios en blanco en líneas vacías (W293)
- Elimina espacios finales en líneas (W291)
- Elimina imports no utilizados (F401)
- Normaliza comillas a comillas dobles (Q000)
- Agrega comas finales faltantes (COM812)
- Ordena imports según convenciones (I001)
- Actualiza anotaciones de tipo a PEP 585 (UP006)

Este commit resuelve el error de compilación en el job de CI/CD
que estaba causando que fallara el linting check.

Archivos afectados: 38
Líneas modificadas: ~2200
This commit is contained in:
Claude 2025-11-17 19:08:02 +00:00
parent 9298f64546
commit 69326b883d
No known key found for this signature in database
38 changed files with 2077 additions and 2112 deletions

View file

@ -14,14 +14,10 @@ According to agents.md requirements:
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from typing import Any
from django.contrib.auth.models import User
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ai_deletion")

View file

@ -29,7 +29,6 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import TypedDict
from django.conf import settings
@ -142,34 +141,34 @@ class AIScanResult:
"""
# Convert internal tuple format to TypedDict format
result: AIScanResultDict = {
'tags': [{'tag_id': tag_id, 'confidence': conf} for tag_id, conf in self.tags],
'custom_fields': {
field_id: {'value': value, 'confidence': conf}
"tags": [{"tag_id": tag_id, "confidence": conf} for tag_id, conf in self.tags],
"custom_fields": {
field_id: {"value": value, "confidence": conf}
for field_id, (value, conf) in self.custom_fields.items()
},
'workflows': [{'workflow_id': wf_id, 'confidence': conf} for wf_id, conf in self.workflows],
'extracted_entities': self.extracted_entities,
'metadata': self.metadata,
"workflows": [{"workflow_id": wf_id, "confidence": conf} for wf_id, conf in self.workflows],
"extracted_entities": self.extracted_entities,
"metadata": self.metadata,
}
# Add optional fields only if present
if self.correspondent:
result['correspondent'] = {
'correspondent_id': self.correspondent[0],
'confidence': self.correspondent[1],
result["correspondent"] = {
"correspondent_id": self.correspondent[0],
"confidence": self.correspondent[1],
}
if self.document_type:
result['document_type'] = {
'type_id': self.document_type[0],
'confidence': self.document_type[1],
result["document_type"] = {
"type_id": self.document_type[0],
"confidence": self.document_type[1],
}
if self.storage_path:
result['storage_path'] = {
'path_id': self.storage_path[0],
'confidence': self.storage_path[1],
result["storage_path"] = {
"path_id": self.storage_path[0],
"confidence": self.storage_path[1],
}
if self.title_suggestion:
result['title_suggestion'] = self.title_suggestion
result["title_suggestion"] = self.title_suggestion
return result
@ -257,14 +256,14 @@ class AIDocumentScanner:
if self._classifier is None and self.ml_enabled:
try:
from documents.ml.classifier import TransformerDocumentClassifier
# Get model name from settings
model_name = getattr(
settings,
settings,
"PAPERLESS_ML_CLASSIFIER_MODEL",
"distilbert-base-uncased",
)
self._classifier = TransformerDocumentClassifier(
model_name=model_name,
use_cache=True,
@ -291,14 +290,14 @@ class AIDocumentScanner:
if self._semantic_search is None and self.ml_enabled:
try:
from documents.ml.semantic_search import SemanticSearch
# Get cache directory from settings
cache_dir = getattr(
settings,
"PAPERLESS_ML_MODEL_CACHE",
None,
)
self._semantic_search = SemanticSearch(
cache_dir=cache_dir,
use_cache=True,
@ -1004,13 +1003,13 @@ class AIDocumentScanner:
import time
logger.info("Starting ML model warm-up...")
start_time = time.time()
from documents.ml.model_cache import ModelCacheManager
cache_manager = ModelCacheManager.get_instance()
# Define model loaders
model_loaders = {}
# Classifier
if self.ml_enabled:
def load_classifier():
@ -1025,14 +1024,14 @@ class AIDocumentScanner:
use_cache=True,
)
model_loaders["classifier"] = load_classifier
# NER
if self.ml_enabled:
def load_ner():
from documents.ml.ner import DocumentNER
return DocumentNER(use_cache=True)
model_loaders["ner"] = load_ner
# Semantic Search
if self.ml_enabled:
def load_semantic():
@ -1040,21 +1039,21 @@ class AIDocumentScanner:
cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None)
return SemanticSearch(cache_dir=cache_dir, use_cache=True)
model_loaders["semantic_search"] = load_semantic
# Table Extractor
if self.advanced_ocr_enabled:
def load_table():
from documents.ocr.table_extractor import TableExtractor
return TableExtractor()
model_loaders["table_extractor"] = load_table
# Warm up all models
cache_manager.warm_up(model_loaders)
warm_up_time = time.time() - start_time
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
def get_cache_metrics(self) -> Dict[str, Any]:
def get_cache_metrics(self) -> dict[str, Any]:
"""
Get cache performance metrics.
@ -1062,7 +1061,7 @@ class AIDocumentScanner:
Dictionary with cache statistics
"""
from documents.ml.model_cache import ModelCacheManager
try:
cache_manager = ModelCacheManager.get_instance()
return cache_manager.get_metrics()
@ -1075,7 +1074,7 @@ class AIDocumentScanner:
def clear_cache(self) -> None:
"""Clear all model caches."""
from documents.ml.model_cache import ModelCacheManager
try:
cache_manager = ModelCacheManager.get_instance()
cache_manager.clear_all()

View file

@ -38,22 +38,22 @@ class DocumentsConfig(AppConfig):
def _initialize_ml_cache(self):
"""Initialize ML model cache and optionally warm up models."""
from django.conf import settings
# Only initialize if ML features are enabled
if not getattr(settings, "PAPERLESS_ENABLE_ML_FEATURES", False):
return
# Initialize cache manager with settings
from documents.ml.model_cache import ModelCacheManager
max_models = getattr(settings, "PAPERLESS_ML_CACHE_MAX_MODELS", 3)
cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None)
cache_manager = ModelCacheManager.get_instance(
max_models=max_models,
disk_cache_dir=str(cache_dir) if cache_dir else None,
)
# Warm up models if configured
warmup_enabled = getattr(settings, "PAPERLESS_ML_CACHE_WARMUP", False)
if warmup_enabled:

View file

@ -310,7 +310,7 @@ class ConsumerPlugin(
# Parsing phase
document_parser = self._create_parser_instance(parser_class)
text, date, thumbnail, archive_path, page_count = self._parse_document(
document_parser, mime_type
document_parser, mime_type,
)
# Storage phase
@ -394,7 +394,7 @@ class ConsumerPlugin(
def _attempt_pdf_recovery(
self,
tempdir: tempfile.TemporaryDirectory,
original_mime_type: str
original_mime_type: str,
) -> str:
"""
Attempt to recover a PDF file with incorrect MIME type using qpdf.
@ -438,7 +438,7 @@ class ConsumerPlugin(
def _get_parser_class(
self,
mime_type: str,
tempdir: tempfile.TemporaryDirectory
tempdir: tempfile.TemporaryDirectory,
) -> type[DocumentParser]:
"""
Determine which parser to use based on MIME type.
@ -468,7 +468,7 @@ class ConsumerPlugin(
def _create_parser_instance(
self,
parser_class: type[DocumentParser]
parser_class: type[DocumentParser],
) -> DocumentParser:
"""
Create a parser instance with progress callback.
@ -496,7 +496,7 @@ class ConsumerPlugin(
def _parse_document(
self,
document_parser: DocumentParser,
mime_type: str
mime_type: str,
) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]:
"""
Parse the document and extract metadata.
@ -670,7 +670,7 @@ class ConsumerPlugin(
self,
document: Document,
thumbnail: Path,
archive_path: Path | None
archive_path: Path | None,
) -> None:
"""
Store document files (source, thumbnail, archive) to disk.
@ -949,7 +949,7 @@ class ConsumerPlugin(
text: The extracted document text
"""
# Check if AI scanner is enabled
if not getattr(settings, 'PAPERLESS_ENABLE_AI_SCANNER', True):
if not getattr(settings, "PAPERLESS_ENABLE_AI_SCANNER", True):
self.log.debug("AI scanner is disabled, skipping AI analysis")
return

View file

@ -1,6 +1,7 @@
# Generated manually for performance optimization
from django.db import migrations, models
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):

View file

@ -1,9 +1,10 @@
# Generated manually for DeletionRequest model
# Based on model definition in documents/models.py
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
from django.conf import settings
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
@ -48,7 +49,7 @@ class Migration(migrations.Migration):
(
"ai_reason",
models.TextField(
help_text="Detailed explanation from AI about why deletion is recommended"
help_text="Detailed explanation from AI about why deletion is recommended",
),
),
(

View file

@ -1,6 +1,7 @@
# Generated manually for DeletionRequest performance optimization
from django.db import migrations, models
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):

View file

@ -1,9 +1,10 @@
# Generated manually for AI Suggestions API
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.core.validators
import django.db.models.deletion
from django.conf import settings
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):

View file

@ -10,9 +10,9 @@ Provides AI/ML capabilities including:
from __future__ import annotations
__all__ = [
"TransformerDocumentClassifier",
"DocumentNER",
"SemanticSearch",
"TransformerDocumentClassifier",
]
# Lazy imports to avoid loading heavy ML libraries unless needed

View file

@ -15,23 +15,16 @@ Logging levels used in this module:
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import torch
from torch.utils.data import Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import Trainer
from transformers import TrainingArguments
from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
from documents.models import Document
logger = logging.getLogger("paperless.ml.classifier")
@ -129,10 +122,10 @@ class TransformerDocumentClassifier:
self.model_name = model_name
self.use_cache = use_cache
self.cache_manager = ModelCacheManager.get_instance() if use_cache else None
# Cache key for this model configuration
self.cache_key = f"classifier_{model_name}"
# Load tokenizer (lightweight, not cached)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = None
@ -141,7 +134,7 @@ class TransformerDocumentClassifier:
logger.info(
f"Initialized TransformerDocumentClassifier with {model_name} "
f"(caching: {use_cache})"
f"(caching: {use_cache})",
)
def train(
@ -264,14 +257,14 @@ class TransformerDocumentClassifier:
if self.use_cache and self.cache_manager:
# Try to get from cache first
cache_key = f"{self.cache_key}_{model_dir}"
def loader():
logger.info(f"Loading model from {model_dir}")
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.eval() # Set to evaluation mode
return {"model": model, "tokenizer": tokenizer}
cached = self.cache_manager.get_or_load_model(cache_key, loader)
self.model = cached["model"]
self.tokenizer = cached["tokenizer"]

View file

@ -24,8 +24,9 @@ import pickle
import threading
import time
from collections import OrderedDict
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any
logger = logging.getLogger("paperless.ml.model_cache")
@ -58,7 +59,7 @@ class CacheMetrics:
with self.lock:
self.loads += 1
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
with self.lock:
total = self.hits + self.misses
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
@ -98,7 +99,7 @@ class LRUCache:
self.lock = threading.Lock()
self.metrics = CacheMetrics()
def get(self, key: str) -> Optional[Any]:
def get(self, key: str) -> Any | None:
"""
Get item from cache.
@ -153,7 +154,7 @@ class LRUCache:
with self.lock:
return len(self.cache)
def get_metrics(self) -> Dict[str, Any]:
def get_metrics(self) -> dict[str, Any]:
"""Get cache metrics."""
return self.metrics.get_stats()
@ -173,7 +174,7 @@ class ModelCacheManager:
model = cache.get_or_load_model("classifier", loader_func)
"""
_instance: Optional[ModelCacheManager] = None
_instance: ModelCacheManager | None = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
@ -187,7 +188,7 @@ class ModelCacheManager:
def __init__(
self,
max_models: int = 3,
disk_cache_dir: Optional[str] = None,
disk_cache_dir: str | None = None,
):
"""
Initialize model cache manager.
@ -215,7 +216,7 @@ class ModelCacheManager:
def get_instance(
cls,
max_models: int = 3,
disk_cache_dir: Optional[str] = None,
disk_cache_dir: str | None = None,
) -> ModelCacheManager:
"""
Get singleton instance of ModelCacheManager.
@ -278,7 +279,7 @@ class ModelCacheManager:
load_time = time.time() - start_time
logger.info(
f"Model loaded successfully: {model_key} "
f"(took {load_time:.2f}s)"
f"(took {load_time:.2f}s)",
)
return model
@ -289,7 +290,7 @@ class ModelCacheManager:
def save_embeddings_to_disk(
self,
key: str,
embeddings: Dict[int, Any],
embeddings: dict[int, Any],
) -> bool:
"""
Save embeddings to disk cache.
@ -311,7 +312,7 @@ class ModelCacheManager:
cache_file = self.disk_cache_dir / f"{key}.pkl"
try:
with open(cache_file, 'wb') as f:
with open(cache_file, "wb") as f:
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}")
return True
@ -330,7 +331,7 @@ class ModelCacheManager:
def load_embeddings_from_disk(
self,
key: str,
) -> Optional[Dict[int, Any]]:
) -> dict[int, Any] | None:
"""
Load embeddings from disk cache.
@ -344,7 +345,7 @@ class ModelCacheManager:
return None
cache_file = self.disk_cache_dir / f"{key}.pkl"
if not cache_file.exists():
return None
@ -384,7 +385,7 @@ class ModelCacheManager:
def clear_all(self) -> None:
"""Clear all caches (memory and disk)."""
self.model_cache.clear()
if self.disk_cache_dir and self.disk_cache_dir.exists():
for cache_file in self.disk_cache_dir.glob("*.pkl"):
try:
@ -393,7 +394,7 @@ class ModelCacheManager:
except Exception as e:
logger.error(f"Failed to delete {cache_file}: {e}")
def get_metrics(self) -> Dict[str, Any]:
def get_metrics(self) -> dict[str, Any]:
"""
Get cache performance metrics.
@ -403,20 +404,20 @@ class ModelCacheManager:
metrics = self.model_cache.get_metrics()
metrics["cache_size"] = self.model_cache.size()
metrics["max_size"] = self.model_cache.max_size
if self.disk_cache_dir and self.disk_cache_dir.exists():
disk_files = list(self.disk_cache_dir.glob("*.pkl"))
metrics["disk_cache_files"] = len(disk_files)
# Calculate total disk cache size
total_size = sum(f.stat().st_size for f in disk_files)
metrics["disk_cache_size_mb"] = f"{total_size / 1024 / 1024:.2f}"
return metrics
def warm_up(
self,
model_loaders: Dict[str, Callable[[], Any]],
model_loaders: dict[str, Callable[[], Any]],
) -> None:
"""
Pre-load models on startup (warm-up).
@ -426,12 +427,12 @@ class ModelCacheManager:
"""
logger.info(f"Starting model warm-up ({len(model_loaders)} models)...")
start_time = time.time()
for model_key, loader_func in model_loaders.items():
try:
self.get_or_load_model(model_key, loader_func)
except Exception as e:
logger.warning(f"Failed to warm-up model {model_key}: {e}")
warm_up_time = time.time() - start_time
logger.info(f"Model warm-up completed in {warm_up_time:.2f}s")

View file

@ -14,15 +14,11 @@ from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING
from transformers import pipeline
from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ml.ner")
@ -69,10 +65,10 @@ class DocumentNER:
self.model_name = model_name
self.use_cache = use_cache
self.cache_manager = ModelCacheManager.get_instance() if use_cache else None
# Cache key for this model
cache_key = f"ner_{model_name}"
if self.use_cache and self.cache_manager:
# Load from cache or create new
def loader():
@ -81,7 +77,7 @@ class DocumentNER:
model=model_name,
aggregation_strategy="simple",
)
self.ner_pipeline = self.cache_manager.get_or_load_model(
cache_key,
loader,

View file

@ -18,18 +18,14 @@ Examples:
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ml.semantic_search")
@ -67,7 +63,7 @@ class SemanticSearch:
"""
logger.info(
f"Initializing SemanticSearch with model: {model_name} "
f"(caching: {use_cache})"
f"(caching: {use_cache})",
)
self.model_name = model_name
@ -75,10 +71,10 @@ class SemanticSearch:
self.cache_manager = ModelCacheManager.get_instance(
disk_cache_dir=cache_dir,
) if use_cache else None
# Cache key for this model
cache_key = f"semantic_search_{model_name}"
if self.use_cache and self.cache_manager:
# Load model from cache
def loader():
@ -127,11 +123,11 @@ class SemanticSearch:
if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor):
logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor")
return False
if hasattr(embedding, 'size'):
if hasattr(embedding, "size"):
if embedding.size == 0:
logger.warning(f"Embedding for doc {doc_id} is empty")
return False
elif hasattr(embedding, 'numel'):
elif hasattr(embedding, "numel"):
if embedding.numel() == 0:
logger.warning(f"Embedding for doc {doc_id} is empty")
return False
@ -216,11 +212,11 @@ class SemanticSearch:
try:
result = self.cache_manager.save_embeddings_to_disk(
"document_embeddings",
self.document_embeddings
self.document_embeddings,
)
if result:
logger.info(
f"Successfully saved {len(self.document_embeddings)} embeddings to disk"
f"Successfully saved {len(self.document_embeddings)} embeddings to disk",
)
else:
logger.error("Failed to save embeddings to disk (returned False)")

View file

@ -1596,59 +1596,59 @@ class DeletionRequest(models.Model):
This ensures no documents are deleted without explicit user consent,
implementing the safety requirement from agents.md.
"""
# Request metadata
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
# Requester (AI system)
requested_by_ai = models.BooleanField(default=True)
ai_reason = models.TextField(
help_text=_("Detailed explanation from AI about why deletion is recommended")
help_text=_("Detailed explanation from AI about why deletion is recommended"),
)
# User who must approve
user = models.ForeignKey(
User,
on_delete=models.CASCADE,
related_name='deletion_requests',
related_name="deletion_requests",
help_text=_("User who must approve this deletion"),
)
# Status tracking
STATUS_PENDING = 'pending'
STATUS_APPROVED = 'approved'
STATUS_REJECTED = 'rejected'
STATUS_CANCELLED = 'cancelled'
STATUS_COMPLETED = 'completed'
STATUS_PENDING = "pending"
STATUS_APPROVED = "approved"
STATUS_REJECTED = "rejected"
STATUS_CANCELLED = "cancelled"
STATUS_COMPLETED = "completed"
STATUS_CHOICES = [
(STATUS_PENDING, _('Pending')),
(STATUS_APPROVED, _('Approved')),
(STATUS_REJECTED, _('Rejected')),
(STATUS_CANCELLED, _('Cancelled')),
(STATUS_COMPLETED, _('Completed')),
(STATUS_PENDING, _("Pending")),
(STATUS_APPROVED, _("Approved")),
(STATUS_REJECTED, _("Rejected")),
(STATUS_CANCELLED, _("Cancelled")),
(STATUS_COMPLETED, _("Completed")),
]
status = models.CharField(
max_length=20,
choices=STATUS_CHOICES,
default=STATUS_PENDING,
)
# Documents to be deleted
documents = models.ManyToManyField(
Document,
related_name='deletion_requests',
related_name="deletion_requests",
help_text=_("Documents that would be deleted if approved"),
)
# Impact summary (JSON field with details)
impact_summary = models.JSONField(
default=dict,
help_text=_("Summary of what will be affected by this deletion"),
)
# Approval tracking
reviewed_at = models.DateTimeField(null=True, blank=True)
reviewed_by = models.ForeignKey(
@ -1656,43 +1656,43 @@ class DeletionRequest(models.Model):
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name='reviewed_deletion_requests',
related_name="reviewed_deletion_requests",
help_text=_("User who reviewed and approved/rejected"),
)
review_comment = models.TextField(
blank=True,
help_text=_("User's comment when reviewing"),
)
# Completion tracking
completed_at = models.DateTimeField(null=True, blank=True)
completion_details = models.JSONField(
default=dict,
help_text=_("Details about the deletion execution"),
)
class Meta:
ordering = ['-created_at']
ordering = ["-created_at"]
verbose_name = _("deletion request")
verbose_name_plural = _("deletion requests")
indexes = [
# Composite index for common listing queries (by user, filtered by status, sorted by date)
# PostgreSQL can use this index for queries on: user, user+status, user+status+created_at
models.Index(fields=['user', 'status', 'created_at'], name='delreq_user_status_created_idx'),
models.Index(fields=["user", "status", "created_at"], name="delreq_user_status_created_idx"),
# Index for queries filtering by status and date without user filter
models.Index(fields=['status', 'created_at'], name='delreq_status_created_idx'),
models.Index(fields=["status", "created_at"], name="delreq_status_created_idx"),
# Index for queries filtering by user and date (common for user-specific views)
models.Index(fields=['user', 'created_at'], name='delreq_user_created_idx'),
models.Index(fields=["user", "created_at"], name="delreq_user_created_idx"),
# Index for queries filtering by review date
models.Index(fields=['reviewed_at'], name='delreq_reviewed_at_idx'),
models.Index(fields=["reviewed_at"], name="delreq_reviewed_at_idx"),
# Index for queries filtering by completion date
models.Index(fields=['completed_at'], name='delreq_completed_at_idx'),
models.Index(fields=["completed_at"], name="delreq_completed_at_idx"),
]
def __str__(self):
doc_count = self.documents.count()
return f"Deletion Request {self.id} - {doc_count} documents - {self.status}"
def approve(self, user: User, comment: str = "") -> bool:
"""
Approve the deletion request.
@ -1706,15 +1706,15 @@ class DeletionRequest(models.Model):
"""
if self.status != self.STATUS_PENDING:
return False
self.status = self.STATUS_APPROVED
self.reviewed_by = user
self.reviewed_at = timezone.now()
self.review_comment = comment
self.save()
return True
def reject(self, user: User, comment: str = "") -> bool:
"""
Reject the deletion request.
@ -1728,13 +1728,13 @@ class DeletionRequest(models.Model):
"""
if self.status != self.STATUS_PENDING:
return False
self.status = self.STATUS_REJECTED
self.reviewed_by = user
self.reviewed_at = timezone.now()
self.review_comment = comment
self.save()
return True
@ -1743,109 +1743,109 @@ class AISuggestionFeedback(models.Model):
Model to track user feedback on AI suggestions (applied/rejected).
Used for improving AI accuracy and providing statistics.
"""
# Suggestion types
TYPE_TAG = 'tag'
TYPE_CORRESPONDENT = 'correspondent'
TYPE_DOCUMENT_TYPE = 'document_type'
TYPE_STORAGE_PATH = 'storage_path'
TYPE_CUSTOM_FIELD = 'custom_field'
TYPE_WORKFLOW = 'workflow'
TYPE_TITLE = 'title'
TYPE_TAG = "tag"
TYPE_CORRESPONDENT = "correspondent"
TYPE_DOCUMENT_TYPE = "document_type"
TYPE_STORAGE_PATH = "storage_path"
TYPE_CUSTOM_FIELD = "custom_field"
TYPE_WORKFLOW = "workflow"
TYPE_TITLE = "title"
SUGGESTION_TYPES = (
(TYPE_TAG, _('Tag')),
(TYPE_CORRESPONDENT, _('Correspondent')),
(TYPE_DOCUMENT_TYPE, _('Document Type')),
(TYPE_STORAGE_PATH, _('Storage Path')),
(TYPE_CUSTOM_FIELD, _('Custom Field')),
(TYPE_WORKFLOW, _('Workflow')),
(TYPE_TITLE, _('Title')),
(TYPE_TAG, _("Tag")),
(TYPE_CORRESPONDENT, _("Correspondent")),
(TYPE_DOCUMENT_TYPE, _("Document Type")),
(TYPE_STORAGE_PATH, _("Storage Path")),
(TYPE_CUSTOM_FIELD, _("Custom Field")),
(TYPE_WORKFLOW, _("Workflow")),
(TYPE_TITLE, _("Title")),
)
# Feedback status
STATUS_APPLIED = 'applied'
STATUS_REJECTED = 'rejected'
STATUS_APPLIED = "applied"
STATUS_REJECTED = "rejected"
FEEDBACK_STATUS = (
(STATUS_APPLIED, _('Applied')),
(STATUS_REJECTED, _('Rejected')),
(STATUS_APPLIED, _("Applied")),
(STATUS_REJECTED, _("Rejected")),
)
document = models.ForeignKey(
Document,
on_delete=models.CASCADE,
related_name='ai_suggestion_feedbacks',
verbose_name=_('document'),
related_name="ai_suggestion_feedbacks",
verbose_name=_("document"),
)
suggestion_type = models.CharField(
_('suggestion type'),
_("suggestion type"),
max_length=50,
choices=SUGGESTION_TYPES,
)
suggested_value_id = models.IntegerField(
_('suggested value ID'),
_("suggested value ID"),
null=True,
blank=True,
help_text=_('ID of the suggested object (tag, correspondent, etc.)'),
help_text=_("ID of the suggested object (tag, correspondent, etc.)"),
)
suggested_value_text = models.TextField(
_('suggested value text'),
_("suggested value text"),
blank=True,
help_text=_('Text representation of the suggested value'),
help_text=_("Text representation of the suggested value"),
)
confidence = models.FloatField(
_('confidence'),
help_text=_('AI confidence score (0.0 to 1.0)'),
_("confidence"),
help_text=_("AI confidence score (0.0 to 1.0)"),
validators=[MinValueValidator(0.0), MaxValueValidator(1.0)],
)
status = models.CharField(
_('status'),
_("status"),
max_length=20,
choices=FEEDBACK_STATUS,
)
user = models.ForeignKey(
User,
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name='ai_suggestion_feedbacks',
verbose_name=_('user'),
help_text=_('User who applied or rejected the suggestion'),
related_name="ai_suggestion_feedbacks",
verbose_name=_("user"),
help_text=_("User who applied or rejected the suggestion"),
)
created_at = models.DateTimeField(
_('created at'),
_("created at"),
auto_now_add=True,
)
applied_at = models.DateTimeField(
_('applied/rejected at'),
_("applied/rejected at"),
auto_now=True,
)
metadata = models.JSONField(
_('metadata'),
_("metadata"),
default=dict,
blank=True,
help_text=_('Additional metadata about the suggestion'),
help_text=_("Additional metadata about the suggestion"),
)
class Meta:
verbose_name = _('AI suggestion feedback')
verbose_name_plural = _('AI suggestion feedbacks')
ordering = ['-created_at']
verbose_name = _("AI suggestion feedback")
verbose_name_plural = _("AI suggestion feedbacks")
ordering = ["-created_at"]
indexes = [
models.Index(fields=['document', 'suggestion_type']),
models.Index(fields=['status', 'created_at']),
models.Index(fields=['suggestion_type', 'status']),
models.Index(fields=["document", "suggestion_type"]),
models.Index(fields=["status", "created_at"]),
models.Index(fields=["suggestion_type", "status"]),
]
def __str__(self):
return f"{self.suggestion_type} suggestion for document {self.document_id} - {self.status}"

View file

@ -11,21 +11,21 @@ Lazy imports are used to avoid loading heavy dependencies unless needed.
"""
__all__ = [
'TableExtractor',
'HandwritingRecognizer',
'FormFieldDetector',
"FormFieldDetector",
"HandwritingRecognizer",
"TableExtractor",
]
def __getattr__(name):
"""Lazy import to avoid loading heavy ML models on startup."""
if name == 'TableExtractor':
if name == "TableExtractor":
from .table_extractor import TableExtractor
return TableExtractor
elif name == 'HandwritingRecognizer':
elif name == "HandwritingRecognizer":
from .handwriting import HandwritingRecognizer
return HandwritingRecognizer
elif name == 'FormFieldDetector':
elif name == "FormFieldDetector":
from .form_detector import FormFieldDetector
return FormFieldDetector
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View file

@ -8,8 +8,8 @@ This module provides capabilities to:
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from typing import Any
import numpy as np
from PIL import Image
@ -37,7 +37,7 @@ class FormFieldDetector:
>>> for cb in checkboxes:
... print(f"{cb['label']}: {'' if cb['checked'] else ''}")
"""
def __init__(self, use_gpu: bool = True):
"""
Initialize the form field detector.
@ -47,20 +47,20 @@ class FormFieldDetector:
"""
self.use_gpu = use_gpu
self._handwriting_recognizer = None
def _get_handwriting_recognizer(self):
"""Lazy load handwriting recognizer for field value extraction."""
if self._handwriting_recognizer is None:
from .handwriting import HandwritingRecognizer
self._handwriting_recognizer = HandwritingRecognizer(use_gpu=self.use_gpu)
return self._handwriting_recognizer
def detect_checkboxes(
self,
self,
image: Image.Image,
min_size: int = 10,
max_size: int = 50
) -> List[Dict[str, Any]]:
max_size: int = 50,
) -> list[dict[str, Any]]:
"""
Detect checkboxes in a form image.
@ -82,54 +82,54 @@ class FormFieldDetector:
"""
try:
import cv2
# Convert to OpenCV format
img_array = np.array(image)
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Detect edges
edges = cv2.Canny(gray, 50, 150)
# Find contours
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
checkboxes = []
for contour in contours:
# Get bounding box
x, y, w, h = cv2.boundingRect(contour)
# Check if it looks like a checkbox (square-ish, right size)
aspect_ratio = w / h if h > 0 else 0
if (min_size <= w <= max_size and
min_size <= h <= max_size and
if (min_size <= w <= max_size and
min_size <= h <= max_size and
0.7 <= aspect_ratio <= 1.3):
# Extract checkbox region
checkbox_region = gray[y:y+h, x:x+w]
# Determine if checked (look for marks inside)
checked, confidence = self._is_checkbox_checked(checkbox_region)
checkboxes.append({
'bbox': [x, y, x+w, y+h],
'checked': checked,
'confidence': confidence
"bbox": [x, y, x+w, y+h],
"checked": checked,
"confidence": confidence,
})
logger.info(f"Detected {len(checkboxes)} checkboxes")
return checkboxes
except ImportError:
logger.error("opencv-python not installed. Install with: pip install opencv-python")
return []
except Exception as e:
logger.error(f"Error detecting checkboxes: {e}")
return []
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> Tuple[bool, float]:
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> tuple[bool, float]:
"""
Determine if a checkbox is checked.
@ -141,34 +141,34 @@ class FormFieldDetector:
"""
try:
import cv2
# Binarize
_, binary = cv2.threshold(checkbox_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Count dark pixels in the center region (where mark would be)
h, w = binary.shape
center_region = binary[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)]
if center_region.size == 0:
return False, 0.0
dark_pixel_ratio = np.sum(center_region > 0) / center_region.size
# If more than 15% of center is dark, consider it checked
checked = dark_pixel_ratio > 0.15
confidence = min(dark_pixel_ratio * 2, 1.0) # Scale confidence
return checked, confidence
except Exception as e:
logger.warning(f"Error checking checkbox state: {e}")
return False, 0.0
def detect_text_fields(
self,
self,
image: Image.Image,
min_width: int = 100
) -> List[Dict[str, Any]]:
min_width: int = 100,
) -> list[dict[str, Any]]:
"""
Detect text input fields in a form.
@ -188,73 +188,73 @@ class FormFieldDetector:
"""
try:
import cv2
# Convert to OpenCV format
img_array = np.array(image)
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Detect horizontal lines (underlines for text fields)
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_width, 1))
detect_horizontal = cv2.morphologyEx(
cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1],
cv2.MORPH_OPEN,
horizontal_kernel,
iterations=2
iterations=2,
)
# Find contours of horizontal lines
contours, _ = cv2.findContours(
detect_horizontal,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
detect_horizontal,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE,
)
text_fields = []
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
# Check if it's a horizontal line (field underline)
if w >= min_width and h < 10:
# Expand upward to include text area
text_bbox = [x, max(0, y-30), x+w, y+h]
text_fields.append({
'bbox': text_bbox,
'type': 'line'
"bbox": text_bbox,
"type": "line",
})
# Detect rectangular boxes (bordered text fields)
edges = cv2.Canny(gray, 50, 150)
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
# Check if it's a rectangular box
aspect_ratio = w / h if h > 0 else 0
if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2:
text_fields.append({
'bbox': [x, y, x+w, y+h],
'type': 'box'
"bbox": [x, y, x+w, y+h],
"type": "box",
})
logger.info(f"Detected {len(text_fields)} text fields")
return text_fields
except ImportError:
logger.error("opencv-python not installed")
return []
except Exception as e:
logger.error(f"Error detecting text fields: {e}")
return []
def detect_labels(
self,
self,
image: Image.Image,
field_bboxes: List[List[int]]
) -> List[Dict[str, Any]]:
field_bboxes: list[list[int]],
) -> list[dict[str, Any]]:
"""
Detect labels near form fields.
@ -267,47 +267,47 @@ class FormFieldDetector:
"""
try:
import pytesseract
# Get all text with bounding boxes
ocr_data = pytesseract.image_to_data(
image,
output_type=pytesseract.Output.DICT
image,
output_type=pytesseract.Output.DICT,
)
# Group text into potential labels
labels = []
for i, text in enumerate(ocr_data['text']):
for i, text in enumerate(ocr_data["text"]):
if text.strip() and len(text.strip()) > 2:
x = ocr_data['left'][i]
y = ocr_data['top'][i]
w = ocr_data['width'][i]
h = ocr_data['height'][i]
x = ocr_data["left"][i]
y = ocr_data["top"][i]
w = ocr_data["width"][i]
h = ocr_data["height"][i]
label_bbox = [x, y, x+w, y+h]
# Find closest field
closest_field_idx = self._find_closest_field(label_bbox, field_bboxes)
labels.append({
'text': text.strip(),
'bbox': label_bbox,
'field_index': closest_field_idx
"text": text.strip(),
"bbox": label_bbox,
"field_index": closest_field_idx,
})
return labels
except ImportError:
logger.error("pytesseract not installed")
return []
except Exception as e:
logger.error(f"Error detecting labels: {e}")
return []
def _find_closest_field(
self,
label_bbox: List[int],
field_bboxes: List[List[int]]
) -> Optional[int]:
self,
label_bbox: list[int],
field_bboxes: list[list[int]],
) -> int | None:
"""
Find the closest field to a label.
@ -320,36 +320,36 @@ class FormFieldDetector:
"""
if not field_bboxes:
return None
# Calculate center of label
label_center_x = (label_bbox[0] + label_bbox[2]) / 2
label_center_y = (label_bbox[1] + label_bbox[3]) / 2
min_distance = float('inf')
min_distance = float("inf")
closest_idx = 0
for i, field_bbox in enumerate(field_bboxes):
# Calculate center of field
field_center_x = (field_bbox[0] + field_bbox[2]) / 2
field_center_y = (field_bbox[1] + field_bbox[3]) / 2
# Euclidean distance
distance = np.sqrt(
(label_center_x - field_center_x)**2 +
(label_center_y - field_center_y)**2
(label_center_x - field_center_x)**2 +
(label_center_y - field_center_y)**2,
)
if distance < min_distance:
min_distance = distance
closest_idx = i
return closest_idx
def detect_form_fields(
self,
self,
image_path: str,
extract_values: bool = True
) -> List[Dict[str, Any]]:
extract_values: bool = True,
) -> list[dict[str, Any]]:
"""
Detect all form fields and extract their values.
@ -372,69 +372,69 @@ class FormFieldDetector:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Detect different field types
text_fields = self.detect_text_fields(image)
checkboxes = self.detect_checkboxes(image)
# Combine all field bboxes for label detection
all_field_bboxes = [f['bbox'] for f in text_fields] + [cb['bbox'] for cb in checkboxes]
all_field_bboxes = [f["bbox"] for f in text_fields] + [cb["bbox"] for cb in checkboxes]
# Detect labels
labels = self.detect_labels(image, all_field_bboxes)
# Build results
results = []
# Add text fields
for i, field in enumerate(text_fields):
# Find associated label
label_text = self._find_label_for_field(i, labels, len(text_fields))
result = {
'type': 'text',
'label': label_text,
'bbox': field['bbox'],
"type": "text",
"label": label_text,
"bbox": field["bbox"],
}
# Extract value if requested
if extract_values:
x1, y1, x2, y2 = field['bbox']
x1, y1, x2, y2 = field["bbox"]
field_image = image.crop((x1, y1, x2, y2))
recognizer = self._get_handwriting_recognizer()
value = recognizer.recognize_from_image(field_image, preprocess=True)
result['value'] = value.strip()
result['confidence'] = recognizer._estimate_confidence(value)
result["value"] = value.strip()
result["confidence"] = recognizer._estimate_confidence(value)
results.append(result)
# Add checkboxes
for i, checkbox in enumerate(checkboxes):
field_idx = len(text_fields) + i
label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes))
results.append({
'type': 'checkbox',
'label': label_text,
'value': checkbox['checked'],
'bbox': checkbox['bbox'],
'confidence': checkbox['confidence']
"type": "checkbox",
"label": label_text,
"value": checkbox["checked"],
"bbox": checkbox["bbox"],
"confidence": checkbox["confidence"],
})
logger.info(f"Detected {len(results)} form fields from {image_path}")
return results
except Exception as e:
logger.error(f"Error detecting form fields: {e}")
return []
def _find_label_for_field(
self,
field_idx: int,
labels: List[Dict[str, Any]],
total_fields: int
self,
field_idx: int,
labels: list[dict[str, Any]],
total_fields: int,
) -> str:
"""
Find the label text for a specific field.
@ -448,20 +448,20 @@ class FormFieldDetector:
Label text or empty string if not found
"""
matching_labels = [
label for label in labels
if label['field_index'] == field_idx
label for label in labels
if label["field_index"] == field_idx
]
if matching_labels:
# Combine multiple label parts if found
return ' '.join(label['text'] for label in matching_labels)
return " ".join(label["text"] for label in matching_labels)
return f"Field_{field_idx + 1}"
def extract_form_data(
self,
self,
image_path: str,
output_format: str = 'dict'
output_format: str = "dict",
) -> Any:
"""
Extract all form data as structured output.
@ -475,19 +475,19 @@ class FormFieldDetector:
"""
# Detect and extract fields
fields = self.detect_form_fields(image_path, extract_values=True)
if output_format == 'dict':
if output_format == "dict":
# Return as dictionary
return {field['label']: field['value'] for field in fields}
elif output_format == 'json':
return {field["label"]: field["value"] for field in fields}
elif output_format == "json":
import json
data = {field['label']: field['value'] for field in fields}
data = {field["label"]: field["value"] for field in fields}
return json.dumps(data, indent=2)
elif output_format == 'dataframe':
elif output_format == "dataframe":
import pandas as pd
return pd.DataFrame(fields)
else:
raise ValueError(f"Invalid output format: {output_format}")

View file

@ -8,8 +8,8 @@ This module provides handwriting OCR capabilities using:
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from typing import Any
import numpy as np
from PIL import Image
@ -34,7 +34,7 @@ class HandwritingRecognizer:
>>> for line in lines:
... print(f"{line['text']} (confidence: {line['confidence']:.2f})")
"""
def __init__(
self,
model_name: str = "microsoft/trocr-base-handwritten",
@ -58,39 +58,40 @@ class HandwritingRecognizer:
self.confidence_threshold = confidence_threshold
self._model = None
self._processor = None
def _load_model(self):
"""Lazy load the handwriting recognition model."""
if self._model is not None:
return
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
logger.info(f"Loading handwriting recognition model: {self.model_name}")
self._processor = TrOCRProcessor.from_pretrained(self.model_name)
self._model = VisionEncoderDecoderModel.from_pretrained(self.model_name)
# Move to GPU if available and requested
if self.use_gpu and torch.cuda.is_available():
self._model = self._model.cuda()
logger.info("Using GPU for handwriting recognition")
else:
logger.info("Using CPU for handwriting recognition")
self._model.eval() # Set to evaluation mode
except ImportError as e:
logger.error(f"Failed to load handwriting model: {e}")
logger.error("Please install: pip install transformers torch pillow")
raise
def recognize_from_image(
self,
self,
image: Image.Image,
preprocess: bool = True
preprocess: bool = True,
) -> str:
"""
Recognize text from a single image.
@ -103,34 +104,34 @@ class HandwritingRecognizer:
Recognized text string
"""
self._load_model()
try:
import torch
# Preprocess image if requested
if preprocess:
image = self._preprocess_image(image)
# Prepare image for model
pixel_values = self._processor(images=image, return_tensors="pt").pixel_values
if self.use_gpu and torch.cuda.is_available():
pixel_values = pixel_values.cuda()
# Generate text
with torch.no_grad():
generated_ids = self._model.generate(pixel_values)
# Decode to text
text = self._processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
logger.debug(f"Recognized text: {text[:100]}...")
return text
except Exception as e:
logger.error(f"Error recognizing handwriting: {e}")
return ""
def _preprocess_image(self, image: Image.Image) -> Image.Image:
"""
Preprocess image for better recognition.
@ -142,29 +143,30 @@ class HandwritingRecognizer:
Preprocessed PIL Image
"""
try:
from PIL import ImageEnhance, ImageFilter
from PIL import ImageEnhance
from PIL import ImageFilter
# Convert to grayscale
if image.mode != 'L':
image = image.convert('L')
if image.mode != "L":
image = image.convert("L")
# Enhance contrast
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(2.0)
# Denoise
image = image.filter(ImageFilter.MedianFilter(size=3))
# Convert back to RGB (required by model)
image = image.convert('RGB')
image = image.convert("RGB")
return image
except Exception as e:
logger.warning(f"Error preprocessing image: {e}")
return image
def detect_text_lines(self, image: Image.Image) -> List[Dict[str, Any]]:
def detect_text_lines(self, image: Image.Image) -> list[dict[str, Any]]:
"""
Detect individual text lines in an image.
@ -184,52 +186,52 @@ class HandwritingRecognizer:
try:
import cv2
import numpy as np
# Convert PIL to OpenCV format
img_array = np.array(image)
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Binarize
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Find contours
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Get bounding boxes for each contour
lines = []
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
# Filter out very small regions
if w > 20 and h > 10:
# Crop line from original image
line_img = image.crop((x, y, x+w, y+h))
lines.append({
'bbox': [x, y, x+w, y+h],
'image': line_img
"bbox": [x, y, x+w, y+h],
"image": line_img,
})
# Sort lines top to bottom
lines.sort(key=lambda l: l['bbox'][1])
lines.sort(key=lambda l: l["bbox"][1])
logger.info(f"Detected {len(lines)} text lines")
return lines
except ImportError:
logger.error("opencv-python not installed. Install with: pip install opencv-python")
return []
except Exception as e:
logger.error(f"Error detecting text lines: {e}")
return []
def recognize_lines(
self,
self,
image_path: str,
return_confidence: bool = True
) -> List[Dict[str, Any]]:
return_confidence: bool = True,
) -> list[dict[str, Any]]:
"""
Recognize text from each line in an image.
@ -250,38 +252,38 @@ class HandwritingRecognizer:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Detect lines
lines = self.detect_text_lines(image)
# Recognize each line
results = []
for i, line in enumerate(lines):
logger.debug(f"Recognizing line {i+1}/{len(lines)}")
text = self.recognize_from_image(line['image'], preprocess=True)
text = self.recognize_from_image(line["image"], preprocess=True)
result = {
'text': text,
'bbox': line['bbox'],
'line_index': i
"text": text,
"bbox": line["bbox"],
"line_index": i,
}
if return_confidence:
# Simple confidence based on text length and content
confidence = self._estimate_confidence(text)
result['confidence'] = confidence
result["confidence"] = confidence
results.append(result)
logger.info(f"Recognized {len(results)} lines from {image_path}")
return results
except Exception as e:
logger.error(f"Error recognizing lines from {image_path}: {e}")
return []
def _estimate_confidence(self, text: str) -> float:
"""
Estimate confidence of recognition result.
@ -294,36 +296,36 @@ class HandwritingRecognizer:
"""
if not text:
return 0.0
# Factors that indicate good recognition
score = 0.5 # Base score
# Longer text tends to be more reliable
if len(text) > 10:
score += 0.1
if len(text) > 20:
score += 0.1
# Text with alphanumeric characters is more reliable
if any(c.isalnum() for c in text):
score += 0.1
# Text with spaces (words) is more reliable
if ' ' in text:
if " " in text:
score += 0.1
# Penalize if too many special characters
special_chars = sum(1 for c in text if not c.isalnum() and not c.isspace())
if special_chars / len(text) > 0.5:
score -= 0.2
return max(0.0, min(1.0, score))
def recognize_from_file(
self,
self,
image_path: str,
mode: str = 'full'
) -> Dict[str, Any]:
mode: str = "full",
) -> dict[str, Any]:
"""
Recognize handwriting from an image file.
@ -337,49 +339,49 @@ class HandwritingRecognizer:
Dictionary with recognized text and metadata
"""
try:
if mode == 'full':
if mode == "full":
# Recognize entire image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
text = self.recognize_from_image(image, preprocess=True)
return {
'text': text,
'mode': 'full',
'confidence': self._estimate_confidence(text)
"text": text,
"mode": "full",
"confidence": self._estimate_confidence(text),
}
elif mode == 'lines':
elif mode == "lines":
# Recognize line by line
lines = self.recognize_lines(image_path, return_confidence=True)
# Combine all lines
full_text = '\n'.join(line['text'] for line in lines)
avg_confidence = np.mean([line['confidence'] for line in lines]) if lines else 0.0
full_text = "\n".join(line["text"] for line in lines)
avg_confidence = np.mean([line["confidence"] for line in lines]) if lines else 0.0
return {
'text': full_text,
'lines': lines,
'mode': 'lines',
'confidence': float(avg_confidence)
"text": full_text,
"lines": lines,
"mode": "lines",
"confidence": float(avg_confidence),
}
else:
raise ValueError(f"Invalid mode: {mode}. Use 'full' or 'lines'")
except Exception as e:
logger.error(f"Error recognizing from file {image_path}: {e}")
return {
'text': '',
'mode': mode,
'confidence': 0.0,
'error': str(e)
"text": "",
"mode": mode,
"confidence": 0.0,
"error": str(e),
}
def recognize_form_fields(
self,
self,
image_path: str,
field_regions: List[Dict[str, Any]]
) -> Dict[str, str]:
field_regions: list[dict[str, Any]],
) -> dict[str, str]:
"""
Recognize text from specific form fields.
@ -399,35 +401,35 @@ class HandwritingRecognizer:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Extract and recognize each field
results = {}
for field in field_regions:
name = field['name']
bbox = field['bbox']
name = field["name"]
bbox = field["bbox"]
# Crop field region
x1, y1, x2, y2 = bbox
field_image = image.crop((x1, y1, x2, y2))
# Recognize text
text = self.recognize_from_image(field_image, preprocess=True)
results[name] = text.strip()
logger.debug(f"Field '{name}': {text[:50]}...")
return results
except Exception as e:
logger.error(f"Error recognizing form fields: {e}")
return {}
def batch_recognize(
self,
image_paths: List[str],
mode: str = 'full'
) -> List[Dict[str, Any]]:
self,
image_paths: list[str],
mode: str = "full",
) -> list[dict[str, Any]]:
"""
Recognize handwriting from multiple images in batch.
@ -442,7 +444,7 @@ class HandwritingRecognizer:
for i, path in enumerate(image_paths):
logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
result = self.recognize_from_file(path, mode=mode)
result['image_path'] = path
result["image_path"] = path
results.append(result)
return results

View file

@ -8,9 +8,8 @@ This module uses various techniques to detect and extract tables from documents:
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from typing import Any
from PIL import Image
logger = logging.getLogger(__name__)
@ -32,7 +31,7 @@ class TableExtractor:
... print(table['data']) # pandas DataFrame
... print(table['bbox']) # bounding box coordinates
"""
def __init__(
self,
model_name: str = "microsoft/table-transformer-detection",
@ -52,34 +51,35 @@ class TableExtractor:
self.use_gpu = use_gpu
self._model = None
self._processor = None
def _load_model(self):
"""Lazy load the table detection model."""
if self._model is not None:
return
try:
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
logger.info(f"Loading table detection model: {self.model_name}")
self._processor = AutoImageProcessor.from_pretrained(self.model_name)
self._model = AutoModelForObjectDetection.from_pretrained(self.model_name)
# Move to GPU if available and requested
if self.use_gpu and torch.cuda.is_available():
self._model = self._model.cuda()
logger.info("Using GPU for table detection")
else:
logger.info("Using CPU for table detection")
except ImportError as e:
logger.error(f"Failed to load table detection model: {e}")
logger.error("Please install required packages: pip install transformers torch pillow")
raise
def detect_tables(self, image: Image.Image) -> List[Dict[str, Any]]:
def detect_tables(self, image: Image.Image) -> list[dict[str, Any]]:
"""
Detect tables in an image.
@ -98,50 +98,50 @@ class TableExtractor:
]
"""
self._load_model()
try:
import torch
# Prepare image
inputs = self._processor(images=image, return_tensors="pt")
if self.use_gpu and torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# Run detection
with torch.no_grad():
outputs = self._model(**inputs)
# Post-process results
target_sizes = torch.tensor([image.size[::-1]])
results = self._processor.post_process_object_detection(
outputs,
outputs,
threshold=self.confidence_threshold,
target_sizes=target_sizes
target_sizes=target_sizes,
)[0]
# Convert to list of dicts
tables = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
tables.append({
'bbox': box.cpu().tolist(),
'score': score.item(),
'label': self._model.config.id2label[label.item()]
"bbox": box.cpu().tolist(),
"score": score.item(),
"label": self._model.config.id2label[label.item()],
})
logger.info(f"Detected {len(tables)} tables in image")
return tables
except Exception as e:
logger.error(f"Error detecting tables: {e}")
return []
def extract_table_from_region(
self,
image: Image.Image,
bbox: List[float],
use_ocr: bool = True
) -> Optional[Dict[str, Any]]:
self,
image: Image.Image,
bbox: list[float],
use_ocr: bool = True,
) -> dict[str, Any] | None:
"""
Extract table data from a specific region of an image.
@ -158,48 +158,48 @@ class TableExtractor:
# Crop to table region
x1, y1, x2, y2 = [int(coord) for coord in bbox]
table_image = image.crop((x1, y1, x2, y2))
if use_ocr:
# Use OCR to extract text and structure
import pytesseract
# Get detailed OCR data
ocr_data = pytesseract.image_to_data(
table_image,
output_type=pytesseract.Output.DICT
table_image,
output_type=pytesseract.Output.DICT,
)
# Reconstruct table structure from OCR data
table_data = self._reconstruct_table_from_ocr(ocr_data)
# Also get raw text
raw_text = pytesseract.image_to_string(table_image)
return {
'data': table_data,
'raw_text': raw_text,
'bbox': bbox,
'image_size': table_image.size
"data": table_data,
"raw_text": raw_text,
"bbox": bbox,
"image_size": table_image.size,
}
else:
# Fallback to basic OCR without structure
import pytesseract
raw_text = pytesseract.image_to_string(table_image)
return {
'data': None,
'raw_text': raw_text,
'bbox': bbox,
'image_size': table_image.size
"data": None,
"raw_text": raw_text,
"bbox": bbox,
"image_size": table_image.size,
}
except ImportError:
logger.error("pytesseract not installed. Install with: pip install pytesseract")
return None
except Exception as e:
logger.error(f"Error extracting table from region: {e}")
return None
def _reconstruct_table_from_ocr(self, ocr_data: Dict) -> Optional[Any]:
def _reconstruct_table_from_ocr(self, ocr_data: dict) -> Any | None:
"""
Reconstruct table structure from OCR output.
@ -211,58 +211,58 @@ class TableExtractor:
"""
try:
import pandas as pd
# Group text by vertical position (rows)
rows = {}
for i, text in enumerate(ocr_data['text']):
for i, text in enumerate(ocr_data["text"]):
if text.strip():
top = ocr_data['top'][i]
left = ocr_data['left'][i]
top = ocr_data["top"][i]
left = ocr_data["left"][i]
# Group by approximate row (within 20 pixels)
row_key = round(top / 20) * 20
if row_key not in rows:
rows[row_key] = []
rows[row_key].append((left, text))
# Sort rows and create DataFrame
table_rows = []
for row_y in sorted(rows.keys()):
# Sort cells by horizontal position
cells = [text for _, text in sorted(rows[row_y])]
table_rows.append(cells)
if table_rows:
# Pad rows to same length
max_cols = max(len(row) for row in table_rows)
table_rows = [row + [''] * (max_cols - len(row)) for row in table_rows]
table_rows = [row + [""] * (max_cols - len(row)) for row in table_rows]
# Create DataFrame
df = pd.DataFrame(table_rows)
# Try to use first row as header if it looks like one
if len(df) > 1:
first_row_text = ' '.join(str(x) for x in df.iloc[0])
first_row_text = " ".join(str(x) for x in df.iloc[0])
if not any(char.isdigit() for char in first_row_text):
df.columns = df.iloc[0]
df = df[1:].reset_index(drop=True)
return df
return None
except ImportError:
logger.error("pandas not installed. Install with: pip install pandas")
return None
except Exception as e:
logger.error(f"Error reconstructing table: {e}")
return None
def extract_tables_from_image(
self,
self,
image_path: str,
output_format: str = 'dataframe'
) -> List[Dict[str, Any]]:
output_format: str = "dataframe",
) -> list[dict[str, Any]]:
"""
Extract all tables from an image file.
@ -275,45 +275,45 @@ class TableExtractor:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Detect tables
detections = self.detect_tables(image)
# Extract data from each table
tables = []
for i, detection in enumerate(detections):
logger.info(f"Extracting table {i+1}/{len(detections)}")
table_data = self.extract_table_from_region(
image,
detection['bbox']
image,
detection["bbox"],
)
if table_data:
table_data['detection_score'] = detection['score']
table_data['table_index'] = i
table_data["detection_score"] = detection["score"]
table_data["table_index"] = i
# Convert to requested format
if output_format == 'csv' and table_data['data'] is not None:
table_data['csv'] = table_data['data'].to_csv(index=False)
elif output_format == 'json' and table_data['data'] is not None:
table_data['json'] = table_data['data'].to_json(orient='records')
if output_format == "csv" and table_data["data"] is not None:
table_data["csv"] = table_data["data"].to_csv(index=False)
elif output_format == "json" and table_data["data"] is not None:
table_data["json"] = table_data["data"].to_json(orient="records")
tables.append(table_data)
logger.info(f"Successfully extracted {len(tables)} tables from {image_path}")
return tables
except Exception as e:
logger.error(f"Error extracting tables from image {image_path}: {e}")
return []
def extract_tables_from_pdf(
self,
self,
pdf_path: str,
page_numbers: Optional[List[int]] = None
) -> Dict[int, List[Dict[str, Any]]]:
page_numbers: list[int] | None = None,
) -> dict[int, list[dict[str, Any]]]:
"""
Extract tables from a PDF document.
@ -326,56 +326,56 @@ class TableExtractor:
"""
try:
from pdf2image import convert_from_path
logger.info(f"Converting PDF to images: {pdf_path}")
# Convert PDF pages to images
if page_numbers:
images = convert_from_path(
pdf_path,
pdf_path,
first_page=min(page_numbers),
last_page=max(page_numbers)
last_page=max(page_numbers),
)
else:
images = convert_from_path(pdf_path)
# Extract tables from each page
results = {}
for i, image in enumerate(images):
page_num = page_numbers[i] if page_numbers else i + 1
logger.info(f"Processing page {page_num}")
# Detect and extract tables
detections = self.detect_tables(image)
tables = []
for detection in detections:
table_data = self.extract_table_from_region(
image,
detection['bbox']
image,
detection["bbox"],
)
if table_data:
table_data['detection_score'] = detection['score']
table_data['page'] = page_num
table_data["detection_score"] = detection["score"]
table_data["page"] = page_num
tables.append(table_data)
if tables:
results[page_num] = tables
logger.info(f"Found {len(tables)} tables on page {page_num}")
return results
except ImportError:
logger.error("pdf2image not installed. Install with: pip install pdf2image")
return {}
except Exception as e:
logger.error(f"Error extracting tables from PDF: {e}")
return {}
def save_tables_to_excel(
self,
tables: List[Dict[str, Any]],
output_path: str
self,
tables: list[dict[str, Any]],
output_path: str,
) -> bool:
"""
Save extracted tables to an Excel file.
@ -389,23 +389,23 @@ class TableExtractor:
"""
try:
import pandas as pd
with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
for i, table in enumerate(tables):
if table.get('data') is not None:
if table.get("data") is not None:
sheet_name = f"Table_{i+1}"
if 'page' in table:
if "page" in table:
sheet_name = f"Page_{table['page']}_Table_{i+1}"
table['data'].to_excel(
writer,
sheet_name=sheet_name,
index=False
table["data"].to_excel(
writer,
sheet_name=sheet_name,
index=False,
)
logger.info(f"Saved {len(tables)} tables to {output_path}")
return True
except ImportError:
logger.error("openpyxl not installed. Install with: pip install openpyxl")
return False

View file

@ -233,11 +233,11 @@ class CanViewAISuggestionsPermission(BasePermission):
def has_permission(self, request, view):
if not request.user or not request.user.is_authenticated:
return False
# Superusers always have permission
if request.user.is_superuser:
return True
# Check for specific permission
return request.user.has_perm("documents.can_view_ai_suggestions")
@ -253,11 +253,11 @@ class CanApplyAISuggestionsPermission(BasePermission):
def has_permission(self, request, view):
if not request.user or not request.user.is_authenticated:
return False
# Superusers always have permission
if request.user.is_superuser:
return True
# Check for specific permission
return request.user.has_perm("documents.can_apply_ai_suggestions")
@ -273,11 +273,11 @@ class CanApproveDeletionsPermission(BasePermission):
def has_permission(self, request, view):
if not request.user or not request.user.is_authenticated:
return False
# Superusers always have permission
if request.user.is_superuser:
return True
# Check for specific permission
return request.user.has_perm("documents.can_approve_deletions")
@ -294,10 +294,10 @@ class CanConfigureAIPermission(BasePermission):
def has_permission(self, request, view):
if not request.user or not request.user.is_authenticated:
return False
# Superusers always have permission
if request.user.is_superuser:
return True
# Check for specific permission
return request.user.has_perm("documents.can_configure_ai")

View file

@ -46,7 +46,6 @@ if settings.AUDIT_LOG_ENABLED:
from documents import bulk_edit
from documents.data_models import DocumentSource
from documents.filters import CustomFieldQueryParser
from documents.models import AISuggestionFeedback
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import CustomFieldInstance
@ -2786,57 +2785,57 @@ class DeletionRequestSerializer(serializers.ModelSerializer):
class DeletionRequestDetailSerializer(serializers.ModelSerializer):
"""Serializer for DeletionRequest model with document details."""
document_details = serializers.SerializerMethodField()
user_username = serializers.CharField(source='user.username', read_only=True)
user_username = serializers.CharField(source="user.username", read_only=True)
reviewed_by_username = serializers.CharField(
source='reviewed_by.username',
source="reviewed_by.username",
read_only=True,
allow_null=True,
)
class Meta:
from documents.models import DeletionRequest
model = DeletionRequest
fields = [
'id',
'created_at',
'updated_at',
'requested_by_ai',
'ai_reason',
'user',
'user_username',
'status',
'impact_summary',
'reviewed_at',
'reviewed_by',
'reviewed_by_username',
'review_comment',
'completed_at',
'completion_details',
'document_details',
"id",
"created_at",
"updated_at",
"requested_by_ai",
"ai_reason",
"user",
"user_username",
"status",
"impact_summary",
"reviewed_at",
"reviewed_by",
"reviewed_by_username",
"review_comment",
"completed_at",
"completion_details",
"document_details",
]
read_only_fields = [
'id',
'created_at',
'updated_at',
'reviewed_at',
'reviewed_by',
'completed_at',
'completion_details',
"id",
"created_at",
"updated_at",
"reviewed_at",
"reviewed_by",
"completed_at",
"completion_details",
]
def get_document_details(self, obj):
"""Get details of documents in this deletion request."""
documents = obj.documents.all()
return [
{
'id': doc.id,
'title': doc.title,
'created': doc.created.isoformat() if doc.created else None,
'correspondent': doc.correspondent.name if doc.correspondent else None,
'document_type': doc.document_type.name if doc.document_type else None,
'tags': [tag.name for tag in doc.tags.all()],
"id": doc.id,
"title": doc.title,
"created": doc.created.isoformat() if doc.created else None,
"correspondent": doc.correspondent.name if doc.correspondent else None,
"document_type": doc.document_type.name if doc.document_type else None,
"tags": [tag.name for tag in doc.tags.all()],
}
for doc in documents
]
@ -2852,6 +2851,9 @@ class DeletionRequestActionSerializer(serializers.Serializer):
allow_blank=True,
label="Review Comment",
help_text="Optional comment when reviewing the deletion request",
)
class AISuggestionsRequestSerializer(serializers.Serializer):
"""Serializer for requesting AI suggestions for a document."""

View file

@ -1,17 +1,15 @@
"""Serializers package for documents app."""
from .ai_suggestions import (
AISuggestionFeedbackSerializer,
AISuggestionsSerializer,
AISuggestionStatsSerializer,
ApplySuggestionSerializer,
RejectSuggestionSerializer,
)
from .ai_suggestions import AISuggestionFeedbackSerializer
from .ai_suggestions import AISuggestionsSerializer
from .ai_suggestions import AISuggestionStatsSerializer
from .ai_suggestions import ApplySuggestionSerializer
from .ai_suggestions import RejectSuggestionSerializer
__all__ = [
'AISuggestionFeedbackSerializer',
'AISuggestionsSerializer',
'AISuggestionStatsSerializer',
'ApplySuggestionSerializer',
'RejectSuggestionSerializer',
"AISuggestionFeedbackSerializer",
"AISuggestionStatsSerializer",
"AISuggestionsSerializer",
"ApplySuggestionSerializer",
"RejectSuggestionSerializer",
]

View file

@ -7,42 +7,39 @@ and handling user feedback on AI suggestions.
from __future__ import annotations
from typing import Any, Dict
from typing import Any
from rest_framework import serializers
from documents.models import (
AISuggestionFeedback,
Correspondent,
CustomField,
DocumentType,
StoragePath,
Tag,
Workflow,
)
from documents.models import AISuggestionFeedback
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.models import Workflow
# Suggestion type choices - used across multiple serializers
SUGGESTION_TYPE_CHOICES = [
'tag',
'correspondent',
'document_type',
'storage_path',
'custom_field',
'workflow',
'title',
"tag",
"correspondent",
"document_type",
"storage_path",
"custom_field",
"workflow",
"title",
]
# Types that require value_id
ID_REQUIRED_TYPES = ['tag', 'correspondent', 'document_type', 'storage_path', 'workflow']
ID_REQUIRED_TYPES = ["tag", "correspondent", "document_type", "storage_path", "workflow"]
# Types that require value_text
TEXT_REQUIRED_TYPES = ['title']
TEXT_REQUIRED_TYPES = ["title"]
# Types that can use either (custom_field can be ID or text)
class TagSuggestionSerializer(serializers.Serializer):
"""Serializer for tag suggestions."""
id = serializers.IntegerField()
name = serializers.CharField()
color = serializers.CharField()
@ -51,7 +48,7 @@ class TagSuggestionSerializer(serializers.Serializer):
class CorrespondentSuggestionSerializer(serializers.Serializer):
"""Serializer for correspondent suggestions."""
id = serializers.IntegerField()
name = serializers.CharField()
confidence = serializers.FloatField()
@ -59,7 +56,7 @@ class CorrespondentSuggestionSerializer(serializers.Serializer):
class DocumentTypeSuggestionSerializer(serializers.Serializer):
"""Serializer for document type suggestions."""
id = serializers.IntegerField()
name = serializers.CharField()
confidence = serializers.FloatField()
@ -67,7 +64,7 @@ class DocumentTypeSuggestionSerializer(serializers.Serializer):
class StoragePathSuggestionSerializer(serializers.Serializer):
"""Serializer for storage path suggestions."""
id = serializers.IntegerField()
name = serializers.CharField()
path = serializers.CharField()
@ -76,7 +73,7 @@ class StoragePathSuggestionSerializer(serializers.Serializer):
class CustomFieldSuggestionSerializer(serializers.Serializer):
"""Serializer for custom field suggestions."""
field_id = serializers.IntegerField()
field_name = serializers.CharField()
value = serializers.CharField()
@ -85,7 +82,7 @@ class CustomFieldSuggestionSerializer(serializers.Serializer):
class WorkflowSuggestionSerializer(serializers.Serializer):
"""Serializer for workflow suggestions."""
id = serializers.IntegerField()
name = serializers.CharField()
confidence = serializers.FloatField()
@ -93,7 +90,7 @@ class WorkflowSuggestionSerializer(serializers.Serializer):
class TitleSuggestionSerializer(serializers.Serializer):
"""Serializer for title suggestions."""
title = serializers.CharField()
@ -103,7 +100,7 @@ class AISuggestionsSerializer(serializers.Serializer):
Converts AIScanResult objects to JSON format for API responses.
"""
tags = TagSuggestionSerializer(many=True, required=False)
correspondent = CorrespondentSuggestionSerializer(required=False, allow_null=True)
document_type = DocumentTypeSuggestionSerializer(required=False, allow_null=True)
@ -111,9 +108,9 @@ class AISuggestionsSerializer(serializers.Serializer):
custom_fields = CustomFieldSuggestionSerializer(many=True, required=False)
workflows = WorkflowSuggestionSerializer(many=True, required=False)
title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True)
@staticmethod
def from_scan_result(scan_result, document_id: int) -> Dict[str, Any]:
def from_scan_result(scan_result, document_id: int) -> dict[str, Any]:
"""
Convert an AIScanResult object to serializer data.
@ -125,7 +122,7 @@ class AISuggestionsSerializer(serializers.Serializer):
Dictionary ready for serialization
"""
data = {}
# Tags
if scan_result.tags:
tag_suggestions = []
@ -133,59 +130,59 @@ class AISuggestionsSerializer(serializers.Serializer):
try:
tag = Tag.objects.get(pk=tag_id)
tag_suggestions.append({
'id': tag.id,
'name': tag.name,
'color': getattr(tag, 'color', '#000000'),
'confidence': confidence,
"id": tag.id,
"name": tag.name,
"color": getattr(tag, "color", "#000000"),
"confidence": confidence,
})
except Tag.DoesNotExist:
# Tag no longer exists in database; skip this suggestion
pass
data['tags'] = tag_suggestions
data["tags"] = tag_suggestions
# Correspondent
if scan_result.correspondent:
corr_id, confidence = scan_result.correspondent
try:
correspondent = Correspondent.objects.get(pk=corr_id)
data['correspondent'] = {
'id': correspondent.id,
'name': correspondent.name,
'confidence': confidence,
data["correspondent"] = {
"id": correspondent.id,
"name": correspondent.name,
"confidence": confidence,
}
except Correspondent.DoesNotExist:
# Correspondent no longer exists in database; omit from suggestions
pass
# Document Type
if scan_result.document_type:
type_id, confidence = scan_result.document_type
try:
doc_type = DocumentType.objects.get(pk=type_id)
data['document_type'] = {
'id': doc_type.id,
'name': doc_type.name,
'confidence': confidence,
data["document_type"] = {
"id": doc_type.id,
"name": doc_type.name,
"confidence": confidence,
}
except DocumentType.DoesNotExist:
# Document type no longer exists in database; omit from suggestions
pass
# Storage Path
if scan_result.storage_path:
path_id, confidence = scan_result.storage_path
try:
storage_path = StoragePath.objects.get(pk=path_id)
data['storage_path'] = {
'id': storage_path.id,
'name': storage_path.name,
'path': storage_path.path,
'confidence': confidence,
data["storage_path"] = {
"id": storage_path.id,
"name": storage_path.name,
"path": storage_path.path,
"confidence": confidence,
}
except StoragePath.DoesNotExist:
# Storage path no longer exists in database; omit from suggestions
pass
# Custom Fields
if scan_result.custom_fields:
field_suggestions = []
@ -193,16 +190,16 @@ class AISuggestionsSerializer(serializers.Serializer):
try:
field = CustomField.objects.get(pk=field_id)
field_suggestions.append({
'field_id': field.id,
'field_name': field.name,
'value': str(value),
'confidence': confidence,
"field_id": field.id,
"field_name": field.name,
"value": str(value),
"confidence": confidence,
})
except CustomField.DoesNotExist:
# Custom field no longer exists in database; skip this suggestion
pass
data['custom_fields'] = field_suggestions
data["custom_fields"] = field_suggestions
# Workflows
if scan_result.workflows:
workflow_suggestions = []
@ -210,21 +207,21 @@ class AISuggestionsSerializer(serializers.Serializer):
try:
workflow = Workflow.objects.get(pk=workflow_id)
workflow_suggestions.append({
'id': workflow.id,
'name': workflow.name,
'confidence': confidence,
"id": workflow.id,
"name": workflow.name,
"confidence": confidence,
})
except Workflow.DoesNotExist:
# Workflow no longer exists in database; skip this suggestion
pass
data['workflows'] = workflow_suggestions
data["workflows"] = workflow_suggestions
# Title suggestion
if scan_result.title_suggestion:
data['title_suggestion'] = {
'title': scan_result.title_suggestion,
data["title_suggestion"] = {
"title": scan_result.title_suggestion,
}
return data
@ -234,28 +231,28 @@ class SuggestionSerializerMixin:
"""
def validate(self, attrs):
"""Validate that the correct value field is provided for the suggestion type."""
suggestion_type = attrs.get('suggestion_type')
value_id = attrs.get('value_id')
value_text = attrs.get('value_text')
suggestion_type = attrs.get("suggestion_type")
value_id = attrs.get("value_id")
value_text = attrs.get("value_text")
# Types that require value_id
if suggestion_type in ID_REQUIRED_TYPES and not value_id:
raise serializers.ValidationError(
f"value_id is required for suggestion_type '{suggestion_type}'"
f"value_id is required for suggestion_type '{suggestion_type}'",
)
# Types that require value_text
if suggestion_type in TEXT_REQUIRED_TYPES and not value_text:
raise serializers.ValidationError(
f"value_text is required for suggestion_type '{suggestion_type}'"
f"value_text is required for suggestion_type '{suggestion_type}'",
)
# For custom_field, either is acceptable
if suggestion_type == 'custom_field' and not value_id and not value_text:
if suggestion_type == "custom_field" and not value_id and not value_text:
raise serializers.ValidationError(
"Either value_id or value_text must be provided for custom_field"
"Either value_id or value_text must be provided for custom_field",
)
return attrs
@ -263,12 +260,12 @@ class ApplySuggestionSerializer(SuggestionSerializerMixin, serializers.Serialize
"""
Serializer for applying AI suggestions.
"""
suggestion_type = serializers.ChoiceField(
choices=SUGGESTION_TYPE_CHOICES,
required=True,
)
value_id = serializers.IntegerField(required=False, allow_null=True)
value_text = serializers.CharField(required=False, allow_blank=True)
confidence = serializers.FloatField(required=True)
@ -278,12 +275,12 @@ class RejectSuggestionSerializer(SuggestionSerializerMixin, serializers.Serializ
"""
Serializer for rejecting AI suggestions.
"""
suggestion_type = serializers.ChoiceField(
choices=SUGGESTION_TYPE_CHOICES,
required=True,
)
value_id = serializers.IntegerField(required=False, allow_null=True)
value_text = serializers.CharField(required=False, allow_blank=True)
confidence = serializers.FloatField(required=True)
@ -291,41 +288,41 @@ class RejectSuggestionSerializer(SuggestionSerializerMixin, serializers.Serializ
class AISuggestionFeedbackSerializer(serializers.ModelSerializer):
"""Serializer for AI suggestion feedback model."""
class Meta:
model = AISuggestionFeedback
fields = [
'id',
'document',
'suggestion_type',
'suggested_value_id',
'suggested_value_text',
'confidence',
'status',
'user',
'created_at',
'applied_at',
'metadata',
"id",
"document",
"suggestion_type",
"suggested_value_id",
"suggested_value_text",
"confidence",
"status",
"user",
"created_at",
"applied_at",
"metadata",
]
read_only_fields = ['id', 'created_at', 'applied_at']
read_only_fields = ["id", "created_at", "applied_at"]
class AISuggestionStatsSerializer(serializers.Serializer):
"""
Serializer for AI suggestion accuracy statistics.
"""
total_suggestions = serializers.IntegerField()
total_applied = serializers.IntegerField()
total_rejected = serializers.IntegerField()
accuracy_rate = serializers.FloatField()
by_type = serializers.DictField(
child=serializers.DictField(),
help_text="Statistics broken down by suggestion type",
)
average_confidence_applied = serializers.FloatField()
average_confidence_rejected = serializers.FloatField()
recent_suggestions = AISuggestionFeedbackSerializer(many=True, required=False)

View file

@ -18,13 +18,11 @@ from django.test import TestCase
from django.utils import timezone
from documents.ai_deletion_manager import AIDeletionManager
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
class TestAIDeletionManagerCreateRequest(TestCase):
@ -33,13 +31,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
def setUp(self):
"""Set up test data."""
self.user = User.objects.create_user(username="testuser", password="testpass")
# Create test documents with various metadata
self.correspondent = Correspondent.objects.create(name="Test Corp")
self.doc_type = DocumentType.objects.create(name="Invoice")
self.tag1 = Tag.objects.create(name="Important")
self.tag2 = Tag.objects.create(name="2024")
self.doc1 = Document.objects.create(
title="Test Document 1",
content="Test content 1",
@ -49,7 +47,7 @@ class TestAIDeletionManagerCreateRequest(TestCase):
document_type=self.doc_type,
)
self.doc1.tags.add(self.tag1, self.tag2)
self.doc2 = Document.objects.create(
title="Test Document 2",
content="Test content 2",
@ -63,13 +61,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
"""Test creating a basic deletion request."""
documents = [self.doc1, self.doc2]
reason = "Duplicate documents detected"
request = AIDeletionManager.create_deletion_request(
documents=documents,
reason=reason,
user=self.user,
)
self.assertIsNotNone(request)
self.assertIsInstance(request, DeletionRequest)
self.assertEqual(request.ai_reason, reason)
@ -82,13 +80,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
"""Test that deletion request includes impact analysis."""
documents = [self.doc1, self.doc2]
reason = "Test deletion"
request = AIDeletionManager.create_deletion_request(
documents=documents,
reason=reason,
user=self.user,
)
impact = request.impact_summary
self.assertIsNotNone(impact)
self.assertEqual(impact["document_count"], 2)
@ -106,14 +104,14 @@ class TestAIDeletionManagerCreateRequest(TestCase):
"document_count": 1,
"custom_field": "custom_value",
}
request = AIDeletionManager.create_deletion_request(
documents=documents,
reason=reason,
user=self.user,
impact_analysis=custom_impact,
)
self.assertEqual(request.impact_summary, custom_impact)
self.assertEqual(request.impact_summary["custom_field"], "custom_value")
@ -121,13 +119,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
"""Test creating request with empty document list."""
documents = []
reason = "Test deletion"
request = AIDeletionManager.create_deletion_request(
documents=documents,
reason=reason,
user=self.user,
)
self.assertIsNotNone(request)
self.assertEqual(request.documents.count(), 0)
self.assertEqual(request.impact_summary["document_count"], 0)
@ -157,9 +155,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
document_type=self.doc_type1,
)
doc.tags.add(self.tag1, self.tag2)
impact = AIDeletionManager._analyze_impact([doc])
self.assertEqual(impact["document_count"], 1)
self.assertEqual(len(impact["documents"]), 1)
self.assertEqual(impact["documents"][0]["id"], doc.id)
@ -180,7 +178,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
document_type=self.doc_type1,
)
doc1.tags.add(self.tag1)
doc2 = Document.objects.create(
title="Document 2",
content="Content 2",
@ -190,9 +188,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
document_type=self.doc_type2,
)
doc2.tags.add(self.tag2, self.tag3)
impact = AIDeletionManager._analyze_impact([doc1, doc2])
self.assertEqual(impact["document_count"], 2)
self.assertEqual(len(impact["documents"]), 2)
self.assertIn("Corp A", impact["affected_correspondents"])
@ -209,9 +207,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
checksum="checksum1",
mime_type="application/pdf",
)
impact = AIDeletionManager._analyze_impact([doc])
self.assertEqual(impact["document_count"], 1)
self.assertEqual(impact["documents"][0]["correspondent"], None)
self.assertEqual(impact["documents"][0]["document_type"], None)
@ -232,7 +230,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
# Force set the created date to an earlier time
doc1.created = timezone.make_aware(datetime(2023, 1, 1))
doc1.save()
doc2 = Document.objects.create(
title="New Document",
content="Content",
@ -241,9 +239,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
)
doc2.created = timezone.make_aware(datetime(2024, 12, 31))
doc2.save()
impact = AIDeletionManager._analyze_impact([doc1, doc2])
self.assertIsNotNone(impact["date_range"]["earliest"])
self.assertIsNotNone(impact["date_range"]["latest"])
# Check that dates are ISO formatted strings
@ -253,7 +251,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
def test_analyze_impact_empty_list(self):
"""Test impact analysis with empty document list."""
impact = AIDeletionManager._analyze_impact([])
self.assertEqual(impact["document_count"], 0)
self.assertEqual(len(impact["documents"]), 0)
self.assertEqual(len(impact["affected_correspondents"]), 0)
@ -267,11 +265,11 @@ class TestAIDeletionManagerFormatRequest(TestCase):
def setUp(self):
"""Set up test data."""
self.user = User.objects.create_user(username="testuser", password="testpass")
self.correspondent = Correspondent.objects.create(name="Test Corp")
self.doc_type = DocumentType.objects.create(name="Invoice")
self.tag = Tag.objects.create(name="Important")
self.doc = Document.objects.create(
title="Test Document",
content="Test content",
@ -289,9 +287,9 @@ class TestAIDeletionManagerFormatRequest(TestCase):
reason="Test reason for deletion",
user=self.user,
)
message = AIDeletionManager.format_deletion_request_for_user(request)
self.assertIsInstance(message, str)
self.assertIn("AI DELETION REQUEST", message)
self.assertIn("Test reason for deletion", message)
@ -306,15 +304,15 @@ class TestAIDeletionManagerFormatRequest(TestCase):
checksum="checksum2",
mime_type="application/pdf",
)
request = AIDeletionManager.create_deletion_request(
documents=[self.doc, doc2],
reason="Multiple documents",
user=self.user,
)
message = AIDeletionManager.format_deletion_request_for_user(request)
self.assertIn("Number of documents: 2", message)
self.assertIn("Test Corp", message)
self.assertIn("Invoice", message)
@ -328,15 +326,15 @@ class TestAIDeletionManagerFormatRequest(TestCase):
checksum="checksum1",
mime_type="application/pdf",
)
request = AIDeletionManager.create_deletion_request(
documents=[doc],
reason="Test deletion",
user=self.user,
)
message = AIDeletionManager.format_deletion_request_for_user(request)
self.assertIn("Basic Document", message)
self.assertIn("None", message) # Should show None for missing metadata
@ -347,9 +345,9 @@ class TestAIDeletionManagerFormatRequest(TestCase):
reason="Test",
user=self.user,
)
message = AIDeletionManager.format_deletion_request_for_user(request)
self.assertIn("explicit approval", message.lower())
self.assertIn("no files will be deleted until you confirm", message.lower())
@ -361,7 +359,7 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
"""Set up test data."""
self.user1 = User.objects.create_user(username="user1", password="pass1")
self.user2 = User.objects.create_user(username="user2", password="pass2")
self.doc = Document.objects.create(
title="Test Document",
content="Content",
@ -382,16 +380,16 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
reason="Reason 2",
user=self.user1,
)
# Create request for user2
AIDeletionManager.create_deletion_request(
documents=[self.doc],
reason="Reason 3",
user=self.user2,
)
pending = AIDeletionManager.get_pending_requests(self.user1)
self.assertEqual(len(pending), 2)
self.assertIn(req1, pending)
self.assertIn(req2, pending)
@ -408,12 +406,12 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
reason="Reason 2",
user=self.user1,
)
# Approve one request
req1.approve(self.user1, "Approved")
pending = AIDeletionManager.get_pending_requests(self.user1)
self.assertEqual(len(pending), 1)
self.assertNotIn(req1, pending)
self.assertIn(req2, pending)
@ -430,12 +428,12 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
reason="Reason 2",
user=self.user1,
)
# Reject one request
req1.reject(self.user1, "Rejected")
pending = AIDeletionManager.get_pending_requests(self.user1)
self.assertEqual(len(pending), 1)
self.assertNotIn(req1, pending)
self.assertIn(req2, pending)
@ -443,7 +441,7 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
def test_get_pending_requests_empty(self):
"""Test getting pending requests when none exist."""
pending = AIDeletionManager.get_pending_requests(self.user1)
self.assertEqual(len(pending), 0)
@ -454,7 +452,7 @@ class TestAIDeletionManagerSecurityConstraints(TestCase):
"""Test that AI can never delete automatically."""
# This is a critical security test
result = AIDeletionManager.can_ai_delete_automatically()
self.assertFalse(result)
def test_deletion_request_requires_pending_status(self):
@ -466,13 +464,13 @@ class TestAIDeletionManagerSecurityConstraints(TestCase):
checksum="checksum1",
mime_type="application/pdf",
)
request = AIDeletionManager.create_deletion_request(
documents=[doc],
reason="Test",
user=user,
)
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
def test_deletion_request_marked_as_ai_initiated(self):
@ -484,13 +482,13 @@ class TestAIDeletionManagerSecurityConstraints(TestCase):
checksum="checksum1",
mime_type="application/pdf",
)
request = AIDeletionManager.create_deletion_request(
documents=[doc],
reason="Test",
user=user,
)
self.assertTrue(request.requested_by_ai)
@ -501,7 +499,7 @@ class TestAIDeletionManagerWorkflow(TestCase):
"""Set up test data."""
self.user = User.objects.create_user(username="testuser", password="pass")
self.approver = User.objects.create_user(username="approver", password="pass")
self.doc1 = Document.objects.create(
title="Document 1",
content="Content 1",
@ -523,14 +521,14 @@ class TestAIDeletionManagerWorkflow(TestCase):
reason="Duplicates detected",
user=self.user,
)
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
self.assertIsNone(request.reviewed_at)
self.assertIsNone(request.reviewed_by)
# Step 2: Approve request
success = request.approve(self.approver, "Looks good")
self.assertTrue(success)
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
self.assertIsNotNone(request.reviewed_at)
@ -545,12 +543,12 @@ class TestAIDeletionManagerWorkflow(TestCase):
reason="Should be deleted",
user=self.user,
)
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
# Step 2: Reject request
success = request.reject(self.approver, "Not a duplicate")
self.assertTrue(success)
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
self.assertIsNotNone(request.reviewed_at)
@ -564,14 +562,14 @@ class TestAIDeletionManagerWorkflow(TestCase):
reason="Test deletion",
user=self.user,
)
# Record initial state
created_at = request.created_at
self.assertIsNotNone(created_at)
# Approve
request.approve(self.approver, "Approved")
# Verify audit trail
self.assertIsNotNone(request.created_at)
self.assertIsNotNone(request.updated_at)

View file

@ -10,24 +10,23 @@ Tests cover:
- Permission assignment and verification
"""
from django.contrib.auth.models import Group, Permission, User
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from rest_framework.test import APIRequestFactory
from documents.models import Document
from documents.permissions import (
CanApplyAISuggestionsPermission,
CanApproveDeletionsPermission,
CanConfigureAIPermission,
CanViewAISuggestionsPermission,
)
from documents.permissions import CanApplyAISuggestionsPermission
from documents.permissions import CanApproveDeletionsPermission
from documents.permissions import CanConfigureAIPermission
from documents.permissions import CanViewAISuggestionsPermission
class MockView:
"""Mock view for testing permissions."""
pass
class TestCanViewAISuggestionsPermission(TestCase):
@ -41,13 +40,13 @@ class TestCanViewAISuggestionsPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -107,13 +106,13 @@ class TestCanApplyAISuggestionsPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -173,13 +172,13 @@ class TestCanApproveDeletionsPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -239,13 +238,13 @@ class TestCanConfigureAIPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -345,7 +344,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_viewer_role_permissions(self):
"""Test that viewer role has appropriate permissions."""
user = User.objects.create_user(
username="viewer", email="viewer@test.com", password="viewer123"
username="viewer", email="viewer@test.com", password="viewer123",
)
user.groups.add(self.viewer_group)
@ -360,7 +359,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_editor_role_permissions(self):
"""Test that editor role has appropriate permissions."""
user = User.objects.create_user(
username="editor", email="editor@test.com", password="editor123"
username="editor", email="editor@test.com", password="editor123",
)
user.groups.add(self.editor_group)
@ -375,7 +374,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_admin_role_permissions(self):
"""Test that admin role has all permissions."""
user = User.objects.create_user(
username="ai_admin", email="ai_admin@test.com", password="admin123"
username="ai_admin", email="ai_admin@test.com", password="admin123",
)
user.groups.add(self.admin_group)
@ -390,7 +389,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_user_with_multiple_groups(self):
"""Test that user permissions accumulate from multiple groups."""
user = User.objects.create_user(
username="multi_role", email="multi@test.com", password="multi123"
username="multi_role", email="multi@test.com", password="multi123",
)
user.groups.add(self.viewer_group, self.editor_group)
@ -405,7 +404,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_direct_permission_assignment_overrides_group(self):
"""Test that direct permission assignment works alongside group permissions."""
user = User.objects.create_user(
username="special", email="special@test.com", password="special123"
username="special", email="special@test.com", password="special123",
)
user.groups.add(self.viewer_group)
@ -428,7 +427,7 @@ class TestPermissionAssignment(TestCase):
def setUp(self):
"""Set up test user."""
self.user = User.objects.create_user(
username="testuser", email="test@test.com", password="test123"
username="testuser", email="test@test.com", password="test123",
)
content_type = ContentType.objects.get_for_model(Document)
self.view_permission, _ = Permission.objects.get_or_create(
@ -500,7 +499,7 @@ class TestPermissionEdgeCases(TestCase):
def test_inactive_user_with_permission(self):
"""Test that inactive users are denied even with permission."""
user = User.objects.create_user(
username="inactive", email="inactive@test.com", password="inactive123"
username="inactive", email="inactive@test.com", password="inactive123",
)
user.is_active = False
user.save()

File diff suppressed because it is too large Load diff

View file

@ -8,24 +8,21 @@ document consumption to metadata application.
from unittest import mock
from django.test import TestCase, TransactionTestCase
from django.test import TestCase
from django.test import TransactionTestCase
from documents.ai_scanner import (
AIDocumentScanner,
AIScanResult,
get_ai_scanner,
)
from documents.models import (
Correspondent,
CustomField,
Document,
DocumentType,
StoragePath,
Tag,
Workflow,
WorkflowTrigger,
WorkflowAction,
)
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import AIScanResult
from documents.ai_scanner import get_ai_scanner
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
class TestAIScannerIntegrationBasic(TestCase):
@ -35,49 +32,49 @@ class TestAIScannerIntegrationBasic(TestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Invoice from ACME Corporation",
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000"
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000",
)
self.tag_invoice = Tag.objects.create(
name="Invoice",
matching_algorithm=Tag.MATCH_AUTO,
match="invoice"
match="invoice",
)
self.tag_important = Tag.objects.create(
name="Important",
matching_algorithm=Tag.MATCH_AUTO,
match="total"
match="total",
)
self.correspondent = Correspondent.objects.create(
name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO,
match="acme"
match="acme",
)
self.doc_type = DocumentType.objects.create(
name="Invoice",
matching_algorithm=DocumentType.MATCH_AUTO,
match="invoice"
match="invoice",
)
self.storage_path = StoragePath.objects.create(
name="Invoices",
path="/invoices",
matching_algorithm=StoragePath.MATCH_AUTO,
match="invoice"
match="invoice",
)
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch('documents.ai_scanner.match_correspondents')
@mock.patch('documents.ai_scanner.match_document_types')
@mock.patch('documents.ai_scanner.match_storage_paths')
@mock.patch("documents.ai_scanner.match_tags")
@mock.patch("documents.ai_scanner.match_correspondents")
@mock.patch("documents.ai_scanner.match_document_types")
@mock.patch("documents.ai_scanner.match_storage_paths")
def test_full_scan_and_apply_workflow(
self,
mock_storage,
mock_types,
mock_correspondents,
mock_tags
mock_tags,
):
"""Test complete workflow from scan to application."""
# Mock the matching functions to return our test data
@ -85,51 +82,51 @@ class TestAIScannerIntegrationBasic(TestCase):
mock_correspondents.return_value = [self.correspondent]
mock_types.return_value = [self.doc_type]
mock_storage.return_value = [self.storage_path]
scanner = AIDocumentScanner(auto_apply_threshold=0.80)
# Scan the document
scan_result = scanner.scan_document(
self.document,
self.document.content
self.document.content,
)
# Verify scan results
self.assertIsNotNone(scan_result)
self.assertGreater(len(scan_result.tags), 0)
self.assertIsNotNone(scan_result.correspondent)
self.assertIsNotNone(scan_result.document_type)
self.assertIsNotNone(scan_result.storage_path)
# Apply the results
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify application
self.assertGreater(len(result["applied"]["tags"]), 0)
self.assertIsNotNone(result["applied"]["correspondent"])
# Verify database changes
self.document.refresh_from_db()
self.assertEqual(self.document.correspondent, self.correspondent)
self.assertEqual(self.document.document_type, self.doc_type)
self.assertEqual(self.document.storage_path, self.storage_path)
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch("documents.ai_scanner.match_tags")
def test_scan_with_no_matches(self, mock_tags):
"""Test scanning when no matches are found."""
mock_tags.return_value = []
scanner = AIDocumentScanner()
scan_result = scanner.scan_document(
self.document,
"Some random text with no matches"
"Some random text with no matches",
)
# Should return empty results
self.assertEqual(len(scan_result.tags), 0)
self.assertIsNone(scan_result.correspondent)
@ -143,46 +140,46 @@ class TestAIScannerIntegrationCustomFields(TestCase):
"""Set up test data with custom fields."""
self.document = Document.objects.create(
title="Invoice",
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com"
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com",
)
self.field_date = CustomField.objects.create(
name="Invoice Date",
data_type=CustomField.FieldDataType.DATE
data_type=CustomField.FieldDataType.DATE,
)
self.field_number = CustomField.objects.create(
name="Invoice Number",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_amount = CustomField.objects.create(
name="Total Amount",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_email = CustomField.objects.create(
name="Contact Email",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
def test_custom_field_extraction_integration(self):
"""Test custom field extraction with mocked NER."""
scanner = AIDocumentScanner()
# Mock NER to return entities
mock_ner = mock.MagicMock()
mock_ner.extract_all.return_value = {
"dates": [{"text": "2024-01-01"}],
"amounts": [{"text": "$1,500"}],
"invoice_numbers": ["INV-123"],
"emails": ["john@example.com"]
"emails": ["john@example.com"],
}
scanner._ner_extractor = mock_ner
# Scan document
scan_result = scanner.scan_document(self.document, self.document.content)
# Verify custom fields were extracted
self.assertGreater(len(scan_result.custom_fields), 0)
# Check specific fields
extracted_field_ids = list(scan_result.custom_fields.keys())
self.assertIn(self.field_date.id, extracted_field_ids)
@ -196,47 +193,47 @@ class TestAIScannerIntegrationWorkflows(TestCase):
"""Set up test workflows."""
self.document = Document.objects.create(
title="Invoice",
content="Invoice document"
content="Invoice document",
)
self.workflow1 = Workflow.objects.create(
name="Invoice Processing",
enabled=True
enabled=True,
)
self.trigger1 = WorkflowTrigger.objects.create(
workflow=self.workflow1,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
)
self.action1 = WorkflowAction.objects.create(
workflow=self.workflow1,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
)
self.workflow2 = Workflow.objects.create(
name="Archive Documents",
enabled=True
enabled=True,
)
self.trigger2 = WorkflowTrigger.objects.create(
workflow=self.workflow2,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
)
def test_workflow_suggestion_integration(self):
"""Test workflow suggestion with real workflows."""
scanner = AIDocumentScanner(suggest_threshold=0.5)
# Create scan result with some attributes
scan_result = AIScanResult()
scan_result.document_type = (1, 0.85)
scan_result.tags = [(1, 0.80)]
# Get workflow suggestions
workflows = scanner._suggest_workflows(
self.document,
self.document.content,
scan_result
scan_result,
)
# Should suggest workflows
self.assertGreater(len(workflows), 0)
workflow_ids = [wf_id for wf_id, _ in workflows]
@ -250,7 +247,7 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
self.tag = Tag.objects.create(name="TestTag")
self.correspondent = Correspondent.objects.create(name="TestCorp")
@ -258,29 +255,29 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
def test_transaction_rollback_on_error(self):
"""Test that transaction rolls back on error."""
scanner = AIDocumentScanner()
scan_result = AIScanResult()
scan_result.tags = [(self.tag.id, 0.90)]
scan_result.correspondent = (self.correspondent.id, 0.90)
# Force an error during save
original_save = Document.save
call_count = [0]
def failing_save(self, *args, **kwargs):
call_count[0] += 1
if call_count[0] >= 1:
raise Exception("Forced save failure")
return original_save(self, *args, **kwargs)
with mock.patch.object(Document, 'save', failing_save):
with mock.patch.object(Document, "save", failing_save):
with self.assertRaises(Exception):
scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify changes were rolled back
self.document.refresh_from_db()
# Document should not have been modified
@ -292,30 +289,30 @@ class TestAIScannerIntegrationPerformance(TestCase):
def test_scan_multiple_documents(self):
"""Test scanning multiple documents efficiently."""
scanner = AIDocumentScanner()
documents = []
for i in range(5):
doc = Document.objects.create(
title=f"Document {i}",
content=f"Content for document {i}"
content=f"Content for document {i}",
)
documents.append(doc)
# Mock to avoid actual ML loading
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
results = []
for doc in documents:
result = scanner.scan_document(doc, doc.content)
results.append(result)
# Verify all scans completed
self.assertEqual(len(results), 5)
for result in results:
@ -329,37 +326,37 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Business Invoice",
content="Invoice from ACME Corporation"
content="Invoice from ACME Corporation",
)
self.correspondent_acme = Correspondent.objects.create(
name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.correspondent_other = Correspondent.objects.create(
name="Other Company",
matching_algorithm=Correspondent.MATCH_AUTO
matching_algorithm=Correspondent.MATCH_AUTO,
)
def test_correspondent_matching_with_ner_entities(self):
"""Test that NER entities help match correspondents."""
scanner = AIDocumentScanner()
# Mock NER to extract organization
mock_ner = mock.MagicMock()
mock_ner.extract_all.return_value = {
"organizations": [{"text": "ACME Corporation"}]
"organizations": [{"text": "ACME Corporation"}],
}
scanner._ner_extractor = mock_ner
# Mock matching to return empty (so NER-based matching is used)
with mock.patch('documents.ai_scanner.match_correspondents', return_value=[]):
with mock.patch("documents.ai_scanner.match_correspondents", return_value=[]):
result = scanner._detect_correspondent(
self.document,
self.document.content,
{"organizations": [{"text": "ACME Corporation"}]}
{"organizations": [{"text": "ACME Corporation"}]},
)
# Should detect ACME correspondent
self.assertIsNotNone(result)
corr_id, confidence = result
@ -372,20 +369,20 @@ class TestAIScannerIntegrationTitleGeneration(TestCase):
def test_title_generation_with_entities(self):
"""Test title generation uses extracted entities."""
scanner = AIDocumentScanner()
document = Document.objects.create(
title="document.pdf",
content="Invoice from ACME Corp dated 2024-01-15"
content="Invoice from ACME Corp dated 2024-01-15",
)
entities = {
"document_type": "Invoice",
"organizations": [{"text": "ACME Corp"}],
"dates": [{"text": "2024-01-15"}]
"dates": [{"text": "2024-01-15"}],
}
title = scanner._suggest_title(document, document.content, entities)
self.assertIsNotNone(title)
self.assertIn("Invoice", title)
self.assertIn("ACME Corp", title)
@ -399,7 +396,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Test",
content="Test"
content="Test",
)
self.tag_high = Tag.objects.create(name="HighConfidence")
self.tag_medium = Tag.objects.create(name="MediumConfidence")
@ -409,26 +406,26 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
"""Test that only high confidence suggestions are auto-applied."""
scanner = AIDocumentScanner(
auto_apply_threshold=0.80,
suggest_threshold=0.60
suggest_threshold=0.60,
)
scan_result = AIScanResult()
scan_result.tags = [
(self.tag_high.id, 0.90), # Should be applied
(self.tag_medium.id, 0.70), # Should be suggested
(self.tag_low.id, 0.50), # Should be ignored
]
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify high confidence was applied
self.assertEqual(len(result["applied"]["tags"]), 1)
self.assertEqual(result["applied"]["tags"][0]["id"], self.tag_high.id)
# Verify medium confidence was suggested
self.assertEqual(len(result["suggestions"]["tags"]), 1)
self.assertEqual(result["suggestions"]["tags"][0]["id"], self.tag_medium.id)
@ -441,28 +438,28 @@ class TestAIScannerIntegrationGlobalInstance(TestCase):
"""Test that global scanner can be reused across multiple scans."""
scanner1 = get_ai_scanner()
scanner2 = get_ai_scanner()
# Should be the same instance
self.assertIs(scanner1, scanner2)
# Should be functional
document = Document.objects.create(
title="Test",
content="Test content"
content="Test content",
)
with mock.patch.object(scanner1, '_extract_entities', return_value={}), \
mock.patch.object(scanner1, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner1, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner1, '_classify_document_type', return_value=None), \
mock.patch.object(scanner1, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner1, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner1, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner1, '_suggest_title', return_value=None):
with mock.patch.object(scanner1, "_extract_entities", return_value={}), \
mock.patch.object(scanner1, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner1, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner1, "_classify_document_type", return_value=None), \
mock.patch.object(scanner1, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner1, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner1, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner1, "_suggest_title", return_value=None):
result1 = scanner1.scan_document(document, document.content)
result2 = scanner2.scan_document(document, document.content)
self.assertIsInstance(result1, AIScanResult)
self.assertIsInstance(result2, AIScanResult)
@ -473,66 +470,66 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
def test_scan_with_minimal_document(self):
"""Test scanning a document with minimal information."""
scanner = AIDocumentScanner()
document = Document.objects.create(
title="",
content=""
content="",
)
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(document, document.content)
self.assertIsInstance(result, AIScanResult)
def test_apply_with_deleted_references(self):
"""Test applying results when referenced objects have been deleted."""
scanner = AIDocumentScanner()
document = Document.objects.create(
title="Test",
content="Test"
content="Test",
)
scan_result = AIScanResult()
scan_result.tags = [(9999, 0.90)] # Non-existent tag ID
scan_result.correspondent = (9999, 0.90) # Non-existent correspondent ID
# Should handle gracefully
result = scanner.apply_scan_results(
document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Should not crash, just log errors
self.assertEqual(len(result["applied"]["tags"]), 0)
def test_scan_with_unicode_and_special_characters(self):
"""Test scanning documents with Unicode and special characters."""
scanner = AIDocumentScanner()
document = Document.objects.create(
title="Factura - España 🇪🇸",
content="Société française • 日本語 • Ελληνικά • مرحبا"
content="Société française • 日本語 • Ελληνικά • مرحبا",
)
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(document, document.content)
self.assertIsInstance(result, AIScanResult)

View file

@ -12,18 +12,17 @@ Tests cover:
from unittest import mock
from django.contrib.auth.models import Permission, User
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from rest_framework import status
from rest_framework.test import APITestCase
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
@ -33,18 +32,18 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
def setUp(self):
"""Set up test data."""
super().setUp()
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
# Assign view permission
content_type = ContentType.objects.get_for_model(Document)
view_permission, _ = Permission.objects.get_or_create(
@ -53,13 +52,13 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
content_type=content_type,
)
self.user_with_permission.user_permissions.add(view_permission)
# Create test document
self.document = Document.objects.create(
title="Test Document",
content="This is a test invoice from ACME Corporation"
content="This is a test invoice from ACME Corporation",
)
# Create test metadata objects
self.tag = Tag.objects.create(name="Invoice")
self.correspondent = Correspondent.objects.create(name="ACME Corp")
@ -70,28 +69,28 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_user_without_permission_denied(self):
"""Test that users without permission are denied."""
self.client.force_authenticate(user=self.user_without_permission)
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_superuser_allowed(self):
"""Test that superusers can access the endpoint."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [(self.tag.id, 0.85)]
@ -100,17 +99,17 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
mock_scan_result.storage_path = None
mock_scan_result.title_suggestion = "Invoice - ACME Corp"
mock_scan_result.custom_fields = {}
mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.scan_document.return_value = mock_scan_result
mock_scanner.return_value = mock_scanner_instance
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("document_id", response.data)
self.assertEqual(response.data["document_id"], self.document.id)
@ -118,8 +117,8 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
def test_user_with_permission_allowed(self):
"""Test that users with permission can access the endpoint."""
self.client.force_authenticate(user=self.user_with_permission)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = []
@ -128,41 +127,41 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
mock_scan_result.storage_path = None
mock_scan_result.title_suggestion = None
mock_scan_result.custom_fields = {}
mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.scan_document.return_value = mock_scan_result
mock_scanner.return_value = mock_scanner_instance
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_invalid_document_id(self):
"""Test handling of invalid document ID."""
self.client.force_authenticate(user=self.superuser)
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": 99999},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_missing_document_id(self):
"""Test handling of missing document ID."""
self.client.force_authenticate(user=self.superuser)
response = self.client.post(
"/api/ai/suggestions/",
{},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -172,15 +171,15 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
def setUp(self):
"""Set up test data."""
super().setUp()
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign apply permission
content_type = ContentType.objects.get_for_model(Document)
apply_permission, _ = Permission.objects.get_or_create(
@ -189,13 +188,13 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
content_type=content_type,
)
self.user_with_permission.user_permissions.add(apply_permission)
# Create test document
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
# Create test metadata
self.tag = Tag.objects.create(name="Test Tag")
self.correspondent = Correspondent.objects.create(name="Test Corp")
@ -205,16 +204,16 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/apply/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_apply_tags_success(self):
"""Test successfully applying tag suggestions."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [(self.tag.id, 0.85)]
@ -223,29 +222,29 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
mock_scan_result.storage_path = None
mock_scan_result.title_suggestion = None
mock_scan_result.custom_fields = {}
mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.scan_document.return_value = mock_scan_result
mock_scanner_instance.auto_apply_threshold = 0.80
mock_scanner.return_value = mock_scanner_instance
response = self.client.post(
"/api/ai/suggestions/apply/",
{
"document_id": self.document.id,
"apply_tags": True
"apply_tags": True,
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["status"], "success")
def test_apply_correspondent_success(self):
"""Test successfully applying correspondent suggestion."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = []
@ -254,23 +253,23 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
mock_scan_result.storage_path = None
mock_scan_result.title_suggestion = None
mock_scan_result.custom_fields = {}
mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.scan_document.return_value = mock_scan_result
mock_scanner_instance.auto_apply_threshold = 0.80
mock_scanner.return_value = mock_scanner_instance
response = self.client.post(
"/api/ai/suggestions/apply/",
{
"document_id": self.document.id,
"apply_correspondent": True
"apply_correspondent": True,
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify correspondent was applied
self.document.refresh_from_db()
self.assertEqual(self.document.correspondent, self.correspondent)
@ -282,43 +281,43 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
def setUp(self):
"""Set up test data."""
super().setUp()
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
def test_unauthorized_access_denied(self):
"""Test that unauthenticated users are denied."""
response = self.client.get("/api/ai/config/")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_user_without_permission_denied(self):
"""Test that users without permission are denied."""
self.client.force_authenticate(user=self.user_without_permission)
response = self.client.get("/api/ai/config/")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_get_config_success(self):
"""Test getting AI configuration."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.auto_apply_threshold = 0.80
mock_scanner_instance.suggest_threshold = 0.60
mock_scanner_instance.ml_enabled = True
mock_scanner_instance.advanced_ocr_enabled = True
mock_scanner.return_value = mock_scanner_instance
response = self.client.get("/api/ai/config/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("auto_apply_threshold", response.data)
self.assertEqual(response.data["auto_apply_threshold"], 0.80)
@ -326,31 +325,31 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
def test_update_config_success(self):
"""Test updating AI configuration."""
self.client.force_authenticate(user=self.superuser)
response = self.client.post(
"/api/ai/config/",
{
"auto_apply_threshold": 0.90,
"suggest_threshold": 0.70
"suggest_threshold": 0.70,
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["status"], "success")
def test_update_config_invalid_threshold(self):
"""Test updating with invalid threshold value."""
self.client.force_authenticate(user=self.superuser)
response = self.client.post(
"/api/ai/config/",
{
"auto_apply_threshold": 1.5 # Invalid: > 1.0
"auto_apply_threshold": 1.5, # Invalid: > 1.0
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -360,18 +359,18 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
def setUp(self):
"""Set up test data."""
super().setUp()
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
# Assign approval permission
content_type = ContentType.objects.get_for_model(Document)
approval_permission, _ = Permission.objects.get_or_create(
@ -380,12 +379,12 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
content_type=content_type,
)
self.user_with_permission.user_permissions.add(approval_permission)
# Create test deletion request
self.deletion_request = DeletionRequest.objects.create(
user=self.user_with_permission,
requested_by_ai=True,
ai_reason="Document appears to be a duplicate"
ai_reason="Document appears to be a duplicate",
)
def test_unauthorized_access_denied(self):
@ -394,102 +393,102 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_user_without_permission_denied(self):
"""Test that users without permission are denied."""
self.client.force_authenticate(user=self.user_without_permission)
response = self.client.post(
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_approve_deletion_success(self):
"""Test successfully approving a deletion request."""
self.client.force_authenticate(user=self.user_with_permission)
response = self.client.post(
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["status"], "success")
# Verify status was updated
self.deletion_request.refresh_from_db()
self.assertEqual(
self.deletion_request.status,
DeletionRequest.STATUS_APPROVED
DeletionRequest.STATUS_APPROVED,
)
def test_reject_deletion_success(self):
"""Test successfully rejecting a deletion request."""
self.client.force_authenticate(user=self.user_with_permission)
response = self.client.post(
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "reject",
"reason": "Document is still needed"
"reason": "Document is still needed",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify status was updated
self.deletion_request.refresh_from_db()
self.assertEqual(
self.deletion_request.status,
DeletionRequest.STATUS_REJECTED
DeletionRequest.STATUS_REJECTED,
)
def test_invalid_request_id(self):
"""Test handling of invalid deletion request ID."""
self.client.force_authenticate(user=self.superuser)
response = self.client.post(
"/api/ai/deletions/approve/",
{
"request_id": 99999,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_superuser_can_approve_any_request(self):
"""Test that superusers can approve any deletion request."""
self.client.force_authenticate(user=self.superuser)
response = self.client.post(
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -499,14 +498,14 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
def setUp(self):
"""Set up test data."""
super().setUp()
# Create user with all AI permissions
self.power_user = User.objects.create_user(
username="power_user", email="power@test.com", password="power123"
username="power_user", email="power@test.com", password="power123",
)
content_type = ContentType.objects.get_for_model(Document)
# Assign all AI permissions
permissions = [
"can_view_ai_suggestions",
@ -514,7 +513,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
"can_approve_deletions",
"can_configure_ai",
]
for codename in permissions:
perm, _ = Permission.objects.get_or_create(
codename=codename,
@ -522,18 +521,18 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
content_type=content_type,
)
self.power_user.user_permissions.add(perm)
self.document = Document.objects.create(
title="Test Doc",
content="Test"
content="Test",
)
def test_power_user_can_access_all_endpoints(self):
"""Test that user with all permissions can access all endpoints."""
self.client.force_authenticate(user=self.power_user)
# Test suggestions endpoint
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = []
mock_scan_result.correspondent = None
@ -541,7 +540,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
mock_scan_result.storage_path = None
mock_scan_result.title_suggestion = None
mock_scan_result.custom_fields = {}
mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.scan_document.return_value = mock_scan_result
mock_scanner_instance.auto_apply_threshold = 0.80
@ -549,25 +548,25 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
mock_scanner_instance.ml_enabled = True
mock_scanner_instance.advanced_ocr_enabled = True
mock_scanner.return_value = mock_scanner_instance
response1 = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response1.status_code, status.HTTP_200_OK)
# Test apply endpoint
response2 = self.client.post(
"/api/ai/suggestions/apply/",
{
"document_id": self.document.id,
"apply_tags": False
"apply_tags": False,
},
format="json"
format="json",
)
self.assertEqual(response2.status_code, status.HTTP_200_OK)
# Test config endpoint
response3 = self.client.get("/api/ai/config/")
self.assertEqual(response3.status_code, status.HTTP_200_OK)

View file

@ -9,14 +9,12 @@ from rest_framework import status
from rest_framework.test import APITestCase
from documents.ai_scanner import AIScanResult
from documents.models import (
AISuggestionFeedback,
Correspondent,
Document,
DocumentType,
StoragePath,
Tag,
)
from documents.models import AISuggestionFeedback
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
@ -25,11 +23,11 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def setUp(self):
super().setUp()
# Create test user
self.user = User.objects.create_superuser(username="test_admin")
self.client.force_authenticate(user=self.user)
# Create test data
self.correspondent = Correspondent.objects.create(
name="Test Corp",
@ -52,7 +50,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
path="/archive/",
pk=1,
)
# Create test document
self.document = Document.objects.create(
title="Test Document",
@ -64,12 +62,12 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_ai_suggestions_endpoint_exists(self):
"""Test that the ai-suggestions endpoint is accessible."""
response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/"
f"/api/documents/{self.document.pk}/ai-suggestions/",
)
# Should not be 404
self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@mock.patch('documents.ai_scanner.get_ai_scanner')
@mock.patch("documents.ai_scanner.get_ai_scanner")
def test_get_ai_suggestions_success(self, mock_get_scanner):
"""Test successfully getting AI suggestions for a document."""
# Create mock scan result
@ -79,39 +77,39 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
scan_result.document_type = (self.doc_type.id, 0.88)
scan_result.storage_path = (self.storage_path.id, 0.80)
scan_result.title_suggestion = "Suggested Title"
# Mock scanner
mock_scanner = mock.Mock()
mock_scanner.scan_document.return_value = scan_result
mock_get_scanner.return_value = mock_scanner
# Make request
response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/"
f"/api/documents/{self.document.pk}/ai-suggestions/",
)
# Verify response
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
# Check tags
self.assertIn('tags', data)
self.assertEqual(len(data['tags']), 2)
self.assertEqual(data['tags'][0]['id'], self.tag1.id)
self.assertEqual(data['tags'][0]['confidence'], 0.85)
self.assertIn("tags", data)
self.assertEqual(len(data["tags"]), 2)
self.assertEqual(data["tags"][0]["id"], self.tag1.id)
self.assertEqual(data["tags"][0]["confidence"], 0.85)
# Check correspondent
self.assertIn('correspondent', data)
self.assertEqual(data['correspondent']['id'], self.correspondent.id)
self.assertEqual(data['correspondent']['confidence'], 0.90)
self.assertIn("correspondent", data)
self.assertEqual(data["correspondent"]["id"], self.correspondent.id)
self.assertEqual(data["correspondent"]["confidence"], 0.90)
# Check document type
self.assertIn('document_type', data)
self.assertEqual(data['document_type']['id'], self.doc_type.id)
self.assertIn("document_type", data)
self.assertEqual(data["document_type"]["id"], self.doc_type.id)
# Check title suggestion
self.assertIn('title_suggestion', data)
self.assertEqual(data['title_suggestion']['title'], "Suggested Title")
self.assertIn("title_suggestion", data)
self.assertEqual(data["title_suggestion"]["title"], "Suggested Title")
def test_get_ai_suggestions_no_content(self):
"""Test getting AI suggestions for document without content."""
@ -122,43 +120,43 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
checksum="empty123",
mime_type="application/pdf",
)
response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("no content", response.json()['detail'].lower())
self.assertIn("no content", response.json()["detail"].lower())
def test_get_ai_suggestions_document_not_found(self):
"""Test getting AI suggestions for non-existent document."""
response = self.client.get("/api/documents/99999/ai-suggestions/")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_apply_suggestion_tag(self):
"""Test applying a tag suggestion."""
request_data = {
'suggestion_type': 'tag',
'value_id': self.tag1.id,
'confidence': 0.85,
"suggestion_type": "tag",
"value_id": self.tag1.id,
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['status'], 'success')
self.assertEqual(response.json()["status"], "success")
# Verify tag was applied
self.document.refresh_from_db()
self.assertIn(self.tag1, self.document.tags.all())
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
@ -169,27 +167,27 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_correspondent(self):
"""Test applying a correspondent suggestion."""
request_data = {
'suggestion_type': 'correspondent',
'value_id': self.correspondent.id,
'confidence': 0.90,
"suggestion_type": "correspondent",
"value_id": self.correspondent.id,
"confidence": 0.90,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify correspondent was applied
self.document.refresh_from_db()
self.assertEqual(self.document.correspondent, self.correspondent)
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='correspondent',
suggestion_type="correspondent",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
@ -197,19 +195,19 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_document_type(self):
"""Test applying a document type suggestion."""
request_data = {
'suggestion_type': 'document_type',
'value_id': self.doc_type.id,
'confidence': 0.88,
"suggestion_type": "document_type",
"value_id": self.doc_type.id,
"confidence": 0.88,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify document type was applied
self.document.refresh_from_db()
self.assertEqual(self.document.document_type, self.doc_type)
@ -217,91 +215,91 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_title(self):
"""Test applying a title suggestion."""
request_data = {
'suggestion_type': 'title',
'value_text': 'New Suggested Title',
'confidence': 0.80,
"suggestion_type": "title",
"value_text": "New Suggested Title",
"confidence": 0.80,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify title was applied
self.document.refresh_from_db()
self.assertEqual(self.document.title, 'New Suggested Title')
self.assertEqual(self.document.title, "New Suggested Title")
def test_apply_suggestion_invalid_type(self):
"""Test applying suggestion with invalid type."""
request_data = {
'suggestion_type': 'invalid_type',
'value_id': 1,
'confidence': 0.85,
"suggestion_type": "invalid_type",
"value_id": 1,
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_apply_suggestion_missing_value(self):
"""Test applying suggestion without value_id or value_text."""
request_data = {
'suggestion_type': 'tag',
'confidence': 0.85,
"suggestion_type": "tag",
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_apply_suggestion_nonexistent_object(self):
"""Test applying suggestion with non-existent object ID."""
request_data = {
'suggestion_type': 'tag',
'value_id': 99999,
'confidence': 0.85,
"suggestion_type": "tag",
"value_id": 99999,
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_reject_suggestion(self):
"""Test rejecting an AI suggestion."""
request_data = {
'suggestion_type': 'tag',
'value_id': self.tag1.id,
'confidence': 0.65,
"suggestion_type": "tag",
"value_id": self.tag1.id,
"confidence": 0.65,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/reject-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['status'], 'success')
self.assertEqual(response.json()["status"], "success")
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
@ -312,46 +310,46 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_reject_suggestion_with_text(self):
"""Test rejecting a suggestion with text value."""
request_data = {
'suggestion_type': 'title',
'value_text': 'Bad Title Suggestion',
'confidence': 0.50,
"suggestion_type": "title",
"value_text": "Bad Title Suggestion",
"confidence": 0.50,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/reject-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='title',
suggestion_type="title",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
self.assertEqual(feedback.suggested_value_text, 'Bad Title Suggestion')
self.assertEqual(feedback.suggested_value_text, "Bad Title Suggestion")
def test_ai_suggestion_stats_empty(self):
"""Test getting statistics when no feedback exists."""
response = self.client.get("/api/documents/ai-suggestion-stats/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['total_suggestions'], 0)
self.assertEqual(data['total_applied'], 0)
self.assertEqual(data['total_rejected'], 0)
self.assertEqual(data['accuracy_rate'], 0)
self.assertEqual(data["total_suggestions"], 0)
self.assertEqual(data["total_applied"], 0)
self.assertEqual(data["total_rejected"], 0)
self.assertEqual(data["accuracy_rate"], 0)
def test_ai_suggestion_stats_with_data(self):
"""Test getting statistics with feedback data."""
# Create some feedback entries
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag1.id,
confidence=0.85,
status=AISuggestionFeedback.STATUS_APPLIED,
@ -359,7 +357,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
)
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag2.id,
confidence=0.70,
status=AISuggestionFeedback.STATUS_APPLIED,
@ -367,38 +365,38 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
)
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='correspondent',
suggestion_type="correspondent",
suggested_value_id=self.correspondent.id,
confidence=0.60,
status=AISuggestionFeedback.STATUS_REJECTED,
user=self.user,
)
response = self.client.get("/api/documents/ai-suggestion-stats/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
# Check overall stats
self.assertEqual(data['total_suggestions'], 3)
self.assertEqual(data['total_applied'], 2)
self.assertEqual(data['total_rejected'], 1)
self.assertAlmostEqual(data['accuracy_rate'], 66.67, places=1)
self.assertEqual(data["total_suggestions"], 3)
self.assertEqual(data["total_applied"], 2)
self.assertEqual(data["total_rejected"], 1)
self.assertAlmostEqual(data["accuracy_rate"], 66.67, places=1)
# Check by_type stats
self.assertIn('by_type', data)
self.assertIn('tag', data['by_type'])
self.assertEqual(data['by_type']['tag']['total'], 2)
self.assertEqual(data['by_type']['tag']['applied'], 2)
self.assertEqual(data['by_type']['tag']['rejected'], 0)
self.assertIn("by_type", data)
self.assertIn("tag", data["by_type"])
self.assertEqual(data["by_type"]["tag"]["total"], 2)
self.assertEqual(data["by_type"]["tag"]["applied"], 2)
self.assertEqual(data["by_type"]["tag"]["rejected"], 0)
# Check confidence averages
self.assertGreater(data['average_confidence_applied'], 0)
self.assertGreater(data['average_confidence_rejected'], 0)
self.assertGreater(data["average_confidence_applied"], 0)
self.assertGreater(data["average_confidence_rejected"], 0)
# Check recent suggestions
self.assertIn('recent_suggestions', data)
self.assertEqual(len(data['recent_suggestions']), 3)
self.assertIn("recent_suggestions", data)
self.assertEqual(len(data["recent_suggestions"]), 3)
def test_ai_suggestion_stats_accuracy_calculation(self):
"""Test that accuracy rate is calculated correctly."""
@ -406,57 +404,57 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
for i in range(7):
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag1.id,
confidence=0.80,
status=AISuggestionFeedback.STATUS_APPLIED,
user=self.user,
)
for i in range(3):
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag2.id,
confidence=0.60,
status=AISuggestionFeedback.STATUS_REJECTED,
user=self.user,
)
response = self.client.get("/api/documents/ai-suggestion-stats/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['total_suggestions'], 10)
self.assertEqual(data['total_applied'], 7)
self.assertEqual(data['total_rejected'], 3)
self.assertEqual(data['accuracy_rate'], 70.0)
self.assertEqual(data["total_suggestions"], 10)
self.assertEqual(data["total_applied"], 7)
self.assertEqual(data["total_rejected"], 3)
self.assertEqual(data["accuracy_rate"], 70.0)
def test_authentication_required(self):
"""Test that authentication is required for all endpoints."""
self.client.force_authenticate(user=None)
# Test ai-suggestions endpoint
response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/"
f"/api/documents/{self.document.pk}/ai-suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# Test apply-suggestion endpoint
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data={},
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# Test reject-suggestion endpoint
response = self.client.post(
f"/api/documents/{self.document.pk}/reject-suggestion/",
data={},
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# Test stats endpoint
response = self.client.get("/api/documents/ai-suggestion-stats/")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

View file

@ -11,17 +11,14 @@ Tests cover:
"""
from django.contrib.auth.models import User
from django.test import override_settings
from rest_framework import status
from rest_framework.test import APITestCase
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
class TestDeletionRequestAPI(APITestCase):
@ -33,7 +30,7 @@ class TestDeletionRequestAPI(APITestCase):
self.user1 = User.objects.create_user(username="user1", password="pass123")
self.user2 = User.objects.create_user(username="user2", password="pass123")
self.admin = User.objects.create_superuser(username="admin", password="admin123")
# Create test documents
self.doc1 = Document.objects.create(
title="Test Document 1",
@ -53,7 +50,7 @@ class TestDeletionRequestAPI(APITestCase):
checksum="checksum3",
mime_type="application/pdf",
)
# Create deletion requests
self.request1 = DeletionRequest.objects.create(
requested_by_ai=True,
@ -63,7 +60,7 @@ class TestDeletionRequestAPI(APITestCase):
impact_summary={"document_count": 1},
)
self.request1.documents.add(self.doc1)
self.request2 = DeletionRequest.objects.create(
requested_by_ai=True,
ai_reason="Low quality document",
@ -77,7 +74,7 @@ class TestDeletionRequestAPI(APITestCase):
"""Test that users can list their own deletion requests."""
self.client.force_authenticate(user=self.user1)
response = self.client.get("/api/deletion-requests/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]), 1)
self.assertEqual(response.data["results"][0]["id"], self.request1.id)
@ -86,7 +83,7 @@ class TestDeletionRequestAPI(APITestCase):
"""Test that admin can list all deletion requests."""
self.client.force_authenticate(user=self.admin)
response = self.client.get("/api/deletion-requests/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]), 2)
@ -94,7 +91,7 @@ class TestDeletionRequestAPI(APITestCase):
"""Test retrieving a single deletion request."""
self.client.force_authenticate(user=self.user1)
response = self.client.get(f"/api/deletion-requests/{self.request1.id}/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["id"], self.request1.id)
self.assertEqual(response.data["ai_reason"], "Duplicate document detected")
@ -104,23 +101,23 @@ class TestDeletionRequestAPI(APITestCase):
def test_approve_deletion_request_as_owner(self):
"""Test approving a deletion request as the owner."""
self.client.force_authenticate(user=self.user1)
# Verify document exists
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/approve/",
{"comment": "Approved by owner"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("message", response.data)
self.assertIn("execution_result", response.data)
self.assertEqual(response.data["execution_result"]["deleted_count"], 1)
# Verify document was deleted
self.assertFalse(Document.objects.filter(id=self.doc1.id).exists())
# Verify deletion request was updated
self.request1.refresh_from_db()
self.assertEqual(self.request1.status, DeletionRequest.STATUS_COMPLETED)
@ -131,18 +128,18 @@ class TestDeletionRequestAPI(APITestCase):
def test_approve_deletion_request_as_admin(self):
"""Test approving a deletion request as admin."""
self.client.force_authenticate(user=self.admin)
response = self.client.post(
f"/api/deletion-requests/{self.request2.id}/approve/",
{"comment": "Approved by admin"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("execution_result", response.data)
# Verify document was deleted
self.assertFalse(Document.objects.filter(id=self.doc2.id).exists())
# Verify deletion request was updated
self.request2.refresh_from_db()
self.assertEqual(self.request2.status, DeletionRequest.STATUS_COMPLETED)
@ -151,16 +148,16 @@ class TestDeletionRequestAPI(APITestCase):
def test_approve_deletion_request_without_permission(self):
"""Test that non-owners cannot approve deletion requests."""
self.client.force_authenticate(user=self.user2)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/approve/",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# Verify document was NOT deleted
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
# Verify deletion request was NOT updated
self.request1.refresh_from_db()
self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING)
@ -169,13 +166,13 @@ class TestDeletionRequestAPI(APITestCase):
"""Test that already approved requests cannot be approved again."""
self.request1.status = DeletionRequest.STATUS_APPROVED
self.request1.save()
self.client.force_authenticate(user=self.user1)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/approve/",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.data)
self.assertIn("pending", response.data["error"].lower())
@ -183,18 +180,18 @@ class TestDeletionRequestAPI(APITestCase):
def test_reject_deletion_request_as_owner(self):
"""Test rejecting a deletion request as the owner."""
self.client.force_authenticate(user=self.user1)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/reject/",
{"comment": "Not needed"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("message", response.data)
# Verify document was NOT deleted
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
# Verify deletion request was updated
self.request1.refresh_from_db()
self.assertEqual(self.request1.status, DeletionRequest.STATUS_REJECTED)
@ -205,16 +202,16 @@ class TestDeletionRequestAPI(APITestCase):
def test_reject_deletion_request_as_admin(self):
"""Test rejecting a deletion request as admin."""
self.client.force_authenticate(user=self.admin)
response = self.client.post(
f"/api/deletion-requests/{self.request2.id}/reject/",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify document was NOT deleted
self.assertTrue(Document.objects.filter(id=self.doc2.id).exists())
# Verify deletion request was updated
self.request2.refresh_from_db()
self.assertEqual(self.request2.status, DeletionRequest.STATUS_REJECTED)
@ -223,13 +220,13 @@ class TestDeletionRequestAPI(APITestCase):
def test_reject_deletion_request_without_permission(self):
"""Test that non-owners cannot reject deletion requests."""
self.client.force_authenticate(user=self.user2)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/reject/",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# Verify deletion request was NOT updated
self.request1.refresh_from_db()
self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING)
@ -238,31 +235,31 @@ class TestDeletionRequestAPI(APITestCase):
"""Test that already rejected requests cannot be rejected again."""
self.request1.status = DeletionRequest.STATUS_REJECTED
self.request1.save()
self.client.force_authenticate(user=self.user1)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/reject/",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.data)
def test_cancel_deletion_request_as_owner(self):
"""Test canceling a deletion request as the owner."""
self.client.force_authenticate(user=self.user1)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/cancel/",
{"comment": "Changed my mind"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("message", response.data)
# Verify document was NOT deleted
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
# Verify deletion request was updated
self.request1.refresh_from_db()
self.assertEqual(self.request1.status, DeletionRequest.STATUS_CANCELLED)
@ -273,13 +270,13 @@ class TestDeletionRequestAPI(APITestCase):
def test_cancel_deletion_request_without_permission(self):
"""Test that non-owners cannot cancel deletion requests."""
self.client.force_authenticate(user=self.user2)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# Verify deletion request was NOT updated
self.request1.refresh_from_db()
self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING)
@ -288,13 +285,13 @@ class TestDeletionRequestAPI(APITestCase):
"""Test that approved requests cannot be cancelled."""
self.request1.status = DeletionRequest.STATUS_APPROVED
self.request1.save()
self.client.force_authenticate(user=self.user1)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.data)
@ -309,17 +306,17 @@ class TestDeletionRequestAPI(APITestCase):
impact_summary={"document_count": 2},
)
multi_request.documents.add(self.doc1, self.doc3)
self.client.force_authenticate(user=self.user1)
response = self.client.post(
f"/api/deletion-requests/{multi_request.id}/approve/",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["execution_result"]["deleted_count"], 2)
self.assertEqual(response.data["execution_result"]["total_documents"], 2)
# Verify both documents were deleted
self.assertFalse(Document.objects.filter(id=self.doc1.id).exists())
self.assertFalse(Document.objects.filter(id=self.doc3.id).exists())
@ -330,15 +327,15 @@ class TestDeletionRequestAPI(APITestCase):
tag = Tag.objects.create(name="test-tag")
correspondent = Correspondent.objects.create(name="Test Corp")
doc_type = DocumentType.objects.create(name="Invoice")
self.doc1.tags.add(tag)
self.doc1.correspondent = correspondent
self.doc1.document_type = doc_type
self.doc1.save()
self.client.force_authenticate(user=self.user1)
response = self.client.get(f"/api/deletion-requests/{self.request1.id}/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
doc_details = response.data["document_details"]
self.assertEqual(len(doc_details), 1)
@ -352,7 +349,7 @@ class TestDeletionRequestAPI(APITestCase):
"""Test that unauthenticated users cannot access the API."""
response = self.client.get("/api/deletion-requests/")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
response = self.client.post(
f"/api/deletion-requests/{self.request1.id}/approve/",
)

View file

@ -1350,18 +1350,18 @@ class TestConsumerAIScannerIntegration(
correspondent = Correspondent.objects.create(name="Test Corp")
doc_type = DocumentType.objects.create(name="Invoice")
storage_path = StoragePath.objects.create(name="Invoices", path="/invoices")
# Create mock AI scanner
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
# Mock scan results
scan_result = AIScanResult()
scan_result.tags = [(tag1.id, 0.85), (tag2.id, 0.75)]
scan_result.correspondent = (correspondent.id, 0.90)
scan_result.document_type = (doc_type.id, 0.85)
scan_result.storage_path = (storage_path.id, 0.80)
mock_scanner.scan_document.return_value = scan_result
mock_scanner.apply_scan_results.return_value = {
"applied": {
@ -1381,20 +1381,20 @@ class TestConsumerAIScannerIntegration(
"workflows": [],
},
}
# Run consumer
filename = self.get_test_file()
with self.get_consumer(filename) as consumer:
consumer.run()
# Verify document was created
document = Document.objects.first()
self.assertIsNotNone(document)
# Verify AI scanner was called
mock_scanner.scan_document.assert_called_once()
mock_scanner.apply_scan_results.assert_called_once()
# Verify the call arguments
call_args = mock_scanner.scan_document.call_args
self.assertEqual(call_args[1]["document"], document)
@ -1412,11 +1412,11 @@ class TestConsumerAIScannerIntegration(
demonstrating graceful degradation.
"""
filename = self.get_test_file()
# Consumer should complete successfully even with ML disabled
with self.get_consumer(filename) as consumer:
consumer.run()
# Verify document was created
document = Document.objects.first()
self.assertIsNotNone(document)
@ -1435,13 +1435,13 @@ class TestConsumerAIScannerIntegration(
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
mock_scanner.scan_document.side_effect = Exception("AI Scanner failed")
filename = self.get_test_file()
# Consumer should complete despite AI scanner failure
with self.get_consumer(filename) as consumer:
consumer.run()
# Verify document was created despite AI failure
document = Document.objects.first()
self.assertIsNotNone(document)
@ -1457,17 +1457,17 @@ class TestConsumerAIScannerIntegration(
"""
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
self.create_empty_scan_result_mock(mock_scanner)
filename = self.get_test_file()
with self.get_consumer(filename) as consumer:
consumer.run()
document = Document.objects.first()
self.assertIsNotNone(document)
# Verify AI scanner was called with PDF
mock_scanner.scan_document.assert_called_once()
call_args = mock_scanner.scan_document.call_args
@ -1488,7 +1488,7 @@ class TestConsumerAIScannerIntegration(
self.dirs.scratch_dir,
self.get_test_archive_file(),
)
with mock.patch("documents.parsers.document_consumer_declaration.send") as m:
m.return_value = [
(
@ -1500,21 +1500,21 @@ class TestConsumerAIScannerIntegration(
},
),
]
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
self.create_empty_scan_result_mock(mock_scanner)
# Create a PNG file
dst = self.get_test_file_with_name("sample.png")
with self.get_consumer(dst) as consumer:
consumer.run()
document = Document.objects.first()
self.assertIsNotNone(document)
# Verify AI scanner was called
mock_scanner.scan_document.assert_called_once()
@ -1527,26 +1527,26 @@ class TestConsumerAIScannerIntegration(
Verifies that AI scanning adds minimal overhead to document consumption.
"""
import time
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
self.create_empty_scan_result_mock(mock_scanner)
filename = self.get_test_file()
start_time = time.time()
with self.get_consumer(filename) as consumer:
consumer.run()
end_time = time.time()
# Verify document was created
document = Document.objects.first()
self.assertIsNotNone(document)
# Verify AI scanner was called
mock_scanner.scan_document.assert_called_once()
# With mocks, this should be very fast (<1s).
# TODO: Implement proper performance testing with real ML models in integration/performance test suite.
elapsed_time = end_time - start_time
@ -1561,33 +1561,32 @@ class TestConsumerAIScannerIntegration(
Verifies that AI scanner respects database transactions and handles
rollbacks correctly.
"""
from django.db import transaction as db_transaction
tag = Tag.objects.create(name="Invoice")
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
scan_result = AIScanResult()
scan_result.tags = [(tag.id, 0.85)]
mock_scanner.scan_document.return_value = scan_result
# Mock apply_scan_results to raise an exception after some work
def apply_with_error(document, scan_result, auto_apply=True):
# Simulate partial work
document.tags.add(tag)
# Then fail
raise Exception("Simulated transaction failure")
mock_scanner.apply_scan_results.side_effect = apply_with_error
filename = self.get_test_file()
# Even with AI scanner failure, the document should still be created
# because we handle AI scanner errors gracefully
with self.get_consumer(filename) as consumer:
consumer.run()
document = Document.objects.first()
self.assertIsNotNone(document)
# The tag addition from AI scanner should be rolled back due to exception
@ -1604,17 +1603,17 @@ class TestConsumerAIScannerIntegration(
"""
tag1 = Tag.objects.create(name="Invoice")
tag2 = Tag.objects.create(name="Receipt")
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
# Configure scanner to return different results for each call
scan_results = []
for tag in [tag1, tag2]:
scan_result = AIScanResult()
scan_result.tags = [(tag.id, 0.85)]
scan_results.append(scan_result)
mock_scanner.scan_document.side_effect = scan_results
mock_scanner.apply_scan_results.return_value = {
"applied": {
@ -1634,20 +1633,20 @@ class TestConsumerAIScannerIntegration(
"workflows": [],
},
}
# Process multiple documents
filenames = [self.get_test_file()]
# Create second file
filenames.append(self.get_test_file_with_name("sample2.pdf"))
for filename in filenames:
with self.get_consumer(filename) as consumer:
consumer.run()
# Verify both documents were created
documents = Document.objects.all()
self.assertEqual(documents.count(), 2)
# Verify AI scanner was called for each document
self.assertEqual(mock_scanner.scan_document.call_count, 2)
@ -1661,17 +1660,17 @@ class TestConsumerAIScannerIntegration(
"""
mock_scanner = MagicMock()
mock_get_scanner.return_value = mock_scanner
self.create_empty_scan_result_mock(mock_scanner)
filename = self.get_test_file()
with self.get_consumer(filename) as consumer:
consumer.run()
document = Document.objects.first()
self.assertIsNotNone(document)
# Verify AI scanner received text content
mock_scanner.scan_document.assert_called_once()
call_args = mock_scanner.scan_document.call_args
@ -1686,10 +1685,10 @@ class TestConsumerAIScannerIntegration(
the AI scanner is not invoked at all.
"""
filename = self.get_test_file()
with self.get_consumer(filename) as consumer:
consumer.run()
# Document should be created normally without AI scanning
document = Document.objects.first()
self.assertIsNotNone(document)

View file

@ -15,13 +15,11 @@ from django.contrib.auth.models import User
from django.test import TestCase
from django.utils import timezone
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
class TestDeletionRequestModelCreation(TestCase):
@ -45,7 +43,7 @@ class TestDeletionRequestModelCreation(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
self.assertIsNotNone(request)
self.assertTrue(request.requested_by_ai)
self.assertEqual(request.ai_reason, "Test reason")
@ -59,7 +57,7 @@ class TestDeletionRequestModelCreation(TestCase):
ai_reason="Test",
user=self.user,
)
self.assertIsNotNone(request.created_at)
self.assertIsNotNone(request.updated_at)
@ -70,7 +68,7 @@ class TestDeletionRequestModelCreation(TestCase):
ai_reason="Test",
user=self.user,
)
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
def test_deletion_request_with_documents(self):
@ -80,16 +78,16 @@ class TestDeletionRequestModelCreation(TestCase):
ai_reason="Test",
user=self.user,
)
doc2 = Document.objects.create(
title="Document 2",
content="Content 2",
checksum="checksum2",
mime_type="application/pdf",
)
request.documents.add(self.doc, doc2)
self.assertEqual(request.documents.count(), 2)
self.assertIn(self.doc, request.documents.all())
self.assertIn(doc2, request.documents.all())
@ -101,7 +99,7 @@ class TestDeletionRequestModelCreation(TestCase):
ai_reason="Test",
user=self.user,
)
self.assertIsInstance(request.impact_summary, dict)
self.assertEqual(request.impact_summary, {})
@ -112,14 +110,14 @@ class TestDeletionRequestModelCreation(TestCase):
"affected_tags": ["tag1", "tag2"],
"metadata": {"key": "value"},
}
request = DeletionRequest.objects.create(
requested_by_ai=True,
ai_reason="Test",
user=self.user,
impact_summary=impact,
)
self.assertEqual(request.impact_summary["document_count"], 5)
self.assertEqual(request.impact_summary["affected_tags"], ["tag1", "tag2"])
@ -131,9 +129,9 @@ class TestDeletionRequestModelCreation(TestCase):
user=self.user,
)
request.documents.add(self.doc)
str_repr = str(request)
self.assertIn("Deletion Request", str_repr)
self.assertIn(str(request.id), str_repr)
self.assertIn("1 documents", str_repr)
@ -162,9 +160,9 @@ class TestDeletionRequestApprove(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
result = request.approve(self.approver, "Approved")
self.assertTrue(result)
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
self.assertEqual(request.reviewed_by, self.approver)
@ -179,9 +177,9 @@ class TestDeletionRequestApprove(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
result = request.approve(self.approver)
self.assertTrue(result)
self.assertEqual(request.review_comment, "")
@ -195,9 +193,9 @@ class TestDeletionRequestApprove(TestCase):
reviewed_by=self.user,
reviewed_at=timezone.now(),
)
result = request.approve(self.approver, "Trying to approve again")
self.assertFalse(result)
self.assertEqual(request.reviewed_by, self.user) # Should not change
@ -211,9 +209,9 @@ class TestDeletionRequestApprove(TestCase):
reviewed_by=self.user,
reviewed_at=timezone.now(),
)
result = request.approve(self.approver, "Trying to approve rejected")
self.assertFalse(result)
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
@ -225,9 +223,9 @@ class TestDeletionRequestApprove(TestCase):
user=self.user,
status=DeletionRequest.STATUS_CANCELLED,
)
result = request.approve(self.approver, "Trying to approve cancelled")
self.assertFalse(result)
self.assertEqual(request.status, DeletionRequest.STATUS_CANCELLED)
@ -239,9 +237,9 @@ class TestDeletionRequestApprove(TestCase):
user=self.user,
status=DeletionRequest.STATUS_COMPLETED,
)
result = request.approve(self.approver, "Trying to approve completed")
self.assertFalse(result)
self.assertEqual(request.status, DeletionRequest.STATUS_COMPLETED)
@ -253,11 +251,11 @@ class TestDeletionRequestApprove(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
before_approval = timezone.now()
result = request.approve(self.approver, "Approved")
after_approval = timezone.now()
self.assertTrue(result)
self.assertIsNotNone(request.reviewed_at)
self.assertGreaterEqual(request.reviewed_at, before_approval)
@ -280,9 +278,9 @@ class TestDeletionRequestReject(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
result = request.reject(self.reviewer, "Not necessary")
self.assertTrue(result)
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
self.assertEqual(request.reviewed_by, self.reviewer)
@ -297,9 +295,9 @@ class TestDeletionRequestReject(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
result = request.reject(self.reviewer)
self.assertTrue(result)
self.assertEqual(request.review_comment, "")
@ -313,9 +311,9 @@ class TestDeletionRequestReject(TestCase):
reviewed_by=self.user,
reviewed_at=timezone.now(),
)
result = request.reject(self.reviewer, "Trying to reject again")
self.assertFalse(result)
self.assertEqual(request.reviewed_by, self.user) # Should not change
@ -329,9 +327,9 @@ class TestDeletionRequestReject(TestCase):
reviewed_by=self.user,
reviewed_at=timezone.now(),
)
result = request.reject(self.reviewer, "Trying to reject approved")
self.assertFalse(result)
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
@ -343,9 +341,9 @@ class TestDeletionRequestReject(TestCase):
user=self.user,
status=DeletionRequest.STATUS_CANCELLED,
)
result = request.reject(self.reviewer, "Trying to reject cancelled")
self.assertFalse(result)
self.assertEqual(request.status, DeletionRequest.STATUS_CANCELLED)
@ -357,9 +355,9 @@ class TestDeletionRequestReject(TestCase):
user=self.user,
status=DeletionRequest.STATUS_COMPLETED,
)
result = request.reject(self.reviewer, "Trying to reject completed")
self.assertFalse(result)
self.assertEqual(request.status, DeletionRequest.STATUS_COMPLETED)
@ -371,11 +369,11 @@ class TestDeletionRequestReject(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
before_rejection = timezone.now()
result = request.reject(self.reviewer, "Rejected")
after_rejection = timezone.now()
self.assertTrue(result)
self.assertIsNotNone(request.reviewed_at)
self.assertGreaterEqual(request.reviewed_at, before_rejection)
@ -389,11 +387,11 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
"""Set up test data."""
self.user = User.objects.create_user(username="user1", password="pass")
self.approver = User.objects.create_user(username="approver", password="pass")
self.correspondent = Correspondent.objects.create(name="Test Corp")
self.doc_type = DocumentType.objects.create(name="Invoice")
self.tag = Tag.objects.create(name="Important")
self.doc1 = Document.objects.create(
title="Document 1",
content="Content 1",
@ -403,7 +401,7 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
document_type=self.doc_type,
)
self.doc1.tags.add(self.tag)
self.doc2 = Document.objects.create(
title="Document 2",
content="Content 2",
@ -421,15 +419,15 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
impact_summary={"document_count": 2},
)
request.documents.add(self.doc1, self.doc2)
# Verify initial state
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
self.assertIsNone(request.reviewed_by)
self.assertIsNone(request.reviewed_at)
# Approve
success = request.approve(self.approver, "Confirmed duplicates")
# Verify final state
self.assertTrue(success)
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
@ -446,13 +444,13 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
status=DeletionRequest.STATUS_PENDING,
)
request.documents.add(self.doc1)
# Verify initial state
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
# Reject
success = request.reject(self.approver, "Not duplicates")
# Verify final state
self.assertTrue(success)
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
@ -468,14 +466,14 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
# Reject first
request.reject(self.user, "Rejected")
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
# Try to approve
success = request.approve(self.approver, "Changed my mind")
# Should fail
self.assertFalse(success)
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
@ -488,14 +486,14 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
# Approve first
request.approve(self.approver, "Approved")
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
# Try to reject
success = request.reject(self.user, "Changed my mind")
# Should fail
self.assertFalse(success)
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
@ -516,7 +514,7 @@ class TestDeletionRequestAuditTrail(TestCase):
ai_reason="Test",
user=self.user,
)
self.assertEqual(request.user, self.user)
def test_audit_trail_records_reviewer(self):
@ -527,9 +525,9 @@ class TestDeletionRequestAuditTrail(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
request.approve(self.approver, "Approved")
self.assertEqual(request.reviewed_by, self.approver)
self.assertNotEqual(request.reviewed_by, request.user)
@ -540,12 +538,12 @@ class TestDeletionRequestAuditTrail(TestCase):
ai_reason="Test",
user=self.user,
)
created_at = request.created_at
# Approve the request
request.approve(self.approver, "Approved")
# Verify timestamps
self.assertIsNotNone(request.created_at)
self.assertIsNotNone(request.updated_at)
@ -555,16 +553,16 @@ class TestDeletionRequestAuditTrail(TestCase):
def test_audit_trail_preserves_ai_reason(self):
"""Test that AI's original reason is preserved."""
original_reason = "AI detected duplicates based on content similarity"
request = DeletionRequest.objects.create(
requested_by_ai=True,
ai_reason=original_reason,
user=self.user,
)
# Approve with different comment
request.approve(self.approver, "User confirmed")
# Original AI reason should be preserved
self.assertEqual(request.ai_reason, original_reason)
self.assertEqual(request.review_comment, "User confirmed")
@ -582,7 +580,7 @@ class TestDeletionRequestAuditTrail(TestCase):
"completed_by": "system",
},
)
self.assertEqual(request.completion_details["deleted_count"], 5)
self.assertEqual(request.completion_details["failed_count"], 0)
@ -593,17 +591,17 @@ class TestDeletionRequestAuditTrail(TestCase):
ai_reason="Reason 1",
user=self.user,
)
request2 = DeletionRequest.objects.create(
requested_by_ai=True,
ai_reason="Reason 2",
user=self.user,
)
# Approve one, reject another
request1.approve(self.approver, "Approved")
request2.reject(self.approver, "Rejected")
# Verify each has its own audit trail
self.assertEqual(request1.status, DeletionRequest.STATUS_APPROVED)
self.assertEqual(request2.status, DeletionRequest.STATUS_REJECTED)
@ -625,13 +623,13 @@ class TestDeletionRequestModelRelationships(TestCase):
ai_reason="Test",
user=self.user,
)
request_id = request.id
self.assertEqual(DeletionRequest.objects.filter(id=request_id).count(), 1)
# Delete user
self.user.delete()
# Request should be deleted
self.assertEqual(DeletionRequest.objects.filter(id=request_id).count(), 0)
@ -649,15 +647,15 @@ class TestDeletionRequestModelRelationships(TestCase):
checksum="checksum2",
mime_type="application/pdf",
)
request = DeletionRequest.objects.create(
requested_by_ai=True,
ai_reason="Test",
user=self.user,
)
request.documents.add(doc1, doc2)
self.assertEqual(request.documents.count(), 2)
self.assertEqual(doc1.deletion_requests.count(), 1)
self.assertEqual(doc2.deletion_requests.count(), 1)
@ -670,29 +668,29 @@ class TestDeletionRequestModelRelationships(TestCase):
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
self.assertIsNone(request.reviewed_by)
def test_reviewed_by_set_null_on_delete(self):
"""Test that reviewed_by is set to null when reviewer is deleted."""
approver = User.objects.create_user(username="approver", password="pass")
request = DeletionRequest.objects.create(
requested_by_ai=True,
ai_reason="Test",
user=self.user,
status=DeletionRequest.STATUS_PENDING,
)
request.approve(approver, "Approved")
self.assertEqual(request.reviewed_by, approver)
# Delete approver
approver.delete()
# Refresh request
request.refresh_from_db()
# reviewed_by should be null
self.assertIsNone(request.reviewed_by)
# But the request should still exist

View file

@ -8,11 +8,9 @@ from unittest import mock
from django.test import TestCase
from documents.ml.model_cache import (
CacheMetrics,
LRUCache,
ModelCacheManager,
)
from documents.ml.model_cache import CacheMetrics
from documents.ml.model_cache import LRUCache
from documents.ml.model_cache import ModelCacheManager
class TestCacheMetrics(TestCase):
@ -22,10 +20,10 @@ class TestCacheMetrics(TestCase):
"""Test recording cache hits."""
metrics = CacheMetrics()
self.assertEqual(metrics.hits, 0)
metrics.record_hit()
self.assertEqual(metrics.hits, 1)
metrics.record_hit()
self.assertEqual(metrics.hits, 2)
@ -33,26 +31,26 @@ class TestCacheMetrics(TestCase):
"""Test recording cache misses."""
metrics = CacheMetrics()
self.assertEqual(metrics.misses, 0)
metrics.record_miss()
self.assertEqual(metrics.misses, 1)
def test_get_stats(self):
"""Test getting cache statistics."""
metrics = CacheMetrics()
# Initial stats
stats = metrics.get_stats()
self.assertEqual(stats["hits"], 0)
self.assertEqual(stats["misses"], 0)
self.assertEqual(stats["hit_rate"], "0.00%")
# After some hits and misses
metrics.record_hit()
metrics.record_hit()
metrics.record_hit()
metrics.record_miss()
stats = metrics.get_stats()
self.assertEqual(stats["hits"], 3)
self.assertEqual(stats["misses"], 1)
@ -64,9 +62,9 @@ class TestCacheMetrics(TestCase):
metrics = CacheMetrics()
metrics.record_hit()
metrics.record_miss()
metrics.reset()
stats = metrics.get_stats()
self.assertEqual(stats["hits"], 0)
self.assertEqual(stats["misses"], 0)
@ -78,28 +76,28 @@ class TestLRUCache(TestCase):
def test_put_and_get(self):
"""Test basic cache operations."""
cache = LRUCache(max_size=2)
cache.put("key1", "value1")
cache.put("key2", "value2")
self.assertEqual(cache.get("key1"), "value1")
self.assertEqual(cache.get("key2"), "value2")
def test_cache_miss(self):
"""Test cache miss returns None."""
cache = LRUCache(max_size=2)
result = cache.get("nonexistent")
self.assertIsNone(result)
def test_lru_eviction(self):
"""Test LRU eviction policy."""
cache = LRUCache(max_size=2)
cache.put("key1", "value1")
cache.put("key2", "value2")
cache.put("key3", "value3") # Should evict key1
self.assertIsNone(cache.get("key1")) # Evicted
self.assertEqual(cache.get("key2"), "value2")
self.assertEqual(cache.get("key3"), "value3")
@ -107,12 +105,12 @@ class TestLRUCache(TestCase):
def test_lru_update_access_order(self):
"""Test that accessing an item updates its position."""
cache = LRUCache(max_size=2)
cache.put("key1", "value1")
cache.put("key2", "value2")
cache.get("key1") # Access key1, making it most recent
cache.put("key3", "value3") # Should evict key2, not key1
self.assertEqual(cache.get("key1"), "value1")
self.assertIsNone(cache.get("key2")) # Evicted
self.assertEqual(cache.get("key3"), "value3")
@ -120,24 +118,24 @@ class TestLRUCache(TestCase):
def test_cache_size(self):
"""Test cache size tracking."""
cache = LRUCache(max_size=3)
self.assertEqual(cache.size(), 0)
cache.put("key1", "value1")
self.assertEqual(cache.size(), 1)
cache.put("key2", "value2")
self.assertEqual(cache.size(), 2)
def test_clear(self):
"""Test clearing cache."""
cache = LRUCache(max_size=2)
cache.put("key1", "value1")
cache.put("key2", "value2")
cache.clear()
self.assertEqual(cache.size(), 0)
self.assertIsNone(cache.get("key1"))
self.assertIsNone(cache.get("key2"))
@ -155,20 +153,20 @@ class TestModelCacheManager(TestCase):
"""Test that ModelCacheManager is a singleton."""
instance1 = ModelCacheManager.get_instance()
instance2 = ModelCacheManager.get_instance()
self.assertIs(instance1, instance2)
def test_get_or_load_model_first_time(self):
"""Test loading a model for the first time (cache miss)."""
cache_manager = ModelCacheManager.get_instance()
# Mock loader function
mock_model = mock.Mock()
loader = mock.Mock(return_value=mock_model)
# Load model
result = cache_manager.get_or_load_model("test_model", loader)
# Verify loader was called
loader.assert_called_once()
self.assertIs(result, mock_model)
@ -176,17 +174,17 @@ class TestModelCacheManager(TestCase):
def test_get_or_load_model_cached(self):
"""Test loading a model from cache (cache hit)."""
cache_manager = ModelCacheManager.get_instance()
# Mock loader function
mock_model = mock.Mock()
loader = mock.Mock(return_value=mock_model)
# Load model first time
cache_manager.get_or_load_model("test_model", loader)
# Load model second time (should be cached)
result = cache_manager.get_or_load_model("test_model", loader)
# Verify loader was only called once
loader.assert_called_once()
self.assertIs(result, mock_model)
@ -197,48 +195,48 @@ class TestModelCacheManager(TestCase):
cache_manager = ModelCacheManager.get_instance(
disk_cache_dir=tmpdir,
)
# Create test embeddings
embeddings = {
1: "embedding1",
2: "embedding2",
3: "embedding3",
}
# Save to disk
cache_manager.save_embeddings_to_disk("test_embeddings", embeddings)
# Verify file was created
cache_file = Path(tmpdir) / "test_embeddings.pkl"
self.assertTrue(cache_file.exists())
# Load from disk
loaded = cache_manager.load_embeddings_from_disk("test_embeddings")
# Verify embeddings match
self.assertEqual(loaded, embeddings)
def test_get_metrics(self):
"""Test getting cache metrics."""
cache_manager = ModelCacheManager.get_instance()
# Mock loader
loader = mock.Mock(return_value=mock.Mock())
# Generate some cache activity
cache_manager.get_or_load_model("model1", loader)
cache_manager.get_or_load_model("model1", loader) # Cache hit
cache_manager.get_or_load_model("model2", loader)
# Get metrics
metrics = cache_manager.get_metrics()
# Verify metrics structure
self.assertIn("hits", metrics)
self.assertIn("misses", metrics)
self.assertIn("cache_size", metrics)
self.assertIn("max_size", metrics)
# Verify hit/miss counts
self.assertEqual(metrics["hits"], 1) # One cache hit
self.assertEqual(metrics["misses"], 2) # Two cache misses
@ -249,21 +247,21 @@ class TestModelCacheManager(TestCase):
cache_manager = ModelCacheManager.get_instance(
disk_cache_dir=tmpdir,
)
# Add some models to cache
loader = mock.Mock(return_value=mock.Mock())
cache_manager.get_or_load_model("model1", loader)
# Add embeddings to disk
embeddings = {1: "embedding1"}
cache_manager.save_embeddings_to_disk("test", embeddings)
# Clear all
cache_manager.clear_all()
# Verify memory cache is cleared
self.assertEqual(cache_manager.model_cache.size(), 0)
# Verify disk cache is cleared
cache_file = Path(tmpdir) / "test.pkl"
self.assertFalse(cache_file.exists())
@ -271,22 +269,22 @@ class TestModelCacheManager(TestCase):
def test_warm_up(self):
"""Test model warm-up functionality."""
cache_manager = ModelCacheManager.get_instance()
# Create mock loaders
model1 = mock.Mock()
model2 = mock.Mock()
loaders = {
"model1": mock.Mock(return_value=model1),
"model2": mock.Mock(return_value=model2),
}
# Warm up
cache_manager.warm_up(loaders)
# Verify all loaders were called
for loader in loaders.values():
loader.assert_called_once()
# Verify models are cached
self.assertEqual(cache_manager.model_cache.size(), 2)

View file

@ -150,7 +150,6 @@ class TestMLCacheDirectory:
def test_model_cache_writable(self, tmp_path):
"""Test that we can write to model cache directory."""
import pathlib
# Use tmp_path fixture for testing
cache_dir = tmp_path / ".cache" / "huggingface"
@ -169,7 +168,6 @@ class TestMLCacheDirectory:
def test_torch_cache_directory(self, tmp_path, monkeypatch):
"""Test that PyTorch can use a custom cache directory."""
import torch
# Set custom cache directory
cache_dir = tmp_path / ".cache" / "torch"
@ -204,9 +202,10 @@ class TestMLPerformanceBasic:
def test_numpy_performance_basic(self):
"""Test basic NumPy performance with larger arrays."""
import numpy as np
import time
import numpy as np
# Create large array (10 million elements)
arr = np.random.rand(10_000_000)

View file

@ -89,6 +89,8 @@ from rest_framework.viewsets import ViewSet
from documents import bulk_edit
from documents import index
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import get_ai_scanner
from documents.bulk_download import ArchiveOnlyStrategy
from documents.bulk_download import OriginalAndArchiveStrategy
from documents.bulk_download import OriginalsOnlyStrategy
@ -141,13 +143,10 @@ from documents.models import UiSettings
from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import get_ai_scanner
from documents.parsers import get_parser_class_for_mime_type
from documents.parsers import parse_date_generator
from documents.permissions import AcknowledgeTasksPermissions
from documents.permissions import CanApplyAISuggestionsPermission
from documents.permissions import CanApproveDeletionsPermission
from documents.permissions import CanConfigureAIPermission
from documents.permissions import CanViewAISuggestionsPermission
from documents.permissions import PaperlessAdminPermissions
@ -1370,36 +1369,36 @@ class UnifiedSearchViewSet(DocumentViewSet):
"""
from documents.ai_scanner import get_ai_scanner
from documents.serializers.ai_suggestions import AISuggestionsSerializer
try:
document = self.get_object()
# Check if document has content to scan
if not document.content:
return Response(
{"detail": "Document has no content to analyze"},
status=400,
)
# Get AI scanner instance
scanner = get_ai_scanner()
# Perform AI scan
scan_result = scanner.scan_document(
document=document,
document_text=document.content,
original_file_path=document.source_path if hasattr(document, 'source_path') else None,
original_file_path=document.source_path if hasattr(document, "source_path") else None,
)
# Convert scan result to serializable format
data = AISuggestionsSerializer.from_scan_result(scan_result, document.id)
# Serialize and return
serializer = AISuggestionsSerializer(data=data)
serializer.is_valid(raise_exception=True)
return Response(serializer.validated_data)
except Exception as e:
logger.error(f"Error getting AI suggestions for document {pk}: {e}", exc_info=True)
return Response(
@ -1416,56 +1415,56 @@ class UnifiedSearchViewSet(DocumentViewSet):
"""
from documents.models import AISuggestionFeedback
from documents.serializers.ai_suggestions import ApplySuggestionSerializer
try:
document = self.get_object()
# Validate input
serializer = ApplySuggestionSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
suggestion_type = serializer.validated_data['suggestion_type']
value_id = serializer.validated_data.get('value_id')
value_text = serializer.validated_data.get('value_text')
confidence = serializer.validated_data['confidence']
suggestion_type = serializer.validated_data["suggestion_type"]
value_id = serializer.validated_data.get("value_id")
value_text = serializer.validated_data.get("value_text")
confidence = serializer.validated_data["confidence"]
# Apply the suggestion based on type
applied = False
result_message = ""
if suggestion_type == 'tag' and value_id:
if suggestion_type == "tag" and value_id:
tag = Tag.objects.get(pk=value_id)
document.tags.add(tag)
applied = True
result_message = f"Tag '{tag.name}' applied"
elif suggestion_type == 'correspondent' and value_id:
elif suggestion_type == "correspondent" and value_id:
correspondent = Correspondent.objects.get(pk=value_id)
document.correspondent = correspondent
document.save()
applied = True
result_message = f"Correspondent '{correspondent.name}' applied"
elif suggestion_type == 'document_type' and value_id:
elif suggestion_type == "document_type" and value_id:
doc_type = DocumentType.objects.get(pk=value_id)
document.document_type = doc_type
document.save()
applied = True
result_message = f"Document type '{doc_type.name}' applied"
elif suggestion_type == 'storage_path' and value_id:
elif suggestion_type == "storage_path" and value_id:
storage_path = StoragePath.objects.get(pk=value_id)
document.storage_path = storage_path
document.save()
applied = True
result_message = f"Storage path '{storage_path.name}' applied"
elif suggestion_type == 'title' and value_text:
elif suggestion_type == "title" and value_text:
document.title = value_text
document.save()
applied = True
result_message = f"Title updated to '{value_text}'"
if applied:
# Record feedback
AISuggestionFeedback.objects.create(
@ -1477,7 +1476,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
status=AISuggestionFeedback.STATUS_APPLIED,
user=request.user,
)
return Response({
"status": "success",
"message": result_message,
@ -1487,8 +1486,8 @@ class UnifiedSearchViewSet(DocumentViewSet):
{"detail": "Invalid suggestion type or missing value"},
status=400,
)
except (Tag.DoesNotExist, Correspondent.DoesNotExist,
except (Tag.DoesNotExist, Correspondent.DoesNotExist,
DocumentType.DoesNotExist, StoragePath.DoesNotExist):
return Response(
{"detail": "Referenced object not found"},
@ -1510,19 +1509,19 @@ class UnifiedSearchViewSet(DocumentViewSet):
"""
from documents.models import AISuggestionFeedback
from documents.serializers.ai_suggestions import RejectSuggestionSerializer
try:
document = self.get_object()
# Validate input
serializer = RejectSuggestionSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
suggestion_type = serializer.validated_data['suggestion_type']
value_id = serializer.validated_data.get('value_id')
value_text = serializer.validated_data.get('value_text')
confidence = serializer.validated_data['confidence']
suggestion_type = serializer.validated_data["suggestion_type"]
value_id = serializer.validated_data.get("value_id")
value_text = serializer.validated_data.get("value_text")
confidence = serializer.validated_data["confidence"]
# Record feedback
AISuggestionFeedback.objects.create(
document=document,
@ -1533,12 +1532,12 @@ class UnifiedSearchViewSet(DocumentViewSet):
status=AISuggestionFeedback.STATUS_REJECTED,
user=request.user,
)
return Response({
"status": "success",
"message": "Suggestion rejected and feedback recorded",
})
except Exception as e:
logger.error(f"Error rejecting suggestion for document {pk}: {e}", exc_info=True)
return Response(
@ -1554,78 +1553,83 @@ class UnifiedSearchViewSet(DocumentViewSet):
Returns aggregated data about applied vs rejected suggestions,
accuracy rates, and confidence scores.
"""
from django.db.models import Avg, Count, Q
from django.db.models import Avg
from django.db.models import Count
from django.db.models import Q
from documents.models import AISuggestionFeedback
from documents.serializers.ai_suggestions import AISuggestionStatsSerializer
try:
# Get overall counts
total_feedbacks = AISuggestionFeedback.objects.count()
total_applied = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_APPLIED
status=AISuggestionFeedback.STATUS_APPLIED,
).count()
total_rejected = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_REJECTED
status=AISuggestionFeedback.STATUS_REJECTED,
).count()
# Calculate accuracy rate
accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0
# Get statistics by suggestion type using a single aggregated query
stats_by_type = AISuggestionFeedback.objects.values('suggestion_type').annotate(
total=Count('id'),
applied=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
rejected=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_REJECTED))
stats_by_type = AISuggestionFeedback.objects.values("suggestion_type").annotate(
total=Count("id"),
applied=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
rejected=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)),
)
# Build the by_type dictionary using the aggregated results
by_type = {}
for stat in stats_by_type:
suggestion_type = stat['suggestion_type']
type_total = stat['total']
type_applied = stat['applied']
type_rejected = stat['rejected']
suggestion_type = stat["suggestion_type"]
type_total = stat["total"]
type_applied = stat["applied"]
type_rejected = stat["rejected"]
by_type[suggestion_type] = {
'total': type_total,
'applied': type_applied,
'rejected': type_rejected,
'accuracy_rate': (type_applied / type_total * 100) if type_total > 0 else 0,
"total": type_total,
"applied": type_applied,
"rejected": type_rejected,
"accuracy_rate": (type_applied / type_total * 100) if type_total > 0 else 0,
}
# Get average confidence scores
avg_confidence_applied = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_APPLIED
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
status=AISuggestionFeedback.STATUS_APPLIED,
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
avg_confidence_rejected = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_REJECTED
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
status=AISuggestionFeedback.STATUS_REJECTED,
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
# Get recent suggestions (last 10)
recent_suggestions = AISuggestionFeedback.objects.order_by('-created_at')[:10]
recent_suggestions = AISuggestionFeedback.objects.order_by("-created_at")[:10]
# Build response data
from documents.serializers.ai_suggestions import AISuggestionFeedbackSerializer
from documents.serializers.ai_suggestions import (
AISuggestionFeedbackSerializer,
)
data = {
'total_suggestions': total_feedbacks,
'total_applied': total_applied,
'total_rejected': total_rejected,
'accuracy_rate': accuracy_rate,
'by_type': by_type,
'average_confidence_applied': avg_confidence_applied,
'average_confidence_rejected': avg_confidence_rejected,
'recent_suggestions': AISuggestionFeedbackSerializer(
recent_suggestions, many=True
"total_suggestions": total_feedbacks,
"total_applied": total_applied,
"total_rejected": total_rejected,
"accuracy_rate": accuracy_rate,
"by_type": by_type,
"average_confidence_applied": avg_confidence_applied,
"average_confidence_rejected": avg_confidence_rejected,
"recent_suggestions": AISuggestionFeedbackSerializer(
recent_suggestions, many=True,
).data,
}
# Serialize and return
serializer = AISuggestionStatsSerializer(data=data)
serializer.is_valid(raise_exception=True)
return Response(serializer.validated_data)
except Exception as e:
logger.error(f"Error getting AI suggestion statistics: {e}", exc_info=True)
return Response(
@ -3561,37 +3565,37 @@ class AISuggestionsView(GenericAPIView):
Requires: can_view_ai_suggestions permission
"""
permission_classes = [IsAuthenticated, CanViewAISuggestionsPermission]
serializer_class = AISuggestionsResponseSerializer
def post(self, request):
"""Get AI suggestions for a document."""
# Validate request
request_serializer = AISuggestionsRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True)
document_id = request_serializer.validated_data['document_id']
document_id = request_serializer.validated_data["document_id"]
try:
document = Document.objects.get(pk=document_id)
except Document.DoesNotExist:
return Response(
{"error": "Document not found or you don't have permission to view it"},
status=status.HTTP_404_NOT_FOUND
status=status.HTTP_404_NOT_FOUND,
)
# Check if user has permission to view this document
if not has_perms_owner_aware(request.user, 'documents.view_document', document):
if not has_perms_owner_aware(request.user, "documents.view_document", document):
return Response(
{"error": "Permission denied"},
status=status.HTTP_403_FORBIDDEN
status=status.HTTP_403_FORBIDDEN,
)
# Get AI scanner and scan document
scanner = get_ai_scanner()
scan_result = scanner.scan_document(document, document.content or "")
# Build response
response_data = {
"document_id": document.id,
@ -3600,9 +3604,9 @@ class AISuggestionsView(GenericAPIView):
"document_type": None,
"storage_path": None,
"title_suggestion": scan_result.title_suggestion,
"custom_fields": {}
"custom_fields": {},
}
# Format tag suggestions
for tag_id, confidence in scan_result.tags:
try:
@ -3610,12 +3614,12 @@ class AISuggestionsView(GenericAPIView):
response_data["tags"].append({
"id": tag.id,
"name": tag.name,
"confidence": confidence
"confidence": confidence,
})
except Tag.DoesNotExist:
# Tag was suggested by AI but no longer exists; skip it
pass
# Format correspondent suggestion
if scan_result.correspondent:
corr_id, confidence = scan_result.correspondent
@ -3624,12 +3628,12 @@ class AISuggestionsView(GenericAPIView):
response_data["correspondent"] = {
"id": correspondent.id,
"name": correspondent.name,
"confidence": confidence
"confidence": confidence,
}
except Correspondent.DoesNotExist:
# Correspondent was suggested but no longer exists; skip it
pass
# Format document type suggestion
if scan_result.document_type:
type_id, confidence = scan_result.document_type
@ -3638,12 +3642,12 @@ class AISuggestionsView(GenericAPIView):
response_data["document_type"] = {
"id": doc_type.id,
"name": doc_type.name,
"confidence": confidence
"confidence": confidence,
}
except DocumentType.DoesNotExist:
# Document type was suggested but no longer exists; skip it
pass
# Format storage path suggestion
if scan_result.storage_path:
path_id, confidence = scan_result.storage_path
@ -3652,19 +3656,19 @@ class AISuggestionsView(GenericAPIView):
response_data["storage_path"] = {
"id": storage_path.id,
"name": storage_path.name,
"confidence": confidence
"confidence": confidence,
}
except StoragePath.DoesNotExist:
# Storage path was suggested but no longer exists; skip it
pass
# Format custom fields
for field_id, (value, confidence) in scan_result.custom_fields.items():
response_data["custom_fields"][str(field_id)] = {
"value": value,
"confidence": confidence
"confidence": confidence,
}
return Response(response_data)
@ -3674,48 +3678,48 @@ class ApplyAISuggestionsView(GenericAPIView):
Requires: can_apply_ai_suggestions permission
"""
permission_classes = [IsAuthenticated, CanApplyAISuggestionsPermission]
def post(self, request):
"""Apply AI suggestions to a document."""
# Validate request
serializer = ApplyAISuggestionsSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
document_id = serializer.validated_data['document_id']
document_id = serializer.validated_data["document_id"]
try:
document = Document.objects.get(pk=document_id)
except Document.DoesNotExist:
return Response(
{"error": "Document not found"},
status=status.HTTP_404_NOT_FOUND
status=status.HTTP_404_NOT_FOUND,
)
# Check if user has permission to change this document
if not has_perms_owner_aware(request.user, 'documents.change_document', document):
if not has_perms_owner_aware(request.user, "documents.change_document", document):
return Response(
{"error": "Permission denied"},
status=status.HTTP_403_FORBIDDEN
status=status.HTTP_403_FORBIDDEN,
)
# Get AI scanner and scan document
scanner = get_ai_scanner()
scan_result = scanner.scan_document(document, document.content or "")
# Apply suggestions based on user selections
applied = []
if serializer.validated_data.get('apply_tags'):
selected_tags = serializer.validated_data.get('selected_tags', [])
if serializer.validated_data.get("apply_tags"):
selected_tags = serializer.validated_data.get("selected_tags", [])
if selected_tags:
# Apply only selected tags
tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags]
else:
# Apply all high-confidence tags
tags_to_apply = [tag_id for tag_id, conf in scan_result.tags if conf >= scanner.auto_apply_threshold]
for tag_id in tags_to_apply:
try:
tag = Tag.objects.get(pk=tag_id)
@ -3724,8 +3728,8 @@ class ApplyAISuggestionsView(GenericAPIView):
except Tag.DoesNotExist:
# Tag not found; skip applying this tag
pass
if serializer.validated_data.get('apply_correspondent') and scan_result.correspondent:
if serializer.validated_data.get("apply_correspondent") and scan_result.correspondent:
corr_id, confidence = scan_result.correspondent
try:
correspondent = Correspondent.objects.get(pk=corr_id)
@ -3734,8 +3738,8 @@ class ApplyAISuggestionsView(GenericAPIView):
except Correspondent.DoesNotExist:
# Correspondent not found; skip applying
pass
if serializer.validated_data.get('apply_document_type') and scan_result.document_type:
if serializer.validated_data.get("apply_document_type") and scan_result.document_type:
type_id, confidence = scan_result.document_type
try:
doc_type = DocumentType.objects.get(pk=type_id)
@ -3744,8 +3748,8 @@ class ApplyAISuggestionsView(GenericAPIView):
except DocumentType.DoesNotExist:
# Document type not found; skip applying
pass
if serializer.validated_data.get('apply_storage_path') and scan_result.storage_path:
if serializer.validated_data.get("apply_storage_path") and scan_result.storage_path:
path_id, confidence = scan_result.storage_path
try:
storage_path = StoragePath.objects.get(pk=path_id)
@ -3754,18 +3758,18 @@ class ApplyAISuggestionsView(GenericAPIView):
except StoragePath.DoesNotExist:
# Storage path not found; skip applying
pass
if serializer.validated_data.get('apply_title') and scan_result.title_suggestion:
if serializer.validated_data.get("apply_title") and scan_result.title_suggestion:
document.title = scan_result.title_suggestion
applied.append(f"title: {scan_result.title_suggestion}")
# Save document
document.save()
return Response({
"status": "success",
"document_id": document.id,
"applied": applied
"applied": applied,
})
@ -3775,23 +3779,23 @@ class AIConfigurationView(GenericAPIView):
Requires: can_configure_ai permission
"""
permission_classes = [IsAuthenticated, CanConfigureAIPermission]
def get(self, request):
"""Get current AI configuration."""
scanner = get_ai_scanner()
config_data = {
"auto_apply_threshold": scanner.auto_apply_threshold,
"suggest_threshold": scanner.suggest_threshold,
"ml_enabled": scanner.ml_enabled,
"advanced_ocr_enabled": scanner.advanced_ocr_enabled,
}
serializer = AIConfigurationSerializer(config_data)
return Response(serializer.data)
def post(self, request):
"""
Update AI configuration.
@ -3802,27 +3806,27 @@ class AIConfigurationView(GenericAPIView):
"""
serializer = AIConfigurationSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
# Create new scanner with updated configuration
config = {}
if 'auto_apply_threshold' in serializer.validated_data:
config['auto_apply_threshold'] = serializer.validated_data['auto_apply_threshold']
if 'suggest_threshold' in serializer.validated_data:
config['suggest_threshold'] = serializer.validated_data['suggest_threshold']
if 'ml_enabled' in serializer.validated_data:
config['enable_ml_features'] = serializer.validated_data['ml_enabled']
if 'advanced_ocr_enabled' in serializer.validated_data:
config['enable_advanced_ocr'] = serializer.validated_data['advanced_ocr_enabled']
if "auto_apply_threshold" in serializer.validated_data:
config["auto_apply_threshold"] = serializer.validated_data["auto_apply_threshold"]
if "suggest_threshold" in serializer.validated_data:
config["suggest_threshold"] = serializer.validated_data["suggest_threshold"]
if "ml_enabled" in serializer.validated_data:
config["enable_ml_features"] = serializer.validated_data["ml_enabled"]
if "advanced_ocr_enabled" in serializer.validated_data:
config["enable_advanced_ocr"] = serializer.validated_data["advanced_ocr_enabled"]
# Update global scanner instance
# WARNING: Not thread-safe. Consider storing configuration in database
# and reloading on each get_ai_scanner() call for production use
from documents import ai_scanner
ai_scanner._scanner_instance = AIDocumentScanner(**config)
return Response({
"status": "success",
"message": "AI configuration updated. Changes may require server restart for consistency."
"message": "AI configuration updated. Changes may require server restart for consistency.",
})

View file

@ -30,10 +30,10 @@ class DeletionRequestViewSet(ModelViewSet):
Provides CRUD operations plus custom actions for approval workflow.
"""
model = DeletionRequest
serializer_class = DeletionRequestSerializer
def get_queryset(self):
"""
Return deletion requests for the current user.
@ -45,7 +45,7 @@ class DeletionRequestViewSet(ModelViewSet):
if user.is_superuser:
return DeletionRequest.objects.all()
return DeletionRequest.objects.filter(user=user)
def _can_manage_request(self, deletion_request):
"""
Check if current user can manage (approve/reject/cancel) the request.
@ -58,7 +58,7 @@ class DeletionRequestViewSet(ModelViewSet):
"""
user = self.request.user
return user.is_superuser or deletion_request.user == user
@action(methods=["post"], detail=True)
def approve(self, request, pk=None):
"""
@ -72,13 +72,13 @@ class DeletionRequestViewSet(ModelViewSet):
Response with execution results
"""
deletion_request = self.get_object()
# Check permissions
if not self._can_manage_request(deletion_request):
return HttpResponseForbidden(
"You don't have permission to approve this deletion request."
"You don't have permission to approve this deletion request.",
)
# Validate status
if deletion_request.status != DeletionRequest.STATUS_PENDING:
return Response(
@ -88,9 +88,9 @@ class DeletionRequestViewSet(ModelViewSet):
},
status=status.HTTP_400_BAD_REQUEST,
)
comment = request.data.get("comment", "")
# Execute approval and deletion in a transaction
try:
with transaction.atomic():
@ -100,12 +100,12 @@ class DeletionRequestViewSet(ModelViewSet):
{"error": "Failed to approve deletion request."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# Execute the deletion
documents = list(deletion_request.documents.all())
deleted_count = 0
failed_deletions = []
for doc in documents:
try:
doc_id = doc.id
@ -114,18 +114,18 @@ class DeletionRequestViewSet(ModelViewSet):
deleted_count += 1
logger.info(
f"Deleted document {doc_id} ('{doc_title}') "
f"as part of deletion request {deletion_request.id}"
f"as part of deletion request {deletion_request.id}",
)
except Exception as e:
logger.error(
f"Failed to delete document {doc.id}: {str(e)}"
f"Failed to delete document {doc.id}: {e!s}",
)
failed_deletions.append({
"id": doc.id,
"title": doc.title,
"error": str(e),
})
# Update completion status
deletion_request.status = DeletionRequest.STATUS_COMPLETED
deletion_request.completed_at = timezone.now()
@ -135,20 +135,20 @@ class DeletionRequestViewSet(ModelViewSet):
"total_documents": len(documents),
}
deletion_request.save()
logger.info(
f"Deletion request {deletion_request.id} completed. "
f"Deleted {deleted_count}/{len(documents)} documents."
f"Deleted {deleted_count}/{len(documents)} documents.",
)
except Exception as e:
logger.error(
f"Error executing deletion request {deletion_request.id}: {str(e)}"
f"Error executing deletion request {deletion_request.id}: {e!s}",
)
return Response(
{"error": f"Failed to execute deletion: {str(e)}"},
{"error": f"Failed to execute deletion: {e!s}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
serializer = self.get_serializer(deletion_request)
return Response(
{
@ -158,7 +158,7 @@ class DeletionRequestViewSet(ModelViewSet):
},
status=status.HTTP_200_OK,
)
@action(methods=["post"], detail=True)
def reject(self, request, pk=None):
"""
@ -172,13 +172,13 @@ class DeletionRequestViewSet(ModelViewSet):
Response with updated deletion request
"""
deletion_request = self.get_object()
# Check permissions
if not self._can_manage_request(deletion_request):
return HttpResponseForbidden(
"You don't have permission to reject this deletion request."
"You don't have permission to reject this deletion request.",
)
# Validate status
if deletion_request.status != DeletionRequest.STATUS_PENDING:
return Response(
@ -188,20 +188,20 @@ class DeletionRequestViewSet(ModelViewSet):
},
status=status.HTTP_400_BAD_REQUEST,
)
comment = request.data.get("comment", "")
# Reject the request
if not deletion_request.reject(request.user, comment):
return Response(
{"error": "Failed to reject deletion request."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
logger.info(
f"Deletion request {deletion_request.id} rejected by user {request.user.username}"
f"Deletion request {deletion_request.id} rejected by user {request.user.username}",
)
serializer = self.get_serializer(deletion_request)
return Response(
{
@ -210,7 +210,7 @@ class DeletionRequestViewSet(ModelViewSet):
},
status=status.HTTP_200_OK,
)
@action(methods=["post"], detail=True)
def cancel(self, request, pk=None):
"""
@ -224,13 +224,13 @@ class DeletionRequestViewSet(ModelViewSet):
Response with updated deletion request
"""
deletion_request = self.get_object()
# Check permissions
if not self._can_manage_request(deletion_request):
return HttpResponseForbidden(
"You don't have permission to cancel this deletion request."
"You don't have permission to cancel this deletion request.",
)
# Validate status
if deletion_request.status != DeletionRequest.STATUS_PENDING:
return Response(
@ -240,18 +240,18 @@ class DeletionRequestViewSet(ModelViewSet):
},
status=status.HTTP_400_BAD_REQUEST,
)
# Cancel the request
deletion_request.status = DeletionRequest.STATUS_CANCELLED
deletion_request.reviewed_by = request.user
deletion_request.reviewed_at = timezone.now()
deletion_request.review_comment = request.data.get("comment", "Cancelled by user")
deletion_request.save()
logger.info(
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}"
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}",
)
serializer = self.get_serializer(deletion_request)
return Response(
{

View file

@ -160,7 +160,7 @@ class SecurityHeadersMiddleware:
# Store nonce in request for use in templates
# Templates can access this via {{ request.csp_nonce }}
if hasattr(request, '_csp_nonce'):
if hasattr(request, "_csp_nonce"):
request._csp_nonce = nonce
# Prevent clickjacking attacks

View file

@ -9,7 +9,6 @@ from __future__ import annotations
import hashlib
import logging
import mimetypes
import os
import re
from pathlib import Path
@ -26,39 +25,39 @@ logger = logging.getLogger("paperless.security")
# Lista explícita de tipos MIME permitidos
ALLOWED_MIME_TYPES = {
# Documentos
'application/pdf',
'application/vnd.oasis.opendocument.text',
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-excel',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'application/vnd.oasis.opendocument.spreadsheet',
'application/vnd.oasis.opendocument.presentation',
'application/rtf',
'text/rtf',
"application/pdf",
"application/vnd.oasis.opendocument.text",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/vnd.oasis.opendocument.spreadsheet",
"application/vnd.oasis.opendocument.presentation",
"application/rtf",
"text/rtf",
# Imágenes
'image/jpeg',
'image/png',
'image/gif',
'image/tiff',
'image/bmp',
'image/webp',
"image/jpeg",
"image/png",
"image/gif",
"image/tiff",
"image/bmp",
"image/webp",
# Texto
'text/plain',
'text/html',
'text/csv',
'text/markdown',
"text/plain",
"text/html",
"text/csv",
"text/markdown",
}
# Maximum file size (100MB by default)
# Can be overridden by settings.MAX_UPLOAD_SIZE
try:
from django.conf import settings
MAX_FILE_SIZE = getattr(settings, 'MAX_UPLOAD_SIZE', 100 * 1024 * 1024) # 100MB por defecto
MAX_FILE_SIZE = getattr(settings, "MAX_UPLOAD_SIZE", 100 * 1024 * 1024) # 100MB por defecto
except ImportError:
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes
@ -114,7 +113,6 @@ ALLOWED_JS_PATTERNS = [
class FileValidationError(Exception):
"""Raised when file validation fails."""
pass
def has_whitelisted_javascript(content: bytes) -> bool:
@ -143,7 +141,7 @@ def validate_mime_type(mime_type: str) -> None:
if mime_type not in ALLOWED_MIME_TYPES:
raise FileValidationError(
f"MIME type '{mime_type}' is not allowed. "
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}"
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}",
)

View file

@ -86,7 +86,7 @@ api_router.register(r"config", ApplicationConfigurationViewSet)
api_router.register(r"processed_mail", ProcessedMailViewSet)
api_router.register(r"deletion_requests", DeletionRequestViewSet)
api_router.register(
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests"
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests",
)