fix(syntax): corrige errores de sintaxis y formato en Python

- Corrige paréntesis faltante en DeletionRequestActionSerializer (serialisers.py:2855)
- Elimina espacios en blanco en líneas vacías (W293)
- Elimina espacios finales en líneas (W291)
- Elimina imports no utilizados (F401)
- Normaliza comillas a comillas dobles (Q000)
- Agrega comas finales faltantes (COM812)
- Ordena imports según convenciones (I001)
- Actualiza anotaciones de tipo a PEP 585 (UP006)

Este commit resuelve el error de compilación en el job de CI/CD
que estaba causando que fallara el linting check.

Archivos afectados: 38
Líneas modificadas: ~2200
This commit is contained in:
Claude 2025-11-17 19:08:02 +00:00
parent 9298f64546
commit 69326b883d
No known key found for this signature in database
38 changed files with 2077 additions and 2112 deletions

View file

@ -14,14 +14,10 @@ According to agents.md requirements:
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from typing import Any
from django.contrib.auth.models import User
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ai_deletion")

View file

@ -29,7 +29,6 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import TypedDict
from django.conf import settings
@ -142,34 +141,34 @@ class AIScanResult:
"""
# Convert internal tuple format to TypedDict format
result: AIScanResultDict = {
'tags': [{'tag_id': tag_id, 'confidence': conf} for tag_id, conf in self.tags],
'custom_fields': {
field_id: {'value': value, 'confidence': conf}
"tags": [{"tag_id": tag_id, "confidence": conf} for tag_id, conf in self.tags],
"custom_fields": {
field_id: {"value": value, "confidence": conf}
for field_id, (value, conf) in self.custom_fields.items()
},
'workflows': [{'workflow_id': wf_id, 'confidence': conf} for wf_id, conf in self.workflows],
'extracted_entities': self.extracted_entities,
'metadata': self.metadata,
"workflows": [{"workflow_id": wf_id, "confidence": conf} for wf_id, conf in self.workflows],
"extracted_entities": self.extracted_entities,
"metadata": self.metadata,
}
# Add optional fields only if present
if self.correspondent:
result['correspondent'] = {
'correspondent_id': self.correspondent[0],
'confidence': self.correspondent[1],
result["correspondent"] = {
"correspondent_id": self.correspondent[0],
"confidence": self.correspondent[1],
}
if self.document_type:
result['document_type'] = {
'type_id': self.document_type[0],
'confidence': self.document_type[1],
result["document_type"] = {
"type_id": self.document_type[0],
"confidence": self.document_type[1],
}
if self.storage_path:
result['storage_path'] = {
'path_id': self.storage_path[0],
'confidence': self.storage_path[1],
result["storage_path"] = {
"path_id": self.storage_path[0],
"confidence": self.storage_path[1],
}
if self.title_suggestion:
result['title_suggestion'] = self.title_suggestion
result["title_suggestion"] = self.title_suggestion
return result
@ -1054,7 +1053,7 @@ class AIDocumentScanner:
warm_up_time = time.time() - start_time
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
def get_cache_metrics(self) -> Dict[str, Any]:
def get_cache_metrics(self) -> dict[str, Any]:
"""
Get cache performance metrics.

View file

@ -310,7 +310,7 @@ class ConsumerPlugin(
# Parsing phase
document_parser = self._create_parser_instance(parser_class)
text, date, thumbnail, archive_path, page_count = self._parse_document(
document_parser, mime_type
document_parser, mime_type,
)
# Storage phase
@ -394,7 +394,7 @@ class ConsumerPlugin(
def _attempt_pdf_recovery(
self,
tempdir: tempfile.TemporaryDirectory,
original_mime_type: str
original_mime_type: str,
) -> str:
"""
Attempt to recover a PDF file with incorrect MIME type using qpdf.
@ -438,7 +438,7 @@ class ConsumerPlugin(
def _get_parser_class(
self,
mime_type: str,
tempdir: tempfile.TemporaryDirectory
tempdir: tempfile.TemporaryDirectory,
) -> type[DocumentParser]:
"""
Determine which parser to use based on MIME type.
@ -468,7 +468,7 @@ class ConsumerPlugin(
def _create_parser_instance(
self,
parser_class: type[DocumentParser]
parser_class: type[DocumentParser],
) -> DocumentParser:
"""
Create a parser instance with progress callback.
@ -496,7 +496,7 @@ class ConsumerPlugin(
def _parse_document(
self,
document_parser: DocumentParser,
mime_type: str
mime_type: str,
) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]:
"""
Parse the document and extract metadata.
@ -670,7 +670,7 @@ class ConsumerPlugin(
self,
document: Document,
thumbnail: Path,
archive_path: Path | None
archive_path: Path | None,
) -> None:
"""
Store document files (source, thumbnail, archive) to disk.
@ -949,7 +949,7 @@ class ConsumerPlugin(
text: The extracted document text
"""
# Check if AI scanner is enabled
if not getattr(settings, 'PAPERLESS_ENABLE_AI_SCANNER', True):
if not getattr(settings, "PAPERLESS_ENABLE_AI_SCANNER", True):
self.log.debug("AI scanner is disabled, skipping AI analysis")
return

View file

@ -1,6 +1,7 @@
# Generated manually for performance optimization
from django.db import migrations, models
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):

View file

@ -1,9 +1,10 @@
# Generated manually for DeletionRequest model
# Based on model definition in documents/models.py
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
from django.conf import settings
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
@ -48,7 +49,7 @@ class Migration(migrations.Migration):
(
"ai_reason",
models.TextField(
help_text="Detailed explanation from AI about why deletion is recommended"
help_text="Detailed explanation from AI about why deletion is recommended",
),
),
(

View file

@ -1,6 +1,7 @@
# Generated manually for DeletionRequest performance optimization
from django.db import migrations, models
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):

View file

@ -1,9 +1,10 @@
# Generated manually for AI Suggestions API
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.core.validators
import django.db.models.deletion
from django.conf import settings
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):

View file

@ -10,9 +10,9 @@ Provides AI/ML capabilities including:
from __future__ import annotations
__all__ = [
"TransformerDocumentClassifier",
"DocumentNER",
"SemanticSearch",
"TransformerDocumentClassifier",
]
# Lazy imports to avoid loading heavy ML libraries unless needed

View file

@ -15,23 +15,16 @@ Logging levels used in this module:
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import torch
from torch.utils.data import Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import Trainer
from transformers import TrainingArguments
from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
from documents.models import Document
logger = logging.getLogger("paperless.ml.classifier")
@ -141,7 +134,7 @@ class TransformerDocumentClassifier:
logger.info(
f"Initialized TransformerDocumentClassifier with {model_name} "
f"(caching: {use_cache})"
f"(caching: {use_cache})",
)
def train(

View file

@ -24,8 +24,9 @@ import pickle
import threading
import time
from collections import OrderedDict
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any
logger = logging.getLogger("paperless.ml.model_cache")
@ -58,7 +59,7 @@ class CacheMetrics:
with self.lock:
self.loads += 1
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
with self.lock:
total = self.hits + self.misses
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
@ -98,7 +99,7 @@ class LRUCache:
self.lock = threading.Lock()
self.metrics = CacheMetrics()
def get(self, key: str) -> Optional[Any]:
def get(self, key: str) -> Any | None:
"""
Get item from cache.
@ -153,7 +154,7 @@ class LRUCache:
with self.lock:
return len(self.cache)
def get_metrics(self) -> Dict[str, Any]:
def get_metrics(self) -> dict[str, Any]:
"""Get cache metrics."""
return self.metrics.get_stats()
@ -173,7 +174,7 @@ class ModelCacheManager:
model = cache.get_or_load_model("classifier", loader_func)
"""
_instance: Optional[ModelCacheManager] = None
_instance: ModelCacheManager | None = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
@ -187,7 +188,7 @@ class ModelCacheManager:
def __init__(
self,
max_models: int = 3,
disk_cache_dir: Optional[str] = None,
disk_cache_dir: str | None = None,
):
"""
Initialize model cache manager.
@ -215,7 +216,7 @@ class ModelCacheManager:
def get_instance(
cls,
max_models: int = 3,
disk_cache_dir: Optional[str] = None,
disk_cache_dir: str | None = None,
) -> ModelCacheManager:
"""
Get singleton instance of ModelCacheManager.
@ -278,7 +279,7 @@ class ModelCacheManager:
load_time = time.time() - start_time
logger.info(
f"Model loaded successfully: {model_key} "
f"(took {load_time:.2f}s)"
f"(took {load_time:.2f}s)",
)
return model
@ -289,7 +290,7 @@ class ModelCacheManager:
def save_embeddings_to_disk(
self,
key: str,
embeddings: Dict[int, Any],
embeddings: dict[int, Any],
) -> bool:
"""
Save embeddings to disk cache.
@ -311,7 +312,7 @@ class ModelCacheManager:
cache_file = self.disk_cache_dir / f"{key}.pkl"
try:
with open(cache_file, 'wb') as f:
with open(cache_file, "wb") as f:
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}")
return True
@ -330,7 +331,7 @@ class ModelCacheManager:
def load_embeddings_from_disk(
self,
key: str,
) -> Optional[Dict[int, Any]]:
) -> dict[int, Any] | None:
"""
Load embeddings from disk cache.
@ -393,7 +394,7 @@ class ModelCacheManager:
except Exception as e:
logger.error(f"Failed to delete {cache_file}: {e}")
def get_metrics(self) -> Dict[str, Any]:
def get_metrics(self) -> dict[str, Any]:
"""
Get cache performance metrics.
@ -416,7 +417,7 @@ class ModelCacheManager:
def warm_up(
self,
model_loaders: Dict[str, Callable[[], Any]],
model_loaders: dict[str, Callable[[], Any]],
) -> None:
"""
Pre-load models on startup (warm-up).

View file

@ -14,15 +14,11 @@ from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING
from transformers import pipeline
from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ml.ner")

View file

@ -18,18 +18,14 @@ Examples:
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
from documents.ml.model_cache import ModelCacheManager
if TYPE_CHECKING:
pass
logger = logging.getLogger("paperless.ml.semantic_search")
@ -67,7 +63,7 @@ class SemanticSearch:
"""
logger.info(
f"Initializing SemanticSearch with model: {model_name} "
f"(caching: {use_cache})"
f"(caching: {use_cache})",
)
self.model_name = model_name
@ -127,11 +123,11 @@ class SemanticSearch:
if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor):
logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor")
return False
if hasattr(embedding, 'size'):
if hasattr(embedding, "size"):
if embedding.size == 0:
logger.warning(f"Embedding for doc {doc_id} is empty")
return False
elif hasattr(embedding, 'numel'):
elif hasattr(embedding, "numel"):
if embedding.numel() == 0:
logger.warning(f"Embedding for doc {doc_id} is empty")
return False
@ -216,11 +212,11 @@ class SemanticSearch:
try:
result = self.cache_manager.save_embeddings_to_disk(
"document_embeddings",
self.document_embeddings
self.document_embeddings,
)
if result:
logger.info(
f"Successfully saved {len(self.document_embeddings)} embeddings to disk"
f"Successfully saved {len(self.document_embeddings)} embeddings to disk",
)
else:
logger.error("Failed to save embeddings to disk (returned False)")

View file

@ -1604,30 +1604,30 @@ class DeletionRequest(models.Model):
# Requester (AI system)
requested_by_ai = models.BooleanField(default=True)
ai_reason = models.TextField(
help_text=_("Detailed explanation from AI about why deletion is recommended")
help_text=_("Detailed explanation from AI about why deletion is recommended"),
)
# User who must approve
user = models.ForeignKey(
User,
on_delete=models.CASCADE,
related_name='deletion_requests',
related_name="deletion_requests",
help_text=_("User who must approve this deletion"),
)
# Status tracking
STATUS_PENDING = 'pending'
STATUS_APPROVED = 'approved'
STATUS_REJECTED = 'rejected'
STATUS_CANCELLED = 'cancelled'
STATUS_COMPLETED = 'completed'
STATUS_PENDING = "pending"
STATUS_APPROVED = "approved"
STATUS_REJECTED = "rejected"
STATUS_CANCELLED = "cancelled"
STATUS_COMPLETED = "completed"
STATUS_CHOICES = [
(STATUS_PENDING, _('Pending')),
(STATUS_APPROVED, _('Approved')),
(STATUS_REJECTED, _('Rejected')),
(STATUS_CANCELLED, _('Cancelled')),
(STATUS_COMPLETED, _('Completed')),
(STATUS_PENDING, _("Pending")),
(STATUS_APPROVED, _("Approved")),
(STATUS_REJECTED, _("Rejected")),
(STATUS_CANCELLED, _("Cancelled")),
(STATUS_COMPLETED, _("Completed")),
]
status = models.CharField(
@ -1639,7 +1639,7 @@ class DeletionRequest(models.Model):
# Documents to be deleted
documents = models.ManyToManyField(
Document,
related_name='deletion_requests',
related_name="deletion_requests",
help_text=_("Documents that would be deleted if approved"),
)
@ -1656,7 +1656,7 @@ class DeletionRequest(models.Model):
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name='reviewed_deletion_requests',
related_name="reviewed_deletion_requests",
help_text=_("User who reviewed and approved/rejected"),
)
review_comment = models.TextField(
@ -1672,21 +1672,21 @@ class DeletionRequest(models.Model):
)
class Meta:
ordering = ['-created_at']
ordering = ["-created_at"]
verbose_name = _("deletion request")
verbose_name_plural = _("deletion requests")
indexes = [
# Composite index for common listing queries (by user, filtered by status, sorted by date)
# PostgreSQL can use this index for queries on: user, user+status, user+status+created_at
models.Index(fields=['user', 'status', 'created_at'], name='delreq_user_status_created_idx'),
models.Index(fields=["user", "status", "created_at"], name="delreq_user_status_created_idx"),
# Index for queries filtering by status and date without user filter
models.Index(fields=['status', 'created_at'], name='delreq_status_created_idx'),
models.Index(fields=["status", "created_at"], name="delreq_status_created_idx"),
# Index for queries filtering by user and date (common for user-specific views)
models.Index(fields=['user', 'created_at'], name='delreq_user_created_idx'),
models.Index(fields=["user", "created_at"], name="delreq_user_created_idx"),
# Index for queries filtering by review date
models.Index(fields=['reviewed_at'], name='delreq_reviewed_at_idx'),
models.Index(fields=["reviewed_at"], name="delreq_reviewed_at_idx"),
# Index for queries filtering by completion date
models.Index(fields=['completed_at'], name='delreq_completed_at_idx'),
models.Index(fields=["completed_at"], name="delreq_completed_at_idx"),
]
def __str__(self):
@ -1745,67 +1745,67 @@ class AISuggestionFeedback(models.Model):
"""
# Suggestion types
TYPE_TAG = 'tag'
TYPE_CORRESPONDENT = 'correspondent'
TYPE_DOCUMENT_TYPE = 'document_type'
TYPE_STORAGE_PATH = 'storage_path'
TYPE_CUSTOM_FIELD = 'custom_field'
TYPE_WORKFLOW = 'workflow'
TYPE_TITLE = 'title'
TYPE_TAG = "tag"
TYPE_CORRESPONDENT = "correspondent"
TYPE_DOCUMENT_TYPE = "document_type"
TYPE_STORAGE_PATH = "storage_path"
TYPE_CUSTOM_FIELD = "custom_field"
TYPE_WORKFLOW = "workflow"
TYPE_TITLE = "title"
SUGGESTION_TYPES = (
(TYPE_TAG, _('Tag')),
(TYPE_CORRESPONDENT, _('Correspondent')),
(TYPE_DOCUMENT_TYPE, _('Document Type')),
(TYPE_STORAGE_PATH, _('Storage Path')),
(TYPE_CUSTOM_FIELD, _('Custom Field')),
(TYPE_WORKFLOW, _('Workflow')),
(TYPE_TITLE, _('Title')),
(TYPE_TAG, _("Tag")),
(TYPE_CORRESPONDENT, _("Correspondent")),
(TYPE_DOCUMENT_TYPE, _("Document Type")),
(TYPE_STORAGE_PATH, _("Storage Path")),
(TYPE_CUSTOM_FIELD, _("Custom Field")),
(TYPE_WORKFLOW, _("Workflow")),
(TYPE_TITLE, _("Title")),
)
# Feedback status
STATUS_APPLIED = 'applied'
STATUS_REJECTED = 'rejected'
STATUS_APPLIED = "applied"
STATUS_REJECTED = "rejected"
FEEDBACK_STATUS = (
(STATUS_APPLIED, _('Applied')),
(STATUS_REJECTED, _('Rejected')),
(STATUS_APPLIED, _("Applied")),
(STATUS_REJECTED, _("Rejected")),
)
document = models.ForeignKey(
Document,
on_delete=models.CASCADE,
related_name='ai_suggestion_feedbacks',
verbose_name=_('document'),
related_name="ai_suggestion_feedbacks",
verbose_name=_("document"),
)
suggestion_type = models.CharField(
_('suggestion type'),
_("suggestion type"),
max_length=50,
choices=SUGGESTION_TYPES,
)
suggested_value_id = models.IntegerField(
_('suggested value ID'),
_("suggested value ID"),
null=True,
blank=True,
help_text=_('ID of the suggested object (tag, correspondent, etc.)'),
help_text=_("ID of the suggested object (tag, correspondent, etc.)"),
)
suggested_value_text = models.TextField(
_('suggested value text'),
_("suggested value text"),
blank=True,
help_text=_('Text representation of the suggested value'),
help_text=_("Text representation of the suggested value"),
)
confidence = models.FloatField(
_('confidence'),
help_text=_('AI confidence score (0.0 to 1.0)'),
_("confidence"),
help_text=_("AI confidence score (0.0 to 1.0)"),
validators=[MinValueValidator(0.0), MaxValueValidator(1.0)],
)
status = models.CharField(
_('status'),
_("status"),
max_length=20,
choices=FEEDBACK_STATUS,
)
@ -1815,36 +1815,36 @@ class AISuggestionFeedback(models.Model):
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name='ai_suggestion_feedbacks',
verbose_name=_('user'),
help_text=_('User who applied or rejected the suggestion'),
related_name="ai_suggestion_feedbacks",
verbose_name=_("user"),
help_text=_("User who applied or rejected the suggestion"),
)
created_at = models.DateTimeField(
_('created at'),
_("created at"),
auto_now_add=True,
)
applied_at = models.DateTimeField(
_('applied/rejected at'),
_("applied/rejected at"),
auto_now=True,
)
metadata = models.JSONField(
_('metadata'),
_("metadata"),
default=dict,
blank=True,
help_text=_('Additional metadata about the suggestion'),
help_text=_("Additional metadata about the suggestion"),
)
class Meta:
verbose_name = _('AI suggestion feedback')
verbose_name_plural = _('AI suggestion feedbacks')
ordering = ['-created_at']
verbose_name = _("AI suggestion feedback")
verbose_name_plural = _("AI suggestion feedbacks")
ordering = ["-created_at"]
indexes = [
models.Index(fields=['document', 'suggestion_type']),
models.Index(fields=['status', 'created_at']),
models.Index(fields=['suggestion_type', 'status']),
models.Index(fields=["document", "suggestion_type"]),
models.Index(fields=["status", "created_at"]),
models.Index(fields=["suggestion_type", "status"]),
]
def __str__(self):

View file

@ -11,21 +11,21 @@ Lazy imports are used to avoid loading heavy dependencies unless needed.
"""
__all__ = [
'TableExtractor',
'HandwritingRecognizer',
'FormFieldDetector',
"FormFieldDetector",
"HandwritingRecognizer",
"TableExtractor",
]
def __getattr__(name):
"""Lazy import to avoid loading heavy ML models on startup."""
if name == 'TableExtractor':
if name == "TableExtractor":
from .table_extractor import TableExtractor
return TableExtractor
elif name == 'HandwritingRecognizer':
elif name == "HandwritingRecognizer":
from .handwriting import HandwritingRecognizer
return HandwritingRecognizer
elif name == 'FormFieldDetector':
elif name == "FormFieldDetector":
from .form_detector import FormFieldDetector
return FormFieldDetector
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View file

@ -8,8 +8,8 @@ This module provides capabilities to:
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from typing import Any
import numpy as np
from PIL import Image
@ -59,8 +59,8 @@ class FormFieldDetector:
self,
image: Image.Image,
min_size: int = 10,
max_size: int = 50
) -> List[Dict[str, Any]]:
max_size: int = 50,
) -> list[dict[str, Any]]:
"""
Detect checkboxes in a form image.
@ -114,9 +114,9 @@ class FormFieldDetector:
checked, confidence = self._is_checkbox_checked(checkbox_region)
checkboxes.append({
'bbox': [x, y, x+w, y+h],
'checked': checked,
'confidence': confidence
"bbox": [x, y, x+w, y+h],
"checked": checked,
"confidence": confidence,
})
logger.info(f"Detected {len(checkboxes)} checkboxes")
@ -129,7 +129,7 @@ class FormFieldDetector:
logger.error(f"Error detecting checkboxes: {e}")
return []
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> Tuple[bool, float]:
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> tuple[bool, float]:
"""
Determine if a checkbox is checked.
@ -167,8 +167,8 @@ class FormFieldDetector:
def detect_text_fields(
self,
image: Image.Image,
min_width: int = 100
) -> List[Dict[str, Any]]:
min_width: int = 100,
) -> list[dict[str, Any]]:
"""
Detect text input fields in a form.
@ -202,14 +202,14 @@ class FormFieldDetector:
cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1],
cv2.MORPH_OPEN,
horizontal_kernel,
iterations=2
iterations=2,
)
# Find contours of horizontal lines
contours, _ = cv2.findContours(
detect_horizontal,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
cv2.CHAIN_APPROX_SIMPLE,
)
text_fields = []
@ -221,8 +221,8 @@ class FormFieldDetector:
# Expand upward to include text area
text_bbox = [x, max(0, y-30), x+w, y+h]
text_fields.append({
'bbox': text_bbox,
'type': 'line'
"bbox": text_bbox,
"type": "line",
})
# Detect rectangular boxes (bordered text fields)
@ -236,8 +236,8 @@ class FormFieldDetector:
aspect_ratio = w / h if h > 0 else 0
if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2:
text_fields.append({
'bbox': [x, y, x+w, y+h],
'type': 'box'
"bbox": [x, y, x+w, y+h],
"type": "box",
})
logger.info(f"Detected {len(text_fields)} text fields")
@ -253,8 +253,8 @@ class FormFieldDetector:
def detect_labels(
self,
image: Image.Image,
field_bboxes: List[List[int]]
) -> List[Dict[str, Any]]:
field_bboxes: list[list[int]],
) -> list[dict[str, Any]]:
"""
Detect labels near form fields.
@ -271,17 +271,17 @@ class FormFieldDetector:
# Get all text with bounding boxes
ocr_data = pytesseract.image_to_data(
image,
output_type=pytesseract.Output.DICT
output_type=pytesseract.Output.DICT,
)
# Group text into potential labels
labels = []
for i, text in enumerate(ocr_data['text']):
for i, text in enumerate(ocr_data["text"]):
if text.strip() and len(text.strip()) > 2:
x = ocr_data['left'][i]
y = ocr_data['top'][i]
w = ocr_data['width'][i]
h = ocr_data['height'][i]
x = ocr_data["left"][i]
y = ocr_data["top"][i]
w = ocr_data["width"][i]
h = ocr_data["height"][i]
label_bbox = [x, y, x+w, y+h]
@ -289,9 +289,9 @@ class FormFieldDetector:
closest_field_idx = self._find_closest_field(label_bbox, field_bboxes)
labels.append({
'text': text.strip(),
'bbox': label_bbox,
'field_index': closest_field_idx
"text": text.strip(),
"bbox": label_bbox,
"field_index": closest_field_idx,
})
return labels
@ -305,9 +305,9 @@ class FormFieldDetector:
def _find_closest_field(
self,
label_bbox: List[int],
field_bboxes: List[List[int]]
) -> Optional[int]:
label_bbox: list[int],
field_bboxes: list[list[int]],
) -> int | None:
"""
Find the closest field to a label.
@ -325,7 +325,7 @@ class FormFieldDetector:
label_center_x = (label_bbox[0] + label_bbox[2]) / 2
label_center_y = (label_bbox[1] + label_bbox[3]) / 2
min_distance = float('inf')
min_distance = float("inf")
closest_idx = 0
for i, field_bbox in enumerate(field_bboxes):
@ -336,7 +336,7 @@ class FormFieldDetector:
# Euclidean distance
distance = np.sqrt(
(label_center_x - field_center_x)**2 +
(label_center_y - field_center_y)**2
(label_center_y - field_center_y)**2,
)
if distance < min_distance:
@ -348,8 +348,8 @@ class FormFieldDetector:
def detect_form_fields(
self,
image_path: str,
extract_values: bool = True
) -> List[Dict[str, Any]]:
extract_values: bool = True,
) -> list[dict[str, Any]]:
"""
Detect all form fields and extract their values.
@ -372,14 +372,14 @@ class FormFieldDetector:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Detect different field types
text_fields = self.detect_text_fields(image)
checkboxes = self.detect_checkboxes(image)
# Combine all field bboxes for label detection
all_field_bboxes = [f['bbox'] for f in text_fields] + [cb['bbox'] for cb in checkboxes]
all_field_bboxes = [f["bbox"] for f in text_fields] + [cb["bbox"] for cb in checkboxes]
# Detect labels
labels = self.detect_labels(image, all_field_bboxes)
@ -393,20 +393,20 @@ class FormFieldDetector:
label_text = self._find_label_for_field(i, labels, len(text_fields))
result = {
'type': 'text',
'label': label_text,
'bbox': field['bbox'],
"type": "text",
"label": label_text,
"bbox": field["bbox"],
}
# Extract value if requested
if extract_values:
x1, y1, x2, y2 = field['bbox']
x1, y1, x2, y2 = field["bbox"]
field_image = image.crop((x1, y1, x2, y2))
recognizer = self._get_handwriting_recognizer()
value = recognizer.recognize_from_image(field_image, preprocess=True)
result['value'] = value.strip()
result['confidence'] = recognizer._estimate_confidence(value)
result["value"] = value.strip()
result["confidence"] = recognizer._estimate_confidence(value)
results.append(result)
@ -416,11 +416,11 @@ class FormFieldDetector:
label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes))
results.append({
'type': 'checkbox',
'label': label_text,
'value': checkbox['checked'],
'bbox': checkbox['bbox'],
'confidence': checkbox['confidence']
"type": "checkbox",
"label": label_text,
"value": checkbox["checked"],
"bbox": checkbox["bbox"],
"confidence": checkbox["confidence"],
})
logger.info(f"Detected {len(results)} form fields from {image_path}")
@ -433,8 +433,8 @@ class FormFieldDetector:
def _find_label_for_field(
self,
field_idx: int,
labels: List[Dict[str, Any]],
total_fields: int
labels: list[dict[str, Any]],
total_fields: int,
) -> str:
"""
Find the label text for a specific field.
@ -449,19 +449,19 @@ class FormFieldDetector:
"""
matching_labels = [
label for label in labels
if label['field_index'] == field_idx
if label["field_index"] == field_idx
]
if matching_labels:
# Combine multiple label parts if found
return ' '.join(label['text'] for label in matching_labels)
return " ".join(label["text"] for label in matching_labels)
return f"Field_{field_idx + 1}"
def extract_form_data(
self,
image_path: str,
output_format: str = 'dict'
output_format: str = "dict",
) -> Any:
"""
Extract all form data as structured output.
@ -476,16 +476,16 @@ class FormFieldDetector:
# Detect and extract fields
fields = self.detect_form_fields(image_path, extract_values=True)
if output_format == 'dict':
if output_format == "dict":
# Return as dictionary
return {field['label']: field['value'] for field in fields}
return {field["label"]: field["value"] for field in fields}
elif output_format == 'json':
elif output_format == "json":
import json
data = {field['label']: field['value'] for field in fields}
data = {field["label"]: field["value"] for field in fields}
return json.dumps(data, indent=2)
elif output_format == 'dataframe':
elif output_format == "dataframe":
import pandas as pd
return pd.DataFrame(fields)

View file

@ -8,8 +8,8 @@ This module provides handwriting OCR capabilities using:
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from typing import Any
import numpy as np
from PIL import Image
@ -65,8 +65,9 @@ class HandwritingRecognizer:
return
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
logger.info(f"Loading handwriting recognition model: {self.model_name}")
@ -90,7 +91,7 @@ class HandwritingRecognizer:
def recognize_from_image(
self,
image: Image.Image,
preprocess: bool = True
preprocess: bool = True,
) -> str:
"""
Recognize text from a single image.
@ -142,11 +143,12 @@ class HandwritingRecognizer:
Preprocessed PIL Image
"""
try:
from PIL import ImageEnhance, ImageFilter
from PIL import ImageEnhance
from PIL import ImageFilter
# Convert to grayscale
if image.mode != 'L':
image = image.convert('L')
if image.mode != "L":
image = image.convert("L")
# Enhance contrast
enhancer = ImageEnhance.Contrast(image)
@ -156,7 +158,7 @@ class HandwritingRecognizer:
image = image.filter(ImageFilter.MedianFilter(size=3))
# Convert back to RGB (required by model)
image = image.convert('RGB')
image = image.convert("RGB")
return image
@ -164,7 +166,7 @@ class HandwritingRecognizer:
logger.warning(f"Error preprocessing image: {e}")
return image
def detect_text_lines(self, image: Image.Image) -> List[Dict[str, Any]]:
def detect_text_lines(self, image: Image.Image) -> list[dict[str, Any]]:
"""
Detect individual text lines in an image.
@ -208,12 +210,12 @@ class HandwritingRecognizer:
# Crop line from original image
line_img = image.crop((x, y, x+w, y+h))
lines.append({
'bbox': [x, y, x+w, y+h],
'image': line_img
"bbox": [x, y, x+w, y+h],
"image": line_img,
})
# Sort lines top to bottom
lines.sort(key=lambda l: l['bbox'][1])
lines.sort(key=lambda l: l["bbox"][1])
logger.info(f"Detected {len(lines)} text lines")
return lines
@ -228,8 +230,8 @@ class HandwritingRecognizer:
def recognize_lines(
self,
image_path: str,
return_confidence: bool = True
) -> List[Dict[str, Any]]:
return_confidence: bool = True,
) -> list[dict[str, Any]]:
"""
Recognize text from each line in an image.
@ -250,7 +252,7 @@ class HandwritingRecognizer:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Detect lines
lines = self.detect_text_lines(image)
@ -260,18 +262,18 @@ class HandwritingRecognizer:
for i, line in enumerate(lines):
logger.debug(f"Recognizing line {i+1}/{len(lines)}")
text = self.recognize_from_image(line['image'], preprocess=True)
text = self.recognize_from_image(line["image"], preprocess=True)
result = {
'text': text,
'bbox': line['bbox'],
'line_index': i
"text": text,
"bbox": line["bbox"],
"line_index": i,
}
if return_confidence:
# Simple confidence based on text length and content
confidence = self._estimate_confidence(text)
result['confidence'] = confidence
result["confidence"] = confidence
results.append(result)
@ -309,7 +311,7 @@ class HandwritingRecognizer:
score += 0.1
# Text with spaces (words) is more reliable
if ' ' in text:
if " " in text:
score += 0.1
# Penalize if too many special characters
@ -322,8 +324,8 @@ class HandwritingRecognizer:
def recognize_from_file(
self,
image_path: str,
mode: str = 'full'
) -> Dict[str, Any]:
mode: str = "full",
) -> dict[str, Any]:
"""
Recognize handwriting from an image file.
@ -337,30 +339,30 @@ class HandwritingRecognizer:
Dictionary with recognized text and metadata
"""
try:
if mode == 'full':
if mode == "full":
# Recognize entire image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
text = self.recognize_from_image(image, preprocess=True)
return {
'text': text,
'mode': 'full',
'confidence': self._estimate_confidence(text)
"text": text,
"mode": "full",
"confidence": self._estimate_confidence(text),
}
elif mode == 'lines':
elif mode == "lines":
# Recognize line by line
lines = self.recognize_lines(image_path, return_confidence=True)
# Combine all lines
full_text = '\n'.join(line['text'] for line in lines)
avg_confidence = np.mean([line['confidence'] for line in lines]) if lines else 0.0
full_text = "\n".join(line["text"] for line in lines)
avg_confidence = np.mean([line["confidence"] for line in lines]) if lines else 0.0
return {
'text': full_text,
'lines': lines,
'mode': 'lines',
'confidence': float(avg_confidence)
"text": full_text,
"lines": lines,
"mode": "lines",
"confidence": float(avg_confidence),
}
else:
@ -369,17 +371,17 @@ class HandwritingRecognizer:
except Exception as e:
logger.error(f"Error recognizing from file {image_path}: {e}")
return {
'text': '',
'mode': mode,
'confidence': 0.0,
'error': str(e)
"text": "",
"mode": mode,
"confidence": 0.0,
"error": str(e),
}
def recognize_form_fields(
self,
image_path: str,
field_regions: List[Dict[str, Any]]
) -> Dict[str, str]:
field_regions: list[dict[str, Any]],
) -> dict[str, str]:
"""
Recognize text from specific form fields.
@ -399,13 +401,13 @@ class HandwritingRecognizer:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Extract and recognize each field
results = {}
for field in field_regions:
name = field['name']
bbox = field['bbox']
name = field["name"]
bbox = field["bbox"]
# Crop field region
x1, y1, x2, y2 = bbox
@ -425,9 +427,9 @@ class HandwritingRecognizer:
def batch_recognize(
self,
image_paths: List[str],
mode: str = 'full'
) -> List[Dict[str, Any]]:
image_paths: list[str],
mode: str = "full",
) -> list[dict[str, Any]]:
"""
Recognize handwriting from multiple images in batch.
@ -442,7 +444,7 @@ class HandwritingRecognizer:
for i, path in enumerate(image_paths):
logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
result = self.recognize_from_file(path, mode=mode)
result['image_path'] = path
result["image_path"] = path
results.append(result)
return results

View file

@ -8,9 +8,8 @@ This module uses various techniques to detect and extract tables from documents:
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from typing import Any
from PIL import Image
logger = logging.getLogger(__name__)
@ -59,8 +58,9 @@ class TableExtractor:
return
try:
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
logger.info(f"Loading table detection model: {self.model_name}")
@ -79,7 +79,7 @@ class TableExtractor:
logger.error("Please install required packages: pip install transformers torch pillow")
raise
def detect_tables(self, image: Image.Image) -> List[Dict[str, Any]]:
def detect_tables(self, image: Image.Image) -> list[dict[str, Any]]:
"""
Detect tables in an image.
@ -117,16 +117,16 @@ class TableExtractor:
results = self._processor.post_process_object_detection(
outputs,
threshold=self.confidence_threshold,
target_sizes=target_sizes
target_sizes=target_sizes,
)[0]
# Convert to list of dicts
tables = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
tables.append({
'bbox': box.cpu().tolist(),
'score': score.item(),
'label': self._model.config.id2label[label.item()]
"bbox": box.cpu().tolist(),
"score": score.item(),
"label": self._model.config.id2label[label.item()],
})
logger.info(f"Detected {len(tables)} tables in image")
@ -139,9 +139,9 @@ class TableExtractor:
def extract_table_from_region(
self,
image: Image.Image,
bbox: List[float],
use_ocr: bool = True
) -> Optional[Dict[str, Any]]:
bbox: list[float],
use_ocr: bool = True,
) -> dict[str, Any] | None:
"""
Extract table data from a specific region of an image.
@ -166,7 +166,7 @@ class TableExtractor:
# Get detailed OCR data
ocr_data = pytesseract.image_to_data(
table_image,
output_type=pytesseract.Output.DICT
output_type=pytesseract.Output.DICT,
)
# Reconstruct table structure from OCR data
@ -176,20 +176,20 @@ class TableExtractor:
raw_text = pytesseract.image_to_string(table_image)
return {
'data': table_data,
'raw_text': raw_text,
'bbox': bbox,
'image_size': table_image.size
"data": table_data,
"raw_text": raw_text,
"bbox": bbox,
"image_size": table_image.size,
}
else:
# Fallback to basic OCR without structure
import pytesseract
raw_text = pytesseract.image_to_string(table_image)
return {
'data': None,
'raw_text': raw_text,
'bbox': bbox,
'image_size': table_image.size
"data": None,
"raw_text": raw_text,
"bbox": bbox,
"image_size": table_image.size,
}
except ImportError:
@ -199,7 +199,7 @@ class TableExtractor:
logger.error(f"Error extracting table from region: {e}")
return None
def _reconstruct_table_from_ocr(self, ocr_data: Dict) -> Optional[Any]:
def _reconstruct_table_from_ocr(self, ocr_data: dict) -> Any | None:
"""
Reconstruct table structure from OCR output.
@ -214,10 +214,10 @@ class TableExtractor:
# Group text by vertical position (rows)
rows = {}
for i, text in enumerate(ocr_data['text']):
for i, text in enumerate(ocr_data["text"]):
if text.strip():
top = ocr_data['top'][i]
left = ocr_data['left'][i]
top = ocr_data["top"][i]
left = ocr_data["left"][i]
# Group by approximate row (within 20 pixels)
row_key = round(top / 20) * 20
@ -235,14 +235,14 @@ class TableExtractor:
if table_rows:
# Pad rows to same length
max_cols = max(len(row) for row in table_rows)
table_rows = [row + [''] * (max_cols - len(row)) for row in table_rows]
table_rows = [row + [""] * (max_cols - len(row)) for row in table_rows]
# Create DataFrame
df = pd.DataFrame(table_rows)
# Try to use first row as header if it looks like one
if len(df) > 1:
first_row_text = ' '.join(str(x) for x in df.iloc[0])
first_row_text = " ".join(str(x) for x in df.iloc[0])
if not any(char.isdigit() for char in first_row_text):
df.columns = df.iloc[0]
df = df[1:].reset_index(drop=True)
@ -261,8 +261,8 @@ class TableExtractor:
def extract_tables_from_image(
self,
image_path: str,
output_format: str = 'dataframe'
) -> List[Dict[str, Any]]:
output_format: str = "dataframe",
) -> list[dict[str, Any]]:
"""
Extract all tables from an image file.
@ -275,7 +275,7 @@ class TableExtractor:
"""
try:
# Load image
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
# Detect tables
detections = self.detect_tables(image)
@ -287,18 +287,18 @@ class TableExtractor:
table_data = self.extract_table_from_region(
image,
detection['bbox']
detection["bbox"],
)
if table_data:
table_data['detection_score'] = detection['score']
table_data['table_index'] = i
table_data["detection_score"] = detection["score"]
table_data["table_index"] = i
# Convert to requested format
if output_format == 'csv' and table_data['data'] is not None:
table_data['csv'] = table_data['data'].to_csv(index=False)
elif output_format == 'json' and table_data['data'] is not None:
table_data['json'] = table_data['data'].to_json(orient='records')
if output_format == "csv" and table_data["data"] is not None:
table_data["csv"] = table_data["data"].to_csv(index=False)
elif output_format == "json" and table_data["data"] is not None:
table_data["json"] = table_data["data"].to_json(orient="records")
tables.append(table_data)
@ -312,8 +312,8 @@ class TableExtractor:
def extract_tables_from_pdf(
self,
pdf_path: str,
page_numbers: Optional[List[int]] = None
) -> Dict[int, List[Dict[str, Any]]]:
page_numbers: list[int] | None = None,
) -> dict[int, list[dict[str, Any]]]:
"""
Extract tables from a PDF document.
@ -334,7 +334,7 @@ class TableExtractor:
images = convert_from_path(
pdf_path,
first_page=min(page_numbers),
last_page=max(page_numbers)
last_page=max(page_numbers),
)
else:
images = convert_from_path(pdf_path)
@ -352,11 +352,11 @@ class TableExtractor:
for detection in detections:
table_data = self.extract_table_from_region(
image,
detection['bbox']
detection["bbox"],
)
if table_data:
table_data['detection_score'] = detection['score']
table_data['page'] = page_num
table_data["detection_score"] = detection["score"]
table_data["page"] = page_num
tables.append(table_data)
if tables:
@ -374,8 +374,8 @@ class TableExtractor:
def save_tables_to_excel(
self,
tables: List[Dict[str, Any]],
output_path: str
tables: list[dict[str, Any]],
output_path: str,
) -> bool:
"""
Save extracted tables to an Excel file.
@ -390,17 +390,17 @@ class TableExtractor:
try:
import pandas as pd
with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
for i, table in enumerate(tables):
if table.get('data') is not None:
if table.get("data") is not None:
sheet_name = f"Table_{i+1}"
if 'page' in table:
if "page" in table:
sheet_name = f"Page_{table['page']}_Table_{i+1}"
table['data'].to_excel(
table["data"].to_excel(
writer,
sheet_name=sheet_name,
index=False
index=False,
)
logger.info(f"Saved {len(tables)} tables to {output_path}")

View file

@ -46,7 +46,6 @@ if settings.AUDIT_LOG_ENABLED:
from documents import bulk_edit
from documents.data_models import DocumentSource
from documents.filters import CustomFieldQueryParser
from documents.models import AISuggestionFeedback
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import CustomFieldInstance
@ -2788,9 +2787,9 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
"""Serializer for DeletionRequest model with document details."""
document_details = serializers.SerializerMethodField()
user_username = serializers.CharField(source='user.username', read_only=True)
user_username = serializers.CharField(source="user.username", read_only=True)
reviewed_by_username = serializers.CharField(
source='reviewed_by.username',
source="reviewed_by.username",
read_only=True,
allow_null=True,
)
@ -2799,31 +2798,31 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
from documents.models import DeletionRequest
model = DeletionRequest
fields = [
'id',
'created_at',
'updated_at',
'requested_by_ai',
'ai_reason',
'user',
'user_username',
'status',
'impact_summary',
'reviewed_at',
'reviewed_by',
'reviewed_by_username',
'review_comment',
'completed_at',
'completion_details',
'document_details',
"id",
"created_at",
"updated_at",
"requested_by_ai",
"ai_reason",
"user",
"user_username",
"status",
"impact_summary",
"reviewed_at",
"reviewed_by",
"reviewed_by_username",
"review_comment",
"completed_at",
"completion_details",
"document_details",
]
read_only_fields = [
'id',
'created_at',
'updated_at',
'reviewed_at',
'reviewed_by',
'completed_at',
'completion_details',
"id",
"created_at",
"updated_at",
"reviewed_at",
"reviewed_by",
"completed_at",
"completion_details",
]
def get_document_details(self, obj):
@ -2831,12 +2830,12 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
documents = obj.documents.all()
return [
{
'id': doc.id,
'title': doc.title,
'created': doc.created.isoformat() if doc.created else None,
'correspondent': doc.correspondent.name if doc.correspondent else None,
'document_type': doc.document_type.name if doc.document_type else None,
'tags': [tag.name for tag in doc.tags.all()],
"id": doc.id,
"title": doc.title,
"created": doc.created.isoformat() if doc.created else None,
"correspondent": doc.correspondent.name if doc.correspondent else None,
"document_type": doc.document_type.name if doc.document_type else None,
"tags": [tag.name for tag in doc.tags.all()],
}
for doc in documents
]
@ -2852,6 +2851,9 @@ class DeletionRequestActionSerializer(serializers.Serializer):
allow_blank=True,
label="Review Comment",
help_text="Optional comment when reviewing the deletion request",
)
class AISuggestionsRequestSerializer(serializers.Serializer):
"""Serializer for requesting AI suggestions for a document."""

View file

@ -1,17 +1,15 @@
"""Serializers package for documents app."""
from .ai_suggestions import (
AISuggestionFeedbackSerializer,
AISuggestionsSerializer,
AISuggestionStatsSerializer,
ApplySuggestionSerializer,
RejectSuggestionSerializer,
)
from .ai_suggestions import AISuggestionFeedbackSerializer
from .ai_suggestions import AISuggestionsSerializer
from .ai_suggestions import AISuggestionStatsSerializer
from .ai_suggestions import ApplySuggestionSerializer
from .ai_suggestions import RejectSuggestionSerializer
__all__ = [
'AISuggestionFeedbackSerializer',
'AISuggestionsSerializer',
'AISuggestionStatsSerializer',
'ApplySuggestionSerializer',
'RejectSuggestionSerializer',
"AISuggestionFeedbackSerializer",
"AISuggestionStatsSerializer",
"AISuggestionsSerializer",
"ApplySuggestionSerializer",
"RejectSuggestionSerializer",
]

View file

@ -7,36 +7,33 @@ and handling user feedback on AI suggestions.
from __future__ import annotations
from typing import Any, Dict
from typing import Any
from rest_framework import serializers
from documents.models import (
AISuggestionFeedback,
Correspondent,
CustomField,
DocumentType,
StoragePath,
Tag,
Workflow,
)
from documents.models import AISuggestionFeedback
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.models import Workflow
# Suggestion type choices - used across multiple serializers
SUGGESTION_TYPE_CHOICES = [
'tag',
'correspondent',
'document_type',
'storage_path',
'custom_field',
'workflow',
'title',
"tag",
"correspondent",
"document_type",
"storage_path",
"custom_field",
"workflow",
"title",
]
# Types that require value_id
ID_REQUIRED_TYPES = ['tag', 'correspondent', 'document_type', 'storage_path', 'workflow']
ID_REQUIRED_TYPES = ["tag", "correspondent", "document_type", "storage_path", "workflow"]
# Types that require value_text
TEXT_REQUIRED_TYPES = ['title']
TEXT_REQUIRED_TYPES = ["title"]
# Types that can use either (custom_field can be ID or text)
@ -113,7 +110,7 @@ class AISuggestionsSerializer(serializers.Serializer):
title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True)
@staticmethod
def from_scan_result(scan_result, document_id: int) -> Dict[str, Any]:
def from_scan_result(scan_result, document_id: int) -> dict[str, Any]:
"""
Convert an AIScanResult object to serializer data.
@ -133,25 +130,25 @@ class AISuggestionsSerializer(serializers.Serializer):
try:
tag = Tag.objects.get(pk=tag_id)
tag_suggestions.append({
'id': tag.id,
'name': tag.name,
'color': getattr(tag, 'color', '#000000'),
'confidence': confidence,
"id": tag.id,
"name": tag.name,
"color": getattr(tag, "color", "#000000"),
"confidence": confidence,
})
except Tag.DoesNotExist:
# Tag no longer exists in database; skip this suggestion
pass
data['tags'] = tag_suggestions
data["tags"] = tag_suggestions
# Correspondent
if scan_result.correspondent:
corr_id, confidence = scan_result.correspondent
try:
correspondent = Correspondent.objects.get(pk=corr_id)
data['correspondent'] = {
'id': correspondent.id,
'name': correspondent.name,
'confidence': confidence,
data["correspondent"] = {
"id": correspondent.id,
"name": correspondent.name,
"confidence": confidence,
}
except Correspondent.DoesNotExist:
# Correspondent no longer exists in database; omit from suggestions
@ -162,10 +159,10 @@ class AISuggestionsSerializer(serializers.Serializer):
type_id, confidence = scan_result.document_type
try:
doc_type = DocumentType.objects.get(pk=type_id)
data['document_type'] = {
'id': doc_type.id,
'name': doc_type.name,
'confidence': confidence,
data["document_type"] = {
"id": doc_type.id,
"name": doc_type.name,
"confidence": confidence,
}
except DocumentType.DoesNotExist:
# Document type no longer exists in database; omit from suggestions
@ -176,11 +173,11 @@ class AISuggestionsSerializer(serializers.Serializer):
path_id, confidence = scan_result.storage_path
try:
storage_path = StoragePath.objects.get(pk=path_id)
data['storage_path'] = {
'id': storage_path.id,
'name': storage_path.name,
'path': storage_path.path,
'confidence': confidence,
data["storage_path"] = {
"id": storage_path.id,
"name": storage_path.name,
"path": storage_path.path,
"confidence": confidence,
}
except StoragePath.DoesNotExist:
# Storage path no longer exists in database; omit from suggestions
@ -193,15 +190,15 @@ class AISuggestionsSerializer(serializers.Serializer):
try:
field = CustomField.objects.get(pk=field_id)
field_suggestions.append({
'field_id': field.id,
'field_name': field.name,
'value': str(value),
'confidence': confidence,
"field_id": field.id,
"field_name": field.name,
"value": str(value),
"confidence": confidence,
})
except CustomField.DoesNotExist:
# Custom field no longer exists in database; skip this suggestion
pass
data['custom_fields'] = field_suggestions
data["custom_fields"] = field_suggestions
# Workflows
if scan_result.workflows:
@ -210,19 +207,19 @@ class AISuggestionsSerializer(serializers.Serializer):
try:
workflow = Workflow.objects.get(pk=workflow_id)
workflow_suggestions.append({
'id': workflow.id,
'name': workflow.name,
'confidence': confidence,
"id": workflow.id,
"name": workflow.name,
"confidence": confidence,
})
except Workflow.DoesNotExist:
# Workflow no longer exists in database; skip this suggestion
pass
data['workflows'] = workflow_suggestions
data["workflows"] = workflow_suggestions
# Title suggestion
if scan_result.title_suggestion:
data['title_suggestion'] = {
'title': scan_result.title_suggestion,
data["title_suggestion"] = {
"title": scan_result.title_suggestion,
}
return data
@ -234,26 +231,26 @@ class SuggestionSerializerMixin:
"""
def validate(self, attrs):
"""Validate that the correct value field is provided for the suggestion type."""
suggestion_type = attrs.get('suggestion_type')
value_id = attrs.get('value_id')
value_text = attrs.get('value_text')
suggestion_type = attrs.get("suggestion_type")
value_id = attrs.get("value_id")
value_text = attrs.get("value_text")
# Types that require value_id
if suggestion_type in ID_REQUIRED_TYPES and not value_id:
raise serializers.ValidationError(
f"value_id is required for suggestion_type '{suggestion_type}'"
f"value_id is required for suggestion_type '{suggestion_type}'",
)
# Types that require value_text
if suggestion_type in TEXT_REQUIRED_TYPES and not value_text:
raise serializers.ValidationError(
f"value_text is required for suggestion_type '{suggestion_type}'"
f"value_text is required for suggestion_type '{suggestion_type}'",
)
# For custom_field, either is acceptable
if suggestion_type == 'custom_field' and not value_id and not value_text:
if suggestion_type == "custom_field" and not value_id and not value_text:
raise serializers.ValidationError(
"Either value_id or value_text must be provided for custom_field"
"Either value_id or value_text must be provided for custom_field",
)
return attrs
@ -295,19 +292,19 @@ class AISuggestionFeedbackSerializer(serializers.ModelSerializer):
class Meta:
model = AISuggestionFeedback
fields = [
'id',
'document',
'suggestion_type',
'suggested_value_id',
'suggested_value_text',
'confidence',
'status',
'user',
'created_at',
'applied_at',
'metadata',
"id",
"document",
"suggestion_type",
"suggested_value_id",
"suggested_value_text",
"confidence",
"status",
"user",
"created_at",
"applied_at",
"metadata",
]
read_only_fields = ['id', 'created_at', 'applied_at']
read_only_fields = ["id", "created_at", "applied_at"]
class AISuggestionStatsSerializer(serializers.Serializer):

View file

@ -18,13 +18,11 @@ from django.test import TestCase
from django.utils import timezone
from documents.ai_deletion_manager import AIDeletionManager
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
class TestAIDeletionManagerCreateRequest(TestCase):

View file

@ -10,24 +10,23 @@ Tests cover:
- Permission assignment and verification
"""
from django.contrib.auth.models import Group, Permission, User
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from rest_framework.test import APIRequestFactory
from documents.models import Document
from documents.permissions import (
CanApplyAISuggestionsPermission,
CanApproveDeletionsPermission,
CanConfigureAIPermission,
CanViewAISuggestionsPermission,
)
from documents.permissions import CanApplyAISuggestionsPermission
from documents.permissions import CanApproveDeletionsPermission
from documents.permissions import CanConfigureAIPermission
from documents.permissions import CanViewAISuggestionsPermission
class MockView:
"""Mock view for testing permissions."""
pass
class TestCanViewAISuggestionsPermission(TestCase):
@ -41,13 +40,13 @@ class TestCanViewAISuggestionsPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -107,13 +106,13 @@ class TestCanApplyAISuggestionsPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -173,13 +172,13 @@ class TestCanApproveDeletionsPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -239,13 +238,13 @@ class TestCanConfigureAIPermission(TestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.regular_user = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
self.permitted_user = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign permission to permitted_user
@ -345,7 +344,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_viewer_role_permissions(self):
"""Test that viewer role has appropriate permissions."""
user = User.objects.create_user(
username="viewer", email="viewer@test.com", password="viewer123"
username="viewer", email="viewer@test.com", password="viewer123",
)
user.groups.add(self.viewer_group)
@ -360,7 +359,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_editor_role_permissions(self):
"""Test that editor role has appropriate permissions."""
user = User.objects.create_user(
username="editor", email="editor@test.com", password="editor123"
username="editor", email="editor@test.com", password="editor123",
)
user.groups.add(self.editor_group)
@ -375,7 +374,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_admin_role_permissions(self):
"""Test that admin role has all permissions."""
user = User.objects.create_user(
username="ai_admin", email="ai_admin@test.com", password="admin123"
username="ai_admin", email="ai_admin@test.com", password="admin123",
)
user.groups.add(self.admin_group)
@ -390,7 +389,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_user_with_multiple_groups(self):
"""Test that user permissions accumulate from multiple groups."""
user = User.objects.create_user(
username="multi_role", email="multi@test.com", password="multi123"
username="multi_role", email="multi@test.com", password="multi123",
)
user.groups.add(self.viewer_group, self.editor_group)
@ -405,7 +404,7 @@ class TestRoleBasedAccessControl(TestCase):
def test_direct_permission_assignment_overrides_group(self):
"""Test that direct permission assignment works alongside group permissions."""
user = User.objects.create_user(
username="special", email="special@test.com", password="special123"
username="special", email="special@test.com", password="special123",
)
user.groups.add(self.viewer_group)
@ -428,7 +427,7 @@ class TestPermissionAssignment(TestCase):
def setUp(self):
"""Set up test user."""
self.user = User.objects.create_user(
username="testuser", email="test@test.com", password="test123"
username="testuser", email="test@test.com", password="test123",
)
content_type = ContentType.objects.get_for_model(Document)
self.view_permission, _ = Permission.objects.get_or_create(
@ -500,7 +499,7 @@ class TestPermissionEdgeCases(TestCase):
def test_inactive_user_with_permission(self):
"""Test that inactive users are denied even with permission."""
user = User.objects.create_user(
username="inactive", email="inactive@test.com", password="inactive123"
username="inactive", email="inactive@test.com", password="inactive123",
)
user.is_active = False
user.save()

View file

@ -21,24 +21,21 @@ Tests cover:
from unittest import mock
from django.db import transaction
from django.test import TestCase, override_settings
from django.test import TestCase
from django.test import override_settings
from documents.ai_scanner import (
AIScanResult,
AIDocumentScanner,
get_ai_scanner,
)
from documents.models import (
Correspondent,
CustomField,
Document,
DocumentType,
StoragePath,
Tag,
Workflow,
WorkflowTrigger,
WorkflowAction,
)
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import AIScanResult
from documents.ai_scanner import get_ai_scanner
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
class TestAIScanResult(TestCase):
@ -100,7 +97,7 @@ class TestAIDocumentScannerInitialization(TestCase):
"""Test scanner initialization with custom confidence thresholds."""
scanner = AIDocumentScanner(
auto_apply_threshold=0.90,
suggest_threshold=0.70
suggest_threshold=0.70,
)
self.assertEqual(scanner.auto_apply_threshold, 0.90)
@ -145,14 +142,14 @@ class TestAIDocumentScannerInitialization(TestCase):
class TestAIDocumentScannerLazyLoading(TestCase):
"""Test lazy loading of ML components."""
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_get_classifier_loads_successfully(self, mock_logger):
"""Test successful lazy loading of classifier."""
scanner = AIDocumentScanner()
# Mock the import and class
mock_classifier_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
return_value=mock_classifier_instance) as mock_classifier_class:
classifier = scanner._get_classifier()
@ -161,13 +158,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
mock_classifier_class.assert_called_once()
mock_logger.info.assert_called_with("ML classifier loaded successfully")
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_get_classifier_returns_cached_instance(self, mock_logger):
"""Test that classifier is only loaded once."""
scanner = AIDocumentScanner()
mock_classifier_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
return_value=mock_classifier_instance):
classifier1 = scanner._get_classifier()
classifier2 = scanner._get_classifier()
@ -175,12 +172,12 @@ class TestAIDocumentScannerLazyLoading(TestCase):
self.assertEqual(classifier1, classifier2)
self.assertIs(classifier1, classifier2)
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_get_classifier_handles_import_error(self, mock_logger):
"""Test that classifier loading handles import errors gracefully."""
scanner = AIDocumentScanner()
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
side_effect=ImportError("Module not found")):
classifier = scanner._get_classifier()
@ -196,13 +193,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
self.assertIsNone(classifier)
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_get_ner_extractor_loads_successfully(self, mock_logger):
"""Test successful lazy loading of NER extractor."""
scanner = AIDocumentScanner()
mock_ner_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.DocumentNER',
with mock.patch("documents.ai_scanner.DocumentNER",
return_value=mock_ner_instance) as mock_ner_class:
ner = scanner._get_ner_extractor()
@ -211,25 +208,25 @@ class TestAIDocumentScannerLazyLoading(TestCase):
mock_ner_class.assert_called_once()
mock_logger.info.assert_called_with("NER extractor loaded successfully")
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_get_ner_extractor_handles_error(self, mock_logger):
"""Test NER extractor handles loading errors."""
scanner = AIDocumentScanner()
with mock.patch('documents.ai_scanner.DocumentNER',
with mock.patch("documents.ai_scanner.DocumentNER",
side_effect=Exception("Failed to load")):
ner = scanner._get_ner_extractor()
self.assertIsNone(ner)
mock_logger.warning.assert_called()
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_get_semantic_search_loads_successfully(self, mock_logger):
"""Test successful lazy loading of semantic search."""
scanner = AIDocumentScanner()
mock_search_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.SemanticSearch',
with mock.patch("documents.ai_scanner.SemanticSearch",
return_value=mock_search_instance) as mock_search_class:
search = scanner._get_semantic_search()
@ -238,13 +235,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
mock_search_class.assert_called_once()
mock_logger.info.assert_called_with("Semantic search loaded successfully")
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_get_table_extractor_loads_successfully(self, mock_logger):
"""Test successful lazy loading of table extractor."""
scanner = AIDocumentScanner()
mock_extractor_instance = mock.MagicMock()
with mock.patch('documents.ai_scanner.TableExtractor',
with mock.patch("documents.ai_scanner.TableExtractor",
return_value=mock_extractor_instance) as mock_extractor_class:
extractor = scanner._get_table_extractor()
@ -276,7 +273,7 @@ class TestExtractEntities(TestCase):
"dates": ["2024-01-01", "2024-12-31"],
"amounts": ["$1,000", "$500"],
"locations": ["New York"],
"misc": ["Invoice#123"]
"misc": ["Invoice#123"],
}
scanner._ner_extractor = mock_ner
@ -320,7 +317,7 @@ class TestExtractEntities(TestCase):
self.assertEqual(entities, {})
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_extract_entities_handles_exception(self, mock_logger):
"""Test that entity extraction handles exceptions gracefully."""
scanner = AIDocumentScanner()
@ -345,10 +342,10 @@ class TestSuggestTags(TestCase):
self.tag3 = Tag.objects.create(name="Tax", matching_algorithm=Tag.MATCH_AUTO)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch("documents.ai_scanner.match_tags")
def test_suggest_tags_with_matched_tags(self, mock_match_tags):
"""Test tag suggestions from matching."""
scanner = AIDocumentScanner()
@ -357,7 +354,7 @@ class TestSuggestTags(TestCase):
suggestions = scanner._suggest_tags(
self.document,
"Invoice from ACME Corp",
{}
{},
)
# Should suggest both matched tags
@ -370,14 +367,14 @@ class TestSuggestTags(TestCase):
for _, confidence in suggestions:
self.assertGreaterEqual(confidence, 0.6)
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch("documents.ai_scanner.match_tags")
def test_suggest_tags_with_organization_entities(self, mock_match_tags):
"""Test tag suggestions based on organization entities."""
scanner = AIDocumentScanner()
mock_match_tags.return_value = []
entities = {
"organizations": [{"text": "ACME Corp"}]
"organizations": [{"text": "ACME Corp"}],
}
suggestions = scanner._suggest_tags(self.document, "text", entities)
@ -386,7 +383,7 @@ class TestSuggestTags(TestCase):
tag_ids = [tag_id for tag_id, _ in suggestions]
self.assertIn(self.tag2.id, tag_ids)
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch("documents.ai_scanner.match_tags")
def test_suggest_tags_removes_duplicates(self, mock_match_tags):
"""Test that duplicate tags keep highest confidence."""
scanner = AIDocumentScanner()
@ -397,8 +394,8 @@ class TestSuggestTags(TestCase):
# Implementation should remove duplicates in actual code
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.match_tags")
@mock.patch("documents.ai_scanner.logger")
def test_suggest_tags_handles_exception(self, mock_logger, mock_match_tags):
"""Test tag suggestion handles exceptions."""
scanner = AIDocumentScanner()
@ -417,18 +414,18 @@ class TestDetectCorrespondent(TestCase):
"""Set up test correspondents."""
self.correspondent1 = Correspondent.objects.create(
name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.correspondent2 = Correspondent.objects.create(
name="TechStart Inc",
matching_algorithm=Correspondent.MATCH_AUTO
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
@mock.patch('documents.ai_scanner.match_correspondents')
@mock.patch("documents.ai_scanner.match_correspondents")
def test_detect_correspondent_with_match(self, mock_match):
"""Test correspondent detection with successful match."""
scanner = AIDocumentScanner()
@ -441,7 +438,7 @@ class TestDetectCorrespondent(TestCase):
self.assertEqual(corr_id, self.correspondent1.id)
self.assertEqual(confidence, 0.85)
@mock.patch('documents.ai_scanner.match_correspondents')
@mock.patch("documents.ai_scanner.match_correspondents")
def test_detect_correspondent_without_match(self, mock_match):
"""Test correspondent detection without match."""
scanner = AIDocumentScanner()
@ -451,14 +448,14 @@ class TestDetectCorrespondent(TestCase):
self.assertIsNone(result)
@mock.patch('documents.ai_scanner.match_correspondents')
@mock.patch("documents.ai_scanner.match_correspondents")
def test_detect_correspondent_from_ner_entities(self, mock_match):
"""Test correspondent detection from NER organizations."""
scanner = AIDocumentScanner()
mock_match.return_value = []
entities = {
"organizations": [{"text": "ACME Corporation"}]
"organizations": [{"text": "ACME Corporation"}],
}
result = scanner._detect_correspondent(self.document, "text", entities)
@ -468,8 +465,8 @@ class TestDetectCorrespondent(TestCase):
self.assertEqual(corr_id, self.correspondent1.id)
self.assertEqual(confidence, 0.70)
@mock.patch('documents.ai_scanner.match_correspondents')
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.match_correspondents")
@mock.patch("documents.ai_scanner.logger")
def test_detect_correspondent_handles_exception(self, mock_logger, mock_match):
"""Test correspondent detection handles exceptions."""
scanner = AIDocumentScanner()
@ -488,18 +485,18 @@ class TestClassifyDocumentType(TestCase):
"""Set up test document types."""
self.doc_type1 = DocumentType.objects.create(
name="Invoice",
matching_algorithm=DocumentType.MATCH_AUTO
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.doc_type2 = DocumentType.objects.create(
name="Receipt",
matching_algorithm=DocumentType.MATCH_AUTO
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
@mock.patch('documents.ai_scanner.match_document_types')
@mock.patch("documents.ai_scanner.match_document_types")
def test_classify_document_type_with_match(self, mock_match):
"""Test document type classification with match."""
scanner = AIDocumentScanner()
@ -512,7 +509,7 @@ class TestClassifyDocumentType(TestCase):
self.assertEqual(type_id, self.doc_type1.id)
self.assertEqual(confidence, 0.85)
@mock.patch('documents.ai_scanner.match_document_types')
@mock.patch("documents.ai_scanner.match_document_types")
def test_classify_document_type_without_match(self, mock_match):
"""Test document type classification without match."""
scanner = AIDocumentScanner()
@ -522,8 +519,8 @@ class TestClassifyDocumentType(TestCase):
self.assertIsNone(result)
@mock.patch('documents.ai_scanner.match_document_types')
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.match_document_types")
@mock.patch("documents.ai_scanner.logger")
def test_classify_document_type_handles_exception(self, mock_logger, mock_match):
"""Test classification handles exceptions."""
scanner = AIDocumentScanner()
@ -543,14 +540,14 @@ class TestSuggestStoragePath(TestCase):
self.storage_path1 = StoragePath.objects.create(
name="Invoices",
path="/documents/invoices",
matching_algorithm=StoragePath.MATCH_AUTO
matching_algorithm=StoragePath.MATCH_AUTO,
)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
@mock.patch('documents.ai_scanner.match_storage_paths')
@mock.patch("documents.ai_scanner.match_storage_paths")
def test_suggest_storage_path_with_match(self, mock_match):
"""Test storage path suggestion with match."""
scanner = AIDocumentScanner()
@ -564,7 +561,7 @@ class TestSuggestStoragePath(TestCase):
self.assertEqual(path_id, self.storage_path1.id)
self.assertEqual(confidence, 0.80)
@mock.patch('documents.ai_scanner.match_storage_paths')
@mock.patch("documents.ai_scanner.match_storage_paths")
def test_suggest_storage_path_without_match(self, mock_match):
"""Test storage path suggestion without match."""
scanner = AIDocumentScanner()
@ -575,8 +572,8 @@ class TestSuggestStoragePath(TestCase):
self.assertIsNone(result)
@mock.patch('documents.ai_scanner.match_storage_paths')
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.match_storage_paths")
@mock.patch("documents.ai_scanner.logger")
def test_suggest_storage_path_handles_exception(self, mock_logger, mock_match):
"""Test storage path suggestion handles exceptions."""
scanner = AIDocumentScanner()
@ -596,19 +593,19 @@ class TestExtractCustomFields(TestCase):
"""Set up test custom fields."""
self.field_date = CustomField.objects.create(
name="Invoice Date",
data_type=CustomField.FieldDataType.DATE
data_type=CustomField.FieldDataType.DATE,
)
self.field_amount = CustomField.objects.create(
name="Total Amount",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_email = CustomField.objects.create(
name="Contact Email",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
def test_extract_custom_fields_with_entities(self):
@ -618,7 +615,7 @@ class TestExtractCustomFields(TestCase):
entities = {
"dates": [{"text": "2024-01-01"}],
"amounts": [{"text": "$1,000"}],
"emails": ["test@example.com"]
"emails": ["test@example.com"],
}
fields = scanner._extract_custom_fields(self.document, "text", entities)
@ -638,12 +635,12 @@ class TestExtractCustomFields(TestCase):
# Should return empty dict
self.assertEqual(fields, {})
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_extract_custom_fields_handles_exception(self, mock_logger):
"""Test custom field extraction handles exceptions."""
scanner = AIDocumentScanner()
with mock.patch.object(CustomField.objects, 'all',
with mock.patch.object(CustomField.objects, "all",
side_effect=Exception("DB error")):
fields = scanner._extract_custom_fields(self.document, "text", {})
@ -658,31 +655,31 @@ class TestExtractFieldValue(TestCase):
"""Set up test fields."""
self.field_date = CustomField.objects.create(
name="Invoice Date",
data_type=CustomField.FieldDataType.DATE
data_type=CustomField.FieldDataType.DATE,
)
self.field_amount = CustomField.objects.create(
name="Total Amount",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_invoice = CustomField.objects.create(
name="Invoice Number",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_email = CustomField.objects.create(
name="Contact Email",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_phone = CustomField.objects.create(
name="Phone Number",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_person = CustomField.objects.create(
name="Person Name",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_company = CustomField.objects.create(
name="Company Name",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
def test_extract_field_value_date(self):
@ -691,7 +688,7 @@ class TestExtractFieldValue(TestCase):
entities = {"dates": [{"text": "2024-01-01"}]}
value, confidence = scanner._extract_field_value(
self.field_date, "text", entities
self.field_date, "text", entities,
)
self.assertEqual(value, "2024-01-01")
@ -703,7 +700,7 @@ class TestExtractFieldValue(TestCase):
entities = {"amounts": [{"text": "$1,000"}]}
value, confidence = scanner._extract_field_value(
self.field_amount, "text", entities
self.field_amount, "text", entities,
)
self.assertEqual(value, "$1,000")
@ -715,7 +712,7 @@ class TestExtractFieldValue(TestCase):
entities = {"invoice_numbers": ["INV-12345"]}
value, confidence = scanner._extract_field_value(
self.field_invoice, "text", entities
self.field_invoice, "text", entities,
)
self.assertEqual(value, "INV-12345")
@ -727,7 +724,7 @@ class TestExtractFieldValue(TestCase):
entities = {"emails": ["test@example.com"]}
value, confidence = scanner._extract_field_value(
self.field_email, "text", entities
self.field_email, "text", entities,
)
self.assertEqual(value, "test@example.com")
@ -739,7 +736,7 @@ class TestExtractFieldValue(TestCase):
entities = {"phones": ["+1-555-1234"]}
value, confidence = scanner._extract_field_value(
self.field_phone, "text", entities
self.field_phone, "text", entities,
)
self.assertEqual(value, "+1-555-1234")
@ -751,7 +748,7 @@ class TestExtractFieldValue(TestCase):
entities = {"persons": [{"text": "John Doe"}]}
value, confidence = scanner._extract_field_value(
self.field_person, "text", entities
self.field_person, "text", entities,
)
self.assertEqual(value, "John Doe")
@ -763,7 +760,7 @@ class TestExtractFieldValue(TestCase):
entities = {"organizations": [{"text": "ACME Corp"}]}
value, confidence = scanner._extract_field_value(
self.field_company, "text", entities
self.field_company, "text", entities,
)
self.assertEqual(value, "ACME Corp")
@ -775,7 +772,7 @@ class TestExtractFieldValue(TestCase):
entities = {}
value, confidence = scanner._extract_field_value(
self.field_date, "text", entities
self.field_date, "text", entities,
)
self.assertIsNone(value)
@ -789,23 +786,23 @@ class TestSuggestWorkflows(TestCase):
"""Set up test workflows."""
self.workflow1 = Workflow.objects.create(
name="Invoice Processing",
enabled=True
enabled=True,
)
self.trigger1 = WorkflowTrigger.objects.create(
workflow=self.workflow1,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
)
self.workflow2 = Workflow.objects.create(
name="Document Archival",
enabled=True
enabled=True,
)
self.trigger2 = WorkflowTrigger.objects.create(
workflow=self.workflow2,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
def test_suggest_workflows_with_matches(self):
@ -820,7 +817,7 @@ class TestSuggestWorkflows(TestCase):
# Create action for workflow
WorkflowAction.objects.create(
workflow=self.workflow1,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
)
suggestions = scanner._suggest_workflows(self.document, "text", scan_result)
@ -841,17 +838,17 @@ class TestSuggestWorkflows(TestCase):
# Should not suggest any (confidence too low)
self.assertEqual(len(suggestions), 0)
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_suggest_workflows_handles_exception(self, mock_logger):
"""Test workflow suggestion handles exceptions."""
scanner = AIDocumentScanner()
scan_result = AIScanResult()
with mock.patch.object(Workflow.objects, 'filter',
with mock.patch.object(Workflow.objects, "filter",
side_effect=Exception("DB error")):
suggestions = scanner._suggest_workflows(
self.document, "text", scan_result
self.document, "text", scan_result,
)
self.assertEqual(suggestions, [])
@ -865,11 +862,11 @@ class TestEvaluateWorkflowMatch(TestCase):
"""Set up test workflow."""
self.workflow = Workflow.objects.create(
name="Test Workflow",
enabled=True
enabled=True,
)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
def test_evaluate_workflow_match_base_confidence(self):
@ -878,7 +875,7 @@ class TestEvaluateWorkflowMatch(TestCase):
scan_result = AIScanResult()
confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result
self.workflow, self.document, scan_result,
)
self.assertEqual(confidence, 0.5)
@ -892,11 +889,11 @@ class TestEvaluateWorkflowMatch(TestCase):
# Create action for workflow
WorkflowAction.objects.create(
workflow=self.workflow,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
)
confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result
self.workflow, self.document, scan_result,
)
self.assertGreater(confidence, 0.5)
@ -908,7 +905,7 @@ class TestEvaluateWorkflowMatch(TestCase):
scan_result.correspondent = (1, 0.90)
confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result
self.workflow, self.document, scan_result,
)
self.assertGreater(confidence, 0.5)
@ -920,7 +917,7 @@ class TestEvaluateWorkflowMatch(TestCase):
scan_result.tags = [(1, 0.80), (2, 0.75)]
confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result
self.workflow, self.document, scan_result,
)
self.assertGreater(confidence, 0.5)
@ -936,11 +933,11 @@ class TestEvaluateWorkflowMatch(TestCase):
# Create action
WorkflowAction.objects.create(
workflow=self.workflow,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
)
confidence = scanner._evaluate_workflow_match(
self.workflow, self.document, scan_result
self.workflow, self.document, scan_result,
)
self.assertLessEqual(confidence, 1.0)
@ -953,7 +950,7 @@ class TestSuggestTitle(TestCase):
"""Set up test document."""
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
def test_suggest_title_with_all_entities(self):
@ -963,7 +960,7 @@ class TestSuggestTitle(TestCase):
entities = {
"document_type": "Invoice",
"organizations": [{"text": "ACME Corporation"}],
"dates": [{"text": "2024-01-01"}]
"dates": [{"text": "2024-01-01"}],
}
title = scanner._suggest_title(self.document, "text", entities)
@ -978,7 +975,7 @@ class TestSuggestTitle(TestCase):
scanner = AIDocumentScanner()
entities = {
"organizations": [{"text": "TechStart Inc"}]
"organizations": [{"text": "TechStart Inc"}],
}
title = scanner._suggest_title(self.document, "text", entities)
@ -1002,7 +999,7 @@ class TestSuggestTitle(TestCase):
long_org = "A" * 100
entities = {
"organizations": [{"text": long_org}],
"dates": [{"text": "2024-01-01"}]
"dates": [{"text": "2024-01-01"}],
}
title = scanner._suggest_title(self.document, "text", entities)
@ -1010,7 +1007,7 @@ class TestSuggestTitle(TestCase):
self.assertIsNotNone(title)
self.assertLessEqual(len(title), 127)
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_suggest_title_handles_exception(self, mock_logger):
"""Test title suggestion handles exceptions."""
scanner = AIDocumentScanner()
@ -1034,7 +1031,7 @@ class TestExtractTables(TestCase):
mock_extractor = mock.MagicMock()
mock_extractor.extract_tables_from_image.return_value = [
{"data": [[1, 2], [3, 4]], "headers": ["A", "B"]}
{"data": [[1, 2], [3, 4]], "headers": ["A", "B"]},
]
scanner._table_extractor = mock_extractor
@ -1053,7 +1050,7 @@ class TestExtractTables(TestCase):
self.assertEqual(tables, [])
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_extract_tables_handles_exception(self, mock_logger):
"""Test table extraction handles exceptions."""
scanner = AIDocumentScanner()
@ -1075,17 +1072,17 @@ class TestScanDocument(TestCase):
"""Set up test document."""
self.document = Document.objects.create(
title="Test Document",
content="Invoice from ACME Corporation dated 2024-01-01"
content="Invoice from ACME Corporation dated 2024-01-01",
)
@mock.patch.object(AIDocumentScanner, '_extract_entities')
@mock.patch.object(AIDocumentScanner, '_suggest_tags')
@mock.patch.object(AIDocumentScanner, '_detect_correspondent')
@mock.patch.object(AIDocumentScanner, '_classify_document_type')
@mock.patch.object(AIDocumentScanner, '_suggest_storage_path')
@mock.patch.object(AIDocumentScanner, '_extract_custom_fields')
@mock.patch.object(AIDocumentScanner, '_suggest_workflows')
@mock.patch.object(AIDocumentScanner, '_suggest_title')
@mock.patch.object(AIDocumentScanner, "_extract_entities")
@mock.patch.object(AIDocumentScanner, "_suggest_tags")
@mock.patch.object(AIDocumentScanner, "_detect_correspondent")
@mock.patch.object(AIDocumentScanner, "_classify_document_type")
@mock.patch.object(AIDocumentScanner, "_suggest_storage_path")
@mock.patch.object(AIDocumentScanner, "_extract_custom_fields")
@mock.patch.object(AIDocumentScanner, "_suggest_workflows")
@mock.patch.object(AIDocumentScanner, "_suggest_title")
def test_scan_document_orchestrates_all_methods(
self,
mock_title,
@ -1095,7 +1092,7 @@ class TestScanDocument(TestCase):
mock_doc_type,
mock_correspondent,
mock_tags,
mock_entities
mock_entities,
):
"""Test that scan_document calls all extraction methods."""
scanner = AIDocumentScanner()
@ -1127,26 +1124,26 @@ class TestScanDocument(TestCase):
self.assertEqual(result.correspondent, (1, 0.90))
self.assertEqual(result.document_type, (1, 0.80))
@mock.patch.object(AIDocumentScanner, '_extract_tables')
@mock.patch.object(AIDocumentScanner, "_extract_tables")
def test_scan_document_extracts_tables_when_enabled(self, mock_extract_tables):
"""Test that tables are extracted when OCR is enabled and file path provided."""
scanner = AIDocumentScanner(enable_advanced_ocr=True)
mock_extract_tables.return_value = [{"data": "test"}]
# Mock other methods to avoid complexity
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(
self.document,
"Document text",
original_file_path="/path/to/file.pdf"
original_file_path="/path/to/file.pdf",
)
mock_extract_tables.assert_called_once_with("/path/to/file.pdf")
@ -1156,15 +1153,15 @@ class TestScanDocument(TestCase):
"""Test that tables are not extracted when file path is not provided."""
scanner = AIDocumentScanner(enable_advanced_ocr=True)
with mock.patch.object(scanner, '_extract_tables') as mock_extract_tables, \
mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_tables") as mock_extract_tables, \
mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, "Document text")
@ -1183,11 +1180,11 @@ class TestApplyScanResults(TestCase):
self.doc_type = DocumentType.objects.create(name="Invoice")
self.storage_path = StoragePath.objects.create(
name="Invoices",
path="/invoices"
path="/invoices",
)
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
def test_apply_scan_results_auto_applies_high_confidence(self):
@ -1203,7 +1200,7 @@ class TestApplyScanResults(TestCase):
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify auto-applied
@ -1222,7 +1219,7 @@ class TestApplyScanResults(TestCase):
"""Test that medium confidence items are suggested, not applied."""
scanner = AIDocumentScanner(
auto_apply_threshold=0.80,
suggest_threshold=0.60
suggest_threshold=0.60,
)
scan_result = AIScanResult()
@ -1232,7 +1229,7 @@ class TestApplyScanResults(TestCase):
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify suggested but not applied
@ -1255,7 +1252,7 @@ class TestApplyScanResults(TestCase):
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=False
auto_apply=False,
)
# Verify nothing was applied
@ -1268,21 +1265,21 @@ class TestApplyScanResults(TestCase):
scan_result = AIScanResult()
scan_result.correspondent = (self.correspondent.id, 0.90)
with mock.patch.object(self.document, 'save',
with mock.patch.object(self.document, "save",
side_effect=Exception("Save failed")):
with self.assertRaises(Exception):
with transaction.atomic():
scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify transaction was rolled back
self.document.refresh_from_db()
self.assertIsNone(self.document.correspondent)
@mock.patch('documents.ai_scanner.logger')
@mock.patch("documents.ai_scanner.logger")
def test_apply_scan_results_handles_exception(self, mock_logger):
"""Test that apply_scan_results handles exceptions gracefully."""
scanner = AIDocumentScanner()
@ -1293,7 +1290,7 @@ class TestApplyScanResults(TestCase):
scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
mock_logger.error.assert_called()
@ -1323,21 +1320,21 @@ class TestEdgeCasesAndErrorHandling(TestCase):
"""Set up test document."""
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
def test_scan_document_with_empty_text(self):
"""Test scanning document with empty text."""
scanner = AIDocumentScanner()
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, "")
@ -1349,14 +1346,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
scanner = AIDocumentScanner()
long_text = "A" * 100000
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, long_text)
@ -1367,14 +1364,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
scanner = AIDocumentScanner()
special_text = "Test with émojis 😀 and special chars: <>{}[]|\\`~"
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(self.document, special_text)
@ -1388,7 +1385,7 @@ class TestEdgeCasesAndErrorHandling(TestCase):
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
self.assertEqual(result["applied"]["tags"], [])
@ -1403,12 +1400,12 @@ class TestEdgeCasesAndErrorHandling(TestCase):
# Test extreme values
scanner_low = AIDocumentScanner(
auto_apply_threshold=0.01,
suggest_threshold=0.01
suggest_threshold=0.01,
)
self.assertEqual(scanner_low.auto_apply_threshold, 0.01)
scanner_high = AIDocumentScanner(
auto_apply_threshold=0.99,
suggest_threshold=0.80
suggest_threshold=0.80,
)
self.assertEqual(scanner_high.auto_apply_threshold, 0.99)

View file

@ -8,24 +8,21 @@ document consumption to metadata application.
from unittest import mock
from django.test import TestCase, TransactionTestCase
from django.test import TestCase
from django.test import TransactionTestCase
from documents.ai_scanner import (
AIDocumentScanner,
AIScanResult,
get_ai_scanner,
)
from documents.models import (
Correspondent,
CustomField,
Document,
DocumentType,
StoragePath,
Tag,
Workflow,
WorkflowTrigger,
WorkflowAction,
)
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import AIScanResult
from documents.ai_scanner import get_ai_scanner
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
class TestAIScannerIntegrationBasic(TestCase):
@ -35,49 +32,49 @@ class TestAIScannerIntegrationBasic(TestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Invoice from ACME Corporation",
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000"
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000",
)
self.tag_invoice = Tag.objects.create(
name="Invoice",
matching_algorithm=Tag.MATCH_AUTO,
match="invoice"
match="invoice",
)
self.tag_important = Tag.objects.create(
name="Important",
matching_algorithm=Tag.MATCH_AUTO,
match="total"
match="total",
)
self.correspondent = Correspondent.objects.create(
name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO,
match="acme"
match="acme",
)
self.doc_type = DocumentType.objects.create(
name="Invoice",
matching_algorithm=DocumentType.MATCH_AUTO,
match="invoice"
match="invoice",
)
self.storage_path = StoragePath.objects.create(
name="Invoices",
path="/invoices",
matching_algorithm=StoragePath.MATCH_AUTO,
match="invoice"
match="invoice",
)
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch('documents.ai_scanner.match_correspondents')
@mock.patch('documents.ai_scanner.match_document_types')
@mock.patch('documents.ai_scanner.match_storage_paths')
@mock.patch("documents.ai_scanner.match_tags")
@mock.patch("documents.ai_scanner.match_correspondents")
@mock.patch("documents.ai_scanner.match_document_types")
@mock.patch("documents.ai_scanner.match_storage_paths")
def test_full_scan_and_apply_workflow(
self,
mock_storage,
mock_types,
mock_correspondents,
mock_tags
mock_tags,
):
"""Test complete workflow from scan to application."""
# Mock the matching functions to return our test data
@ -91,7 +88,7 @@ class TestAIScannerIntegrationBasic(TestCase):
# Scan the document
scan_result = scanner.scan_document(
self.document,
self.document.content
self.document.content,
)
# Verify scan results
@ -105,7 +102,7 @@ class TestAIScannerIntegrationBasic(TestCase):
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify application
@ -118,7 +115,7 @@ class TestAIScannerIntegrationBasic(TestCase):
self.assertEqual(self.document.document_type, self.doc_type)
self.assertEqual(self.document.storage_path, self.storage_path)
@mock.patch('documents.ai_scanner.match_tags')
@mock.patch("documents.ai_scanner.match_tags")
def test_scan_with_no_matches(self, mock_tags):
"""Test scanning when no matches are found."""
mock_tags.return_value = []
@ -127,7 +124,7 @@ class TestAIScannerIntegrationBasic(TestCase):
scan_result = scanner.scan_document(
self.document,
"Some random text with no matches"
"Some random text with no matches",
)
# Should return empty results
@ -143,24 +140,24 @@ class TestAIScannerIntegrationCustomFields(TestCase):
"""Set up test data with custom fields."""
self.document = Document.objects.create(
title="Invoice",
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com"
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com",
)
self.field_date = CustomField.objects.create(
name="Invoice Date",
data_type=CustomField.FieldDataType.DATE
data_type=CustomField.FieldDataType.DATE,
)
self.field_number = CustomField.objects.create(
name="Invoice Number",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_amount = CustomField.objects.create(
name="Total Amount",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
self.field_email = CustomField.objects.create(
name="Contact Email",
data_type=CustomField.FieldDataType.STRING
data_type=CustomField.FieldDataType.STRING,
)
def test_custom_field_extraction_integration(self):
@ -173,7 +170,7 @@ class TestAIScannerIntegrationCustomFields(TestCase):
"dates": [{"text": "2024-01-01"}],
"amounts": [{"text": "$1,500"}],
"invoice_numbers": ["INV-123"],
"emails": ["john@example.com"]
"emails": ["john@example.com"],
}
scanner._ner_extractor = mock_ner
@ -196,29 +193,29 @@ class TestAIScannerIntegrationWorkflows(TestCase):
"""Set up test workflows."""
self.document = Document.objects.create(
title="Invoice",
content="Invoice document"
content="Invoice document",
)
self.workflow1 = Workflow.objects.create(
name="Invoice Processing",
enabled=True
enabled=True,
)
self.trigger1 = WorkflowTrigger.objects.create(
workflow=self.workflow1,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
)
self.action1 = WorkflowAction.objects.create(
workflow=self.workflow1,
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
)
self.workflow2 = Workflow.objects.create(
name="Archive Documents",
enabled=True
enabled=True,
)
self.trigger2 = WorkflowTrigger.objects.create(
workflow=self.workflow2,
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
)
def test_workflow_suggestion_integration(self):
@ -234,7 +231,7 @@ class TestAIScannerIntegrationWorkflows(TestCase):
workflows = scanner._suggest_workflows(
self.document,
self.document.content,
scan_result
scan_result,
)
# Should suggest workflows
@ -250,7 +247,7 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
self.tag = Tag.objects.create(name="TestTag")
self.correspondent = Correspondent.objects.create(name="TestCorp")
@ -273,12 +270,12 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
raise Exception("Forced save failure")
return original_save(self, *args, **kwargs)
with mock.patch.object(Document, 'save', failing_save):
with mock.patch.object(Document, "save", failing_save):
with self.assertRaises(Exception):
scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify changes were rolled back
@ -297,19 +294,19 @@ class TestAIScannerIntegrationPerformance(TestCase):
for i in range(5):
doc = Document.objects.create(
title=f"Document {i}",
content=f"Content for document {i}"
content=f"Content for document {i}",
)
documents.append(doc)
# Mock to avoid actual ML loading
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
results = []
for doc in documents:
@ -329,16 +326,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Business Invoice",
content="Invoice from ACME Corporation"
content="Invoice from ACME Corporation",
)
self.correspondent_acme = Correspondent.objects.create(
name="ACME Corporation",
matching_algorithm=Correspondent.MATCH_AUTO
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.correspondent_other = Correspondent.objects.create(
name="Other Company",
matching_algorithm=Correspondent.MATCH_AUTO
matching_algorithm=Correspondent.MATCH_AUTO,
)
def test_correspondent_matching_with_ner_entities(self):
@ -348,16 +345,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
# Mock NER to extract organization
mock_ner = mock.MagicMock()
mock_ner.extract_all.return_value = {
"organizations": [{"text": "ACME Corporation"}]
"organizations": [{"text": "ACME Corporation"}],
}
scanner._ner_extractor = mock_ner
# Mock matching to return empty (so NER-based matching is used)
with mock.patch('documents.ai_scanner.match_correspondents', return_value=[]):
with mock.patch("documents.ai_scanner.match_correspondents", return_value=[]):
result = scanner._detect_correspondent(
self.document,
self.document.content,
{"organizations": [{"text": "ACME Corporation"}]}
{"organizations": [{"text": "ACME Corporation"}]},
)
# Should detect ACME correspondent
@ -375,13 +372,13 @@ class TestAIScannerIntegrationTitleGeneration(TestCase):
document = Document.objects.create(
title="document.pdf",
content="Invoice from ACME Corp dated 2024-01-15"
content="Invoice from ACME Corp dated 2024-01-15",
)
entities = {
"document_type": "Invoice",
"organizations": [{"text": "ACME Corp"}],
"dates": [{"text": "2024-01-15"}]
"dates": [{"text": "2024-01-15"}],
}
title = scanner._suggest_title(document, document.content, entities)
@ -399,7 +396,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
"""Set up test data."""
self.document = Document.objects.create(
title="Test",
content="Test"
content="Test",
)
self.tag_high = Tag.objects.create(name="HighConfidence")
self.tag_medium = Tag.objects.create(name="MediumConfidence")
@ -409,7 +406,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
"""Test that only high confidence suggestions are auto-applied."""
scanner = AIDocumentScanner(
auto_apply_threshold=0.80,
suggest_threshold=0.60
suggest_threshold=0.60,
)
scan_result = AIScanResult()
@ -422,7 +419,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
result = scanner.apply_scan_results(
self.document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Verify high confidence was applied
@ -448,17 +445,17 @@ class TestAIScannerIntegrationGlobalInstance(TestCase):
# Should be functional
document = Document.objects.create(
title="Test",
content="Test content"
content="Test content",
)
with mock.patch.object(scanner1, '_extract_entities', return_value={}), \
mock.patch.object(scanner1, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner1, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner1, '_classify_document_type', return_value=None), \
mock.patch.object(scanner1, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner1, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner1, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner1, '_suggest_title', return_value=None):
with mock.patch.object(scanner1, "_extract_entities", return_value={}), \
mock.patch.object(scanner1, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner1, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner1, "_classify_document_type", return_value=None), \
mock.patch.object(scanner1, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner1, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner1, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner1, "_suggest_title", return_value=None):
result1 = scanner1.scan_document(document, document.content)
result2 = scanner2.scan_document(document, document.content)
@ -476,17 +473,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
document = Document.objects.create(
title="",
content=""
content="",
)
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(document, document.content)
@ -498,7 +495,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
document = Document.objects.create(
title="Test",
content="Test"
content="Test",
)
scan_result = AIScanResult()
@ -509,7 +506,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
result = scanner.apply_scan_results(
document,
scan_result,
auto_apply=True
auto_apply=True,
)
# Should not crash, just log errors
@ -521,17 +518,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
document = Document.objects.create(
title="Factura - España 🇪🇸",
content="Société française • 日本語 • Ελληνικά • مرحبا"
content="Société française • 日本語 • Ελληνικά • مرحبا",
)
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
mock.patch.object(scanner, '_suggest_title', return_value=None):
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
mock.patch.object(scanner, "_suggest_title", return_value=None):
result = scanner.scan_document(document, document.content)

View file

@ -12,18 +12,17 @@ Tests cover:
from unittest import mock
from django.contrib.auth.models import Permission, User
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from rest_framework import status
from rest_framework.test import APITestCase
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
@ -36,13 +35,13 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
# Assign view permission
@ -57,7 +56,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create test document
self.document = Document.objects.create(
title="Test Document",
content="This is a test invoice from ACME Corporation"
content="This is a test invoice from ACME Corporation",
)
# Create test metadata objects
@ -70,7 +69,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@ -82,7 +81,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -91,7 +90,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test that superusers can access the endpoint."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [(self.tag.id, 0.85)]
@ -108,7 +107,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -119,7 +118,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test that users with permission can access the endpoint."""
self.client.force_authenticate(user=self.user_with_permission)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = []
@ -136,7 +135,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -148,7 +147,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/",
{"document_id": 99999},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -160,7 +159,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/",
{},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -175,10 +174,10 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
# Assign apply permission
@ -193,7 +192,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
# Create test document
self.document = Document.objects.create(
title="Test Document",
content="Test content"
content="Test content",
)
# Create test metadata
@ -205,7 +204,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/suggestions/apply/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@ -214,7 +213,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test successfully applying tag suggestions."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = [(self.tag.id, 0.85)]
@ -233,9 +232,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/suggestions/apply/",
{
"document_id": self.document.id,
"apply_tags": True
"apply_tags": True,
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -245,7 +244,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"""Test successfully applying correspondent suggestion."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
# Mock the scanner response
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = []
@ -264,9 +263,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/suggestions/apply/",
{
"document_id": self.document.id,
"apply_correspondent": True
"apply_correspondent": True,
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -285,10 +284,10 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
def test_unauthorized_access_denied(self):
@ -309,7 +308,7 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
"""Test getting AI configuration."""
self.client.force_authenticate(user=self.superuser)
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
mock_scanner_instance = mock.MagicMock()
mock_scanner_instance.auto_apply_threshold = 0.80
mock_scanner_instance.suggest_threshold = 0.60
@ -331,9 +330,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/config/",
{
"auto_apply_threshold": 0.90,
"suggest_threshold": 0.70
"suggest_threshold": 0.70,
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -346,9 +345,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
response = self.client.post(
"/api/ai/config/",
{
"auto_apply_threshold": 1.5 # Invalid: > 1.0
"auto_apply_threshold": 1.5, # Invalid: > 1.0
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -363,13 +362,13 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
# Create users
self.superuser = User.objects.create_superuser(
username="admin", email="admin@test.com", password="admin123"
username="admin", email="admin@test.com", password="admin123",
)
self.user_with_permission = User.objects.create_user(
username="permitted", email="permitted@test.com", password="permitted123"
username="permitted", email="permitted@test.com", password="permitted123",
)
self.user_without_permission = User.objects.create_user(
username="regular", email="regular@test.com", password="regular123"
username="regular", email="regular@test.com", password="regular123",
)
# Assign approval permission
@ -385,7 +384,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
self.deletion_request = DeletionRequest.objects.create(
user=self.user_with_permission,
requested_by_ai=True,
ai_reason="Document appears to be a duplicate"
ai_reason="Document appears to be a duplicate",
)
def test_unauthorized_access_denied(self):
@ -394,9 +393,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@ -409,9 +408,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -424,9 +423,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -436,7 +435,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
self.deletion_request.refresh_from_db()
self.assertEqual(
self.deletion_request.status,
DeletionRequest.STATUS_APPROVED
DeletionRequest.STATUS_APPROVED,
)
def test_reject_deletion_success(self):
@ -448,9 +447,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
{
"request_id": self.deletion_request.id,
"action": "reject",
"reason": "Document is still needed"
"reason": "Document is still needed",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -459,7 +458,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
self.deletion_request.refresh_from_db()
self.assertEqual(
self.deletion_request.status,
DeletionRequest.STATUS_REJECTED
DeletionRequest.STATUS_REJECTED,
)
def test_invalid_request_id(self):
@ -470,9 +469,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/",
{
"request_id": 99999,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -485,9 +484,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
"/api/ai/deletions/approve/",
{
"request_id": self.deletion_request.id,
"action": "approve"
"action": "approve",
},
format="json"
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -502,7 +501,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
# Create user with all AI permissions
self.power_user = User.objects.create_user(
username="power_user", email="power@test.com", password="power123"
username="power_user", email="power@test.com", password="power123",
)
content_type = ContentType.objects.get_for_model(Document)
@ -525,7 +524,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
self.document = Document.objects.create(
title="Test Doc",
content="Test"
content="Test",
)
def test_power_user_can_access_all_endpoints(self):
@ -533,7 +532,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
self.client.force_authenticate(user=self.power_user)
# Test suggestions endpoint
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
mock_scan_result = mock.MagicMock()
mock_scan_result.tags = []
mock_scan_result.correspondent = None
@ -553,7 +552,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
response1 = self.client.post(
"/api/ai/suggestions/",
{"document_id": self.document.id},
format="json"
format="json",
)
self.assertEqual(response1.status_code, status.HTTP_200_OK)
@ -562,9 +561,9 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
"/api/ai/suggestions/apply/",
{
"document_id": self.document.id,
"apply_tags": False
"apply_tags": False,
},
format="json"
format="json",
)
self.assertEqual(response2.status_code, status.HTTP_200_OK)

View file

@ -9,14 +9,12 @@ from rest_framework import status
from rest_framework.test import APITestCase
from documents.ai_scanner import AIScanResult
from documents.models import (
AISuggestionFeedback,
Correspondent,
Document,
DocumentType,
StoragePath,
Tag,
)
from documents.models import AISuggestionFeedback
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
@ -64,12 +62,12 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_ai_suggestions_endpoint_exists(self):
"""Test that the ai-suggestions endpoint is accessible."""
response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/"
f"/api/documents/{self.document.pk}/ai-suggestions/",
)
# Should not be 404
self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@mock.patch('documents.ai_scanner.get_ai_scanner')
@mock.patch("documents.ai_scanner.get_ai_scanner")
def test_get_ai_suggestions_success(self, mock_get_scanner):
"""Test successfully getting AI suggestions for a document."""
# Create mock scan result
@ -87,7 +85,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Make request
response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/"
f"/api/documents/{self.document.pk}/ai-suggestions/",
)
# Verify response
@ -95,23 +93,23 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
data = response.json()
# Check tags
self.assertIn('tags', data)
self.assertEqual(len(data['tags']), 2)
self.assertEqual(data['tags'][0]['id'], self.tag1.id)
self.assertEqual(data['tags'][0]['confidence'], 0.85)
self.assertIn("tags", data)
self.assertEqual(len(data["tags"]), 2)
self.assertEqual(data["tags"][0]["id"], self.tag1.id)
self.assertEqual(data["tags"][0]["confidence"], 0.85)
# Check correspondent
self.assertIn('correspondent', data)
self.assertEqual(data['correspondent']['id'], self.correspondent.id)
self.assertEqual(data['correspondent']['confidence'], 0.90)
self.assertIn("correspondent", data)
self.assertEqual(data["correspondent"]["id"], self.correspondent.id)
self.assertEqual(data["correspondent"]["confidence"], 0.90)
# Check document type
self.assertIn('document_type', data)
self.assertEqual(data['document_type']['id'], self.doc_type.id)
self.assertIn("document_type", data)
self.assertEqual(data["document_type"]["id"], self.doc_type.id)
# Check title suggestion
self.assertIn('title_suggestion', data)
self.assertEqual(data['title_suggestion']['title'], "Suggested Title")
self.assertIn("title_suggestion", data)
self.assertEqual(data["title_suggestion"]["title"], "Suggested Title")
def test_get_ai_suggestions_no_content(self):
"""Test getting AI suggestions for document without content."""
@ -126,7 +124,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("no content", response.json()['detail'].lower())
self.assertIn("no content", response.json()["detail"].lower())
def test_get_ai_suggestions_document_not_found(self):
"""Test getting AI suggestions for non-existent document."""
@ -137,19 +135,19 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_tag(self):
"""Test applying a tag suggestion."""
request_data = {
'suggestion_type': 'tag',
'value_id': self.tag1.id,
'confidence': 0.85,
"suggestion_type": "tag",
"value_id": self.tag1.id,
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['status'], 'success')
self.assertEqual(response.json()["status"], "success")
# Verify tag was applied
self.document.refresh_from_db()
@ -158,7 +156,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
@ -169,15 +167,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_correspondent(self):
"""Test applying a correspondent suggestion."""
request_data = {
'suggestion_type': 'correspondent',
'value_id': self.correspondent.id,
'confidence': 0.90,
"suggestion_type": "correspondent",
"value_id": self.correspondent.id,
"confidence": 0.90,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -189,7 +187,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='correspondent',
suggestion_type="correspondent",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
@ -197,15 +195,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_document_type(self):
"""Test applying a document type suggestion."""
request_data = {
'suggestion_type': 'document_type',
'value_id': self.doc_type.id,
'confidence': 0.88,
"suggestion_type": "document_type",
"value_id": self.doc_type.id,
"confidence": 0.88,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -217,35 +215,35 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_title(self):
"""Test applying a title suggestion."""
request_data = {
'suggestion_type': 'title',
'value_text': 'New Suggested Title',
'confidence': 0.80,
"suggestion_type": "title",
"value_text": "New Suggested Title",
"confidence": 0.80,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Verify title was applied
self.document.refresh_from_db()
self.assertEqual(self.document.title, 'New Suggested Title')
self.assertEqual(self.document.title, "New Suggested Title")
def test_apply_suggestion_invalid_type(self):
"""Test applying suggestion with invalid type."""
request_data = {
'suggestion_type': 'invalid_type',
'value_id': 1,
'confidence': 0.85,
"suggestion_type": "invalid_type",
"value_id": 1,
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -253,14 +251,14 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_missing_value(self):
"""Test applying suggestion without value_id or value_text."""
request_data = {
'suggestion_type': 'tag',
'confidence': 0.85,
"suggestion_type": "tag",
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@ -268,15 +266,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_apply_suggestion_nonexistent_object(self):
"""Test applying suggestion with non-existent object ID."""
request_data = {
'suggestion_type': 'tag',
'value_id': 99999,
'confidence': 0.85,
"suggestion_type": "tag",
"value_id": 99999,
"confidence": 0.85,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/apply-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -284,24 +282,24 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_reject_suggestion(self):
"""Test rejecting an AI suggestion."""
request_data = {
'suggestion_type': 'tag',
'value_id': self.tag1.id,
'confidence': 0.65,
"suggestion_type": "tag",
"value_id": self.tag1.id,
"confidence": 0.65,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/reject-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['status'], 'success')
self.assertEqual(response.json()["status"], "success")
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
@ -312,15 +310,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
def test_reject_suggestion_with_text(self):
"""Test rejecting a suggestion with text value."""
request_data = {
'suggestion_type': 'title',
'value_text': 'Bad Title Suggestion',
'confidence': 0.50,
"suggestion_type": "title",
"value_text": "Bad Title Suggestion",
"confidence": 0.50,
}
response = self.client.post(
f"/api/documents/{self.document.pk}/reject-suggestion/",
data=request_data,
format='json',
format="json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -328,11 +326,11 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Verify feedback was recorded
feedback = AISuggestionFeedback.objects.filter(
document=self.document,
suggestion_type='title',
suggestion_type="title",
).first()
self.assertIsNotNone(feedback)
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
self.assertEqual(feedback.suggested_value_text, 'Bad Title Suggestion')
self.assertEqual(feedback.suggested_value_text, "Bad Title Suggestion")
def test_ai_suggestion_stats_empty(self):
"""Test getting statistics when no feedback exists."""
@ -341,17 +339,17 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['total_suggestions'], 0)
self.assertEqual(data['total_applied'], 0)
self.assertEqual(data['total_rejected'], 0)
self.assertEqual(data['accuracy_rate'], 0)
self.assertEqual(data["total_suggestions"], 0)
self.assertEqual(data["total_applied"], 0)
self.assertEqual(data["total_rejected"], 0)
self.assertEqual(data["accuracy_rate"], 0)
def test_ai_suggestion_stats_with_data(self):
"""Test getting statistics with feedback data."""
# Create some feedback entries
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag1.id,
confidence=0.85,
status=AISuggestionFeedback.STATUS_APPLIED,
@ -359,7 +357,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
)
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag2.id,
confidence=0.70,
status=AISuggestionFeedback.STATUS_APPLIED,
@ -367,7 +365,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
)
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='correspondent',
suggestion_type="correspondent",
suggested_value_id=self.correspondent.id,
confidence=0.60,
status=AISuggestionFeedback.STATUS_REJECTED,
@ -380,25 +378,25 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
data = response.json()
# Check overall stats
self.assertEqual(data['total_suggestions'], 3)
self.assertEqual(data['total_applied'], 2)
self.assertEqual(data['total_rejected'], 1)
self.assertAlmostEqual(data['accuracy_rate'], 66.67, places=1)
self.assertEqual(data["total_suggestions"], 3)
self.assertEqual(data["total_applied"], 2)
self.assertEqual(data["total_rejected"], 1)
self.assertAlmostEqual(data["accuracy_rate"], 66.67, places=1)
# Check by_type stats
self.assertIn('by_type', data)
self.assertIn('tag', data['by_type'])
self.assertEqual(data['by_type']['tag']['total'], 2)
self.assertEqual(data['by_type']['tag']['applied'], 2)
self.assertEqual(data['by_type']['tag']['rejected'], 0)
self.assertIn("by_type", data)
self.assertIn("tag", data["by_type"])
self.assertEqual(data["by_type"]["tag"]["total"], 2)
self.assertEqual(data["by_type"]["tag"]["applied"], 2)
self.assertEqual(data["by_type"]["tag"]["rejected"], 0)
# Check confidence averages
self.assertGreater(data['average_confidence_applied'], 0)
self.assertGreater(data['average_confidence_rejected'], 0)
self.assertGreater(data["average_confidence_applied"], 0)
self.assertGreater(data["average_confidence_rejected"], 0)
# Check recent suggestions
self.assertIn('recent_suggestions', data)
self.assertEqual(len(data['recent_suggestions']), 3)
self.assertIn("recent_suggestions", data)
self.assertEqual(len(data["recent_suggestions"]), 3)
def test_ai_suggestion_stats_accuracy_calculation(self):
"""Test that accuracy rate is calculated correctly."""
@ -406,7 +404,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
for i in range(7):
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag1.id,
confidence=0.80,
status=AISuggestionFeedback.STATUS_APPLIED,
@ -416,7 +414,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
for i in range(3):
AISuggestionFeedback.objects.create(
document=self.document,
suggestion_type='tag',
suggestion_type="tag",
suggested_value_id=self.tag2.id,
confidence=0.60,
status=AISuggestionFeedback.STATUS_REJECTED,
@ -428,10 +426,10 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['total_suggestions'], 10)
self.assertEqual(data['total_applied'], 7)
self.assertEqual(data['total_rejected'], 3)
self.assertEqual(data['accuracy_rate'], 70.0)
self.assertEqual(data["total_suggestions"], 10)
self.assertEqual(data["total_applied"], 7)
self.assertEqual(data["total_rejected"], 3)
self.assertEqual(data["accuracy_rate"], 70.0)
def test_authentication_required(self):
"""Test that authentication is required for all endpoints."""
@ -439,7 +437,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
# Test ai-suggestions endpoint
response = self.client.get(
f"/api/documents/{self.document.pk}/ai-suggestions/"
f"/api/documents/{self.document.pk}/ai-suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

View file

@ -11,17 +11,14 @@ Tests cover:
"""
from django.contrib.auth.models import User
from django.test import override_settings
from rest_framework import status
from rest_framework.test import APITestCase
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
class TestDeletionRequestAPI(APITestCase):

View file

@ -1561,7 +1561,6 @@ class TestConsumerAIScannerIntegration(
Verifies that AI scanner respects database transactions and handles
rollbacks correctly.
"""
from django.db import transaction as db_transaction
tag = Tag.objects.create(name="Invoice")

View file

@ -15,13 +15,11 @@ from django.contrib.auth.models import User
from django.test import TestCase
from django.utils import timezone
from documents.models import (
Correspondent,
DeletionRequest,
Document,
DocumentType,
Tag,
)
from documents.models import Correspondent
from documents.models import DeletionRequest
from documents.models import Document
from documents.models import DocumentType
from documents.models import Tag
class TestDeletionRequestModelCreation(TestCase):

View file

@ -8,11 +8,9 @@ from unittest import mock
from django.test import TestCase
from documents.ml.model_cache import (
CacheMetrics,
LRUCache,
ModelCacheManager,
)
from documents.ml.model_cache import CacheMetrics
from documents.ml.model_cache import LRUCache
from documents.ml.model_cache import ModelCacheManager
class TestCacheMetrics(TestCase):

View file

@ -150,7 +150,6 @@ class TestMLCacheDirectory:
def test_model_cache_writable(self, tmp_path):
"""Test that we can write to model cache directory."""
import pathlib
# Use tmp_path fixture for testing
cache_dir = tmp_path / ".cache" / "huggingface"
@ -169,7 +168,6 @@ class TestMLCacheDirectory:
def test_torch_cache_directory(self, tmp_path, monkeypatch):
"""Test that PyTorch can use a custom cache directory."""
import torch
# Set custom cache directory
cache_dir = tmp_path / ".cache" / "torch"
@ -204,9 +202,10 @@ class TestMLPerformanceBasic:
def test_numpy_performance_basic(self):
"""Test basic NumPy performance with larger arrays."""
import numpy as np
import time
import numpy as np
# Create large array (10 million elements)
arr = np.random.rand(10_000_000)

View file

@ -89,6 +89,8 @@ from rest_framework.viewsets import ViewSet
from documents import bulk_edit
from documents import index
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import get_ai_scanner
from documents.bulk_download import ArchiveOnlyStrategy
from documents.bulk_download import OriginalAndArchiveStrategy
from documents.bulk_download import OriginalsOnlyStrategy
@ -141,13 +143,10 @@ from documents.models import UiSettings
from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
from documents.ai_scanner import AIDocumentScanner
from documents.ai_scanner import get_ai_scanner
from documents.parsers import get_parser_class_for_mime_type
from documents.parsers import parse_date_generator
from documents.permissions import AcknowledgeTasksPermissions
from documents.permissions import CanApplyAISuggestionsPermission
from documents.permissions import CanApproveDeletionsPermission
from documents.permissions import CanConfigureAIPermission
from documents.permissions import CanViewAISuggestionsPermission
from documents.permissions import PaperlessAdminPermissions
@ -1388,7 +1387,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
scan_result = scanner.scan_document(
document=document,
document_text=document.content,
original_file_path=document.source_path if hasattr(document, 'source_path') else None,
original_file_path=document.source_path if hasattr(document, "source_path") else None,
)
# Convert scan result to serializable format
@ -1424,43 +1423,43 @@ class UnifiedSearchViewSet(DocumentViewSet):
serializer = ApplySuggestionSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
suggestion_type = serializer.validated_data['suggestion_type']
value_id = serializer.validated_data.get('value_id')
value_text = serializer.validated_data.get('value_text')
confidence = serializer.validated_data['confidence']
suggestion_type = serializer.validated_data["suggestion_type"]
value_id = serializer.validated_data.get("value_id")
value_text = serializer.validated_data.get("value_text")
confidence = serializer.validated_data["confidence"]
# Apply the suggestion based on type
applied = False
result_message = ""
if suggestion_type == 'tag' and value_id:
if suggestion_type == "tag" and value_id:
tag = Tag.objects.get(pk=value_id)
document.tags.add(tag)
applied = True
result_message = f"Tag '{tag.name}' applied"
elif suggestion_type == 'correspondent' and value_id:
elif suggestion_type == "correspondent" and value_id:
correspondent = Correspondent.objects.get(pk=value_id)
document.correspondent = correspondent
document.save()
applied = True
result_message = f"Correspondent '{correspondent.name}' applied"
elif suggestion_type == 'document_type' and value_id:
elif suggestion_type == "document_type" and value_id:
doc_type = DocumentType.objects.get(pk=value_id)
document.document_type = doc_type
document.save()
applied = True
result_message = f"Document type '{doc_type.name}' applied"
elif suggestion_type == 'storage_path' and value_id:
elif suggestion_type == "storage_path" and value_id:
storage_path = StoragePath.objects.get(pk=value_id)
document.storage_path = storage_path
document.save()
applied = True
result_message = f"Storage path '{storage_path.name}' applied"
elif suggestion_type == 'title' and value_text:
elif suggestion_type == "title" and value_text:
document.title = value_text
document.save()
applied = True
@ -1518,10 +1517,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
serializer = RejectSuggestionSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
suggestion_type = serializer.validated_data['suggestion_type']
value_id = serializer.validated_data.get('value_id')
value_text = serializer.validated_data.get('value_text')
confidence = serializer.validated_data['confidence']
suggestion_type = serializer.validated_data["suggestion_type"]
value_id = serializer.validated_data.get("value_id")
value_text = serializer.validated_data.get("value_text")
confidence = serializer.validated_data["confidence"]
# Record feedback
AISuggestionFeedback.objects.create(
@ -1554,7 +1553,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
Returns aggregated data about applied vs rejected suggestions,
accuracy rates, and confidence scores.
"""
from django.db.models import Avg, Count, Q
from django.db.models import Avg
from django.db.models import Count
from django.db.models import Q
from documents.models import AISuggestionFeedback
from documents.serializers.ai_suggestions import AISuggestionStatsSerializer
@ -1562,61 +1564,63 @@ class UnifiedSearchViewSet(DocumentViewSet):
# Get overall counts
total_feedbacks = AISuggestionFeedback.objects.count()
total_applied = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_APPLIED
status=AISuggestionFeedback.STATUS_APPLIED,
).count()
total_rejected = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_REJECTED
status=AISuggestionFeedback.STATUS_REJECTED,
).count()
# Calculate accuracy rate
accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0
# Get statistics by suggestion type using a single aggregated query
stats_by_type = AISuggestionFeedback.objects.values('suggestion_type').annotate(
total=Count('id'),
applied=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
rejected=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_REJECTED))
stats_by_type = AISuggestionFeedback.objects.values("suggestion_type").annotate(
total=Count("id"),
applied=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
rejected=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)),
)
# Build the by_type dictionary using the aggregated results
by_type = {}
for stat in stats_by_type:
suggestion_type = stat['suggestion_type']
type_total = stat['total']
type_applied = stat['applied']
type_rejected = stat['rejected']
suggestion_type = stat["suggestion_type"]
type_total = stat["total"]
type_applied = stat["applied"]
type_rejected = stat["rejected"]
by_type[suggestion_type] = {
'total': type_total,
'applied': type_applied,
'rejected': type_rejected,
'accuracy_rate': (type_applied / type_total * 100) if type_total > 0 else 0,
"total": type_total,
"applied": type_applied,
"rejected": type_rejected,
"accuracy_rate": (type_applied / type_total * 100) if type_total > 0 else 0,
}
# Get average confidence scores
avg_confidence_applied = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_APPLIED
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
status=AISuggestionFeedback.STATUS_APPLIED,
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
avg_confidence_rejected = AISuggestionFeedback.objects.filter(
status=AISuggestionFeedback.STATUS_REJECTED
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
status=AISuggestionFeedback.STATUS_REJECTED,
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
# Get recent suggestions (last 10)
recent_suggestions = AISuggestionFeedback.objects.order_by('-created_at')[:10]
recent_suggestions = AISuggestionFeedback.objects.order_by("-created_at")[:10]
# Build response data
from documents.serializers.ai_suggestions import AISuggestionFeedbackSerializer
from documents.serializers.ai_suggestions import (
AISuggestionFeedbackSerializer,
)
data = {
'total_suggestions': total_feedbacks,
'total_applied': total_applied,
'total_rejected': total_rejected,
'accuracy_rate': accuracy_rate,
'by_type': by_type,
'average_confidence_applied': avg_confidence_applied,
'average_confidence_rejected': avg_confidence_rejected,
'recent_suggestions': AISuggestionFeedbackSerializer(
recent_suggestions, many=True
"total_suggestions": total_feedbacks,
"total_applied": total_applied,
"total_rejected": total_rejected,
"accuracy_rate": accuracy_rate,
"by_type": by_type,
"average_confidence_applied": avg_confidence_applied,
"average_confidence_rejected": avg_confidence_rejected,
"recent_suggestions": AISuggestionFeedbackSerializer(
recent_suggestions, many=True,
).data,
}
@ -3571,21 +3575,21 @@ class AISuggestionsView(GenericAPIView):
request_serializer = AISuggestionsRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True)
document_id = request_serializer.validated_data['document_id']
document_id = request_serializer.validated_data["document_id"]
try:
document = Document.objects.get(pk=document_id)
except Document.DoesNotExist:
return Response(
{"error": "Document not found or you don't have permission to view it"},
status=status.HTTP_404_NOT_FOUND
status=status.HTTP_404_NOT_FOUND,
)
# Check if user has permission to view this document
if not has_perms_owner_aware(request.user, 'documents.view_document', document):
if not has_perms_owner_aware(request.user, "documents.view_document", document):
return Response(
{"error": "Permission denied"},
status=status.HTTP_403_FORBIDDEN
status=status.HTTP_403_FORBIDDEN,
)
# Get AI scanner and scan document
@ -3600,7 +3604,7 @@ class AISuggestionsView(GenericAPIView):
"document_type": None,
"storage_path": None,
"title_suggestion": scan_result.title_suggestion,
"custom_fields": {}
"custom_fields": {},
}
# Format tag suggestions
@ -3610,7 +3614,7 @@ class AISuggestionsView(GenericAPIView):
response_data["tags"].append({
"id": tag.id,
"name": tag.name,
"confidence": confidence
"confidence": confidence,
})
except Tag.DoesNotExist:
# Tag was suggested by AI but no longer exists; skip it
@ -3624,7 +3628,7 @@ class AISuggestionsView(GenericAPIView):
response_data["correspondent"] = {
"id": correspondent.id,
"name": correspondent.name,
"confidence": confidence
"confidence": confidence,
}
except Correspondent.DoesNotExist:
# Correspondent was suggested but no longer exists; skip it
@ -3638,7 +3642,7 @@ class AISuggestionsView(GenericAPIView):
response_data["document_type"] = {
"id": doc_type.id,
"name": doc_type.name,
"confidence": confidence
"confidence": confidence,
}
except DocumentType.DoesNotExist:
# Document type was suggested but no longer exists; skip it
@ -3652,7 +3656,7 @@ class AISuggestionsView(GenericAPIView):
response_data["storage_path"] = {
"id": storage_path.id,
"name": storage_path.name,
"confidence": confidence
"confidence": confidence,
}
except StoragePath.DoesNotExist:
# Storage path was suggested but no longer exists; skip it
@ -3662,7 +3666,7 @@ class AISuggestionsView(GenericAPIView):
for field_id, (value, confidence) in scan_result.custom_fields.items():
response_data["custom_fields"][str(field_id)] = {
"value": value,
"confidence": confidence
"confidence": confidence,
}
return Response(response_data)
@ -3683,21 +3687,21 @@ class ApplyAISuggestionsView(GenericAPIView):
serializer = ApplyAISuggestionsSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
document_id = serializer.validated_data['document_id']
document_id = serializer.validated_data["document_id"]
try:
document = Document.objects.get(pk=document_id)
except Document.DoesNotExist:
return Response(
{"error": "Document not found"},
status=status.HTTP_404_NOT_FOUND
status=status.HTTP_404_NOT_FOUND,
)
# Check if user has permission to change this document
if not has_perms_owner_aware(request.user, 'documents.change_document', document):
if not has_perms_owner_aware(request.user, "documents.change_document", document):
return Response(
{"error": "Permission denied"},
status=status.HTTP_403_FORBIDDEN
status=status.HTTP_403_FORBIDDEN,
)
# Get AI scanner and scan document
@ -3707,8 +3711,8 @@ class ApplyAISuggestionsView(GenericAPIView):
# Apply suggestions based on user selections
applied = []
if serializer.validated_data.get('apply_tags'):
selected_tags = serializer.validated_data.get('selected_tags', [])
if serializer.validated_data.get("apply_tags"):
selected_tags = serializer.validated_data.get("selected_tags", [])
if selected_tags:
# Apply only selected tags
tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags]
@ -3725,7 +3729,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Tag not found; skip applying this tag
pass
if serializer.validated_data.get('apply_correspondent') and scan_result.correspondent:
if serializer.validated_data.get("apply_correspondent") and scan_result.correspondent:
corr_id, confidence = scan_result.correspondent
try:
correspondent = Correspondent.objects.get(pk=corr_id)
@ -3735,7 +3739,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Correspondent not found; skip applying
pass
if serializer.validated_data.get('apply_document_type') and scan_result.document_type:
if serializer.validated_data.get("apply_document_type") and scan_result.document_type:
type_id, confidence = scan_result.document_type
try:
doc_type = DocumentType.objects.get(pk=type_id)
@ -3745,7 +3749,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Document type not found; skip applying
pass
if serializer.validated_data.get('apply_storage_path') and scan_result.storage_path:
if serializer.validated_data.get("apply_storage_path") and scan_result.storage_path:
path_id, confidence = scan_result.storage_path
try:
storage_path = StoragePath.objects.get(pk=path_id)
@ -3755,7 +3759,7 @@ class ApplyAISuggestionsView(GenericAPIView):
# Storage path not found; skip applying
pass
if serializer.validated_data.get('apply_title') and scan_result.title_suggestion:
if serializer.validated_data.get("apply_title") and scan_result.title_suggestion:
document.title = scan_result.title_suggestion
applied.append(f"title: {scan_result.title_suggestion}")
@ -3765,7 +3769,7 @@ class ApplyAISuggestionsView(GenericAPIView):
return Response({
"status": "success",
"document_id": document.id,
"applied": applied
"applied": applied,
})
@ -3805,14 +3809,14 @@ class AIConfigurationView(GenericAPIView):
# Create new scanner with updated configuration
config = {}
if 'auto_apply_threshold' in serializer.validated_data:
config['auto_apply_threshold'] = serializer.validated_data['auto_apply_threshold']
if 'suggest_threshold' in serializer.validated_data:
config['suggest_threshold'] = serializer.validated_data['suggest_threshold']
if 'ml_enabled' in serializer.validated_data:
config['enable_ml_features'] = serializer.validated_data['ml_enabled']
if 'advanced_ocr_enabled' in serializer.validated_data:
config['enable_advanced_ocr'] = serializer.validated_data['advanced_ocr_enabled']
if "auto_apply_threshold" in serializer.validated_data:
config["auto_apply_threshold"] = serializer.validated_data["auto_apply_threshold"]
if "suggest_threshold" in serializer.validated_data:
config["suggest_threshold"] = serializer.validated_data["suggest_threshold"]
if "ml_enabled" in serializer.validated_data:
config["enable_ml_features"] = serializer.validated_data["ml_enabled"]
if "advanced_ocr_enabled" in serializer.validated_data:
config["enable_advanced_ocr"] = serializer.validated_data["advanced_ocr_enabled"]
# Update global scanner instance
# WARNING: Not thread-safe. Consider storing configuration in database
@ -3822,7 +3826,7 @@ class AIConfigurationView(GenericAPIView):
return Response({
"status": "success",
"message": "AI configuration updated. Changes may require server restart for consistency."
"message": "AI configuration updated. Changes may require server restart for consistency.",
})

View file

@ -76,7 +76,7 @@ class DeletionRequestViewSet(ModelViewSet):
# Check permissions
if not self._can_manage_request(deletion_request):
return HttpResponseForbidden(
"You don't have permission to approve this deletion request."
"You don't have permission to approve this deletion request.",
)
# Validate status
@ -114,11 +114,11 @@ class DeletionRequestViewSet(ModelViewSet):
deleted_count += 1
logger.info(
f"Deleted document {doc_id} ('{doc_title}') "
f"as part of deletion request {deletion_request.id}"
f"as part of deletion request {deletion_request.id}",
)
except Exception as e:
logger.error(
f"Failed to delete document {doc.id}: {str(e)}"
f"Failed to delete document {doc.id}: {e!s}",
)
failed_deletions.append({
"id": doc.id,
@ -138,14 +138,14 @@ class DeletionRequestViewSet(ModelViewSet):
logger.info(
f"Deletion request {deletion_request.id} completed. "
f"Deleted {deleted_count}/{len(documents)} documents."
f"Deleted {deleted_count}/{len(documents)} documents.",
)
except Exception as e:
logger.error(
f"Error executing deletion request {deletion_request.id}: {str(e)}"
f"Error executing deletion request {deletion_request.id}: {e!s}",
)
return Response(
{"error": f"Failed to execute deletion: {str(e)}"},
{"error": f"Failed to execute deletion: {e!s}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@ -176,7 +176,7 @@ class DeletionRequestViewSet(ModelViewSet):
# Check permissions
if not self._can_manage_request(deletion_request):
return HttpResponseForbidden(
"You don't have permission to reject this deletion request."
"You don't have permission to reject this deletion request.",
)
# Validate status
@ -199,7 +199,7 @@ class DeletionRequestViewSet(ModelViewSet):
)
logger.info(
f"Deletion request {deletion_request.id} rejected by user {request.user.username}"
f"Deletion request {deletion_request.id} rejected by user {request.user.username}",
)
serializer = self.get_serializer(deletion_request)
@ -228,7 +228,7 @@ class DeletionRequestViewSet(ModelViewSet):
# Check permissions
if not self._can_manage_request(deletion_request):
return HttpResponseForbidden(
"You don't have permission to cancel this deletion request."
"You don't have permission to cancel this deletion request.",
)
# Validate status
@ -249,7 +249,7 @@ class DeletionRequestViewSet(ModelViewSet):
deletion_request.save()
logger.info(
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}"
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}",
)
serializer = self.get_serializer(deletion_request)

View file

@ -160,7 +160,7 @@ class SecurityHeadersMiddleware:
# Store nonce in request for use in templates
# Templates can access this via {{ request.csp_nonce }}
if hasattr(request, '_csp_nonce'):
if hasattr(request, "_csp_nonce"):
request._csp_nonce = nonce
# Prevent clickjacking attacks

View file

@ -9,7 +9,6 @@ from __future__ import annotations
import hashlib
import logging
import mimetypes
import os
import re
from pathlib import Path
@ -26,39 +25,39 @@ logger = logging.getLogger("paperless.security")
# Lista explícita de tipos MIME permitidos
ALLOWED_MIME_TYPES = {
# Documentos
'application/pdf',
'application/vnd.oasis.opendocument.text',
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-excel',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'application/vnd.oasis.opendocument.spreadsheet',
'application/vnd.oasis.opendocument.presentation',
'application/rtf',
'text/rtf',
"application/pdf",
"application/vnd.oasis.opendocument.text",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/vnd.oasis.opendocument.spreadsheet",
"application/vnd.oasis.opendocument.presentation",
"application/rtf",
"text/rtf",
# Imágenes
'image/jpeg',
'image/png',
'image/gif',
'image/tiff',
'image/bmp',
'image/webp',
"image/jpeg",
"image/png",
"image/gif",
"image/tiff",
"image/bmp",
"image/webp",
# Texto
'text/plain',
'text/html',
'text/csv',
'text/markdown',
"text/plain",
"text/html",
"text/csv",
"text/markdown",
}
# Maximum file size (100MB by default)
# Can be overridden by settings.MAX_UPLOAD_SIZE
try:
from django.conf import settings
MAX_FILE_SIZE = getattr(settings, 'MAX_UPLOAD_SIZE', 100 * 1024 * 1024) # 100MB por defecto
MAX_FILE_SIZE = getattr(settings, "MAX_UPLOAD_SIZE", 100 * 1024 * 1024) # 100MB por defecto
except ImportError:
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes
@ -114,7 +113,6 @@ ALLOWED_JS_PATTERNS = [
class FileValidationError(Exception):
"""Raised when file validation fails."""
pass
def has_whitelisted_javascript(content: bytes) -> bool:
@ -143,7 +141,7 @@ def validate_mime_type(mime_type: str) -> None:
if mime_type not in ALLOWED_MIME_TYPES:
raise FileValidationError(
f"MIME type '{mime_type}' is not allowed. "
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}"
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}",
)

View file

@ -86,7 +86,7 @@ api_router.register(r"config", ApplicationConfigurationViewSet)
api_router.register(r"processed_mail", ProcessedMailViewSet)
api_router.register(r"deletion_requests", DeletionRequestViewSet)
api_router.register(
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests"
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests",
)