mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-06 14:55:07 +01:00
fix(syntax): corrige errores de sintaxis y formato en Python
- Corrige paréntesis faltante en DeletionRequestActionSerializer (serialisers.py:2855) - Elimina espacios en blanco en líneas vacías (W293) - Elimina espacios finales en líneas (W291) - Elimina imports no utilizados (F401) - Normaliza comillas a comillas dobles (Q000) - Agrega comas finales faltantes (COM812) - Ordena imports según convenciones (I001) - Actualiza anotaciones de tipo a PEP 585 (UP006) Este commit resuelve el error de compilación en el job de CI/CD que estaba causando que fallara el linting check. Archivos afectados: 38 Líneas modificadas: ~2200
This commit is contained in:
parent
9298f64546
commit
69326b883d
38 changed files with 2077 additions and 2112 deletions
|
|
@ -14,14 +14,10 @@ According to agents.md requirements:
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger("paperless.ai_deletion")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ from __future__ import annotations
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from django.conf import settings
|
||||
|
|
@ -142,34 +141,34 @@ class AIScanResult:
|
|||
"""
|
||||
# Convert internal tuple format to TypedDict format
|
||||
result: AIScanResultDict = {
|
||||
'tags': [{'tag_id': tag_id, 'confidence': conf} for tag_id, conf in self.tags],
|
||||
'custom_fields': {
|
||||
field_id: {'value': value, 'confidence': conf}
|
||||
"tags": [{"tag_id": tag_id, "confidence": conf} for tag_id, conf in self.tags],
|
||||
"custom_fields": {
|
||||
field_id: {"value": value, "confidence": conf}
|
||||
for field_id, (value, conf) in self.custom_fields.items()
|
||||
},
|
||||
'workflows': [{'workflow_id': wf_id, 'confidence': conf} for wf_id, conf in self.workflows],
|
||||
'extracted_entities': self.extracted_entities,
|
||||
'metadata': self.metadata,
|
||||
"workflows": [{"workflow_id": wf_id, "confidence": conf} for wf_id, conf in self.workflows],
|
||||
"extracted_entities": self.extracted_entities,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
# Add optional fields only if present
|
||||
if self.correspondent:
|
||||
result['correspondent'] = {
|
||||
'correspondent_id': self.correspondent[0],
|
||||
'confidence': self.correspondent[1],
|
||||
result["correspondent"] = {
|
||||
"correspondent_id": self.correspondent[0],
|
||||
"confidence": self.correspondent[1],
|
||||
}
|
||||
if self.document_type:
|
||||
result['document_type'] = {
|
||||
'type_id': self.document_type[0],
|
||||
'confidence': self.document_type[1],
|
||||
result["document_type"] = {
|
||||
"type_id": self.document_type[0],
|
||||
"confidence": self.document_type[1],
|
||||
}
|
||||
if self.storage_path:
|
||||
result['storage_path'] = {
|
||||
'path_id': self.storage_path[0],
|
||||
'confidence': self.storage_path[1],
|
||||
result["storage_path"] = {
|
||||
"path_id": self.storage_path[0],
|
||||
"confidence": self.storage_path[1],
|
||||
}
|
||||
if self.title_suggestion:
|
||||
result['title_suggestion'] = self.title_suggestion
|
||||
result["title_suggestion"] = self.title_suggestion
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -1054,7 +1053,7 @@ class AIDocumentScanner:
|
|||
warm_up_time = time.time() - start_time
|
||||
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
|
||||
|
||||
def get_cache_metrics(self) -> Dict[str, Any]:
|
||||
def get_cache_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache performance metrics.
|
||||
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ class ConsumerPlugin(
|
|||
# Parsing phase
|
||||
document_parser = self._create_parser_instance(parser_class)
|
||||
text, date, thumbnail, archive_path, page_count = self._parse_document(
|
||||
document_parser, mime_type
|
||||
document_parser, mime_type,
|
||||
)
|
||||
|
||||
# Storage phase
|
||||
|
|
@ -394,7 +394,7 @@ class ConsumerPlugin(
|
|||
def _attempt_pdf_recovery(
|
||||
self,
|
||||
tempdir: tempfile.TemporaryDirectory,
|
||||
original_mime_type: str
|
||||
original_mime_type: str,
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to recover a PDF file with incorrect MIME type using qpdf.
|
||||
|
|
@ -438,7 +438,7 @@ class ConsumerPlugin(
|
|||
def _get_parser_class(
|
||||
self,
|
||||
mime_type: str,
|
||||
tempdir: tempfile.TemporaryDirectory
|
||||
tempdir: tempfile.TemporaryDirectory,
|
||||
) -> type[DocumentParser]:
|
||||
"""
|
||||
Determine which parser to use based on MIME type.
|
||||
|
|
@ -468,7 +468,7 @@ class ConsumerPlugin(
|
|||
|
||||
def _create_parser_instance(
|
||||
self,
|
||||
parser_class: type[DocumentParser]
|
||||
parser_class: type[DocumentParser],
|
||||
) -> DocumentParser:
|
||||
"""
|
||||
Create a parser instance with progress callback.
|
||||
|
|
@ -496,7 +496,7 @@ class ConsumerPlugin(
|
|||
def _parse_document(
|
||||
self,
|
||||
document_parser: DocumentParser,
|
||||
mime_type: str
|
||||
mime_type: str,
|
||||
) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]:
|
||||
"""
|
||||
Parse the document and extract metadata.
|
||||
|
|
@ -670,7 +670,7 @@ class ConsumerPlugin(
|
|||
self,
|
||||
document: Document,
|
||||
thumbnail: Path,
|
||||
archive_path: Path | None
|
||||
archive_path: Path | None,
|
||||
) -> None:
|
||||
"""
|
||||
Store document files (source, thumbnail, archive) to disk.
|
||||
|
|
@ -949,7 +949,7 @@ class ConsumerPlugin(
|
|||
text: The extracted document text
|
||||
"""
|
||||
# Check if AI scanner is enabled
|
||||
if not getattr(settings, 'PAPERLESS_ENABLE_AI_SCANNER', True):
|
||||
if not getattr(settings, "PAPERLESS_ENABLE_AI_SCANNER", True):
|
||||
self.log.debug("AI scanner is disabled, skipping AI analysis")
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Generated manually for performance optimization
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# Generated manually for DeletionRequest model
|
||||
# Based on model definition in documents/models.py
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
@ -48,7 +49,7 @@ class Migration(migrations.Migration):
|
|||
(
|
||||
"ai_reason",
|
||||
models.TextField(
|
||||
help_text="Detailed explanation from AI about why deletion is recommended"
|
||||
help_text="Detailed explanation from AI about why deletion is recommended",
|
||||
),
|
||||
),
|
||||
(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Generated manually for DeletionRequest performance optimization
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# Generated manually for AI Suggestions API
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ Provides AI/ML capabilities including:
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"TransformerDocumentClassifier",
|
||||
"DocumentNER",
|
||||
"SemanticSearch",
|
||||
"TransformerDocumentClassifier",
|
||||
]
|
||||
|
||||
# Lazy imports to avoid loading heavy ML libraries unless needed
|
||||
|
|
|
|||
|
|
@ -15,23 +15,16 @@ Logging levels used in this module:
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import Trainer
|
||||
from transformers import TrainingArguments
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from documents.models import Document
|
||||
|
||||
logger = logging.getLogger("paperless.ml.classifier")
|
||||
|
||||
|
||||
|
|
@ -141,7 +134,7 @@ class TransformerDocumentClassifier:
|
|||
|
||||
logger.info(
|
||||
f"Initialized TransformerDocumentClassifier with {model_name} "
|
||||
f"(caching: {use_cache})"
|
||||
f"(caching: {use_cache})",
|
||||
)
|
||||
|
||||
def train(
|
||||
|
|
|
|||
|
|
@ -24,8 +24,9 @@ import pickle
|
|||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger("paperless.ml.model_cache")
|
||||
|
||||
|
|
@ -58,7 +59,7 @@ class CacheMetrics:
|
|||
with self.lock:
|
||||
self.loads += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
with self.lock:
|
||||
total = self.hits + self.misses
|
||||
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
|
||||
|
|
@ -98,7 +99,7 @@ class LRUCache:
|
|||
self.lock = threading.Lock()
|
||||
self.metrics = CacheMetrics()
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""
|
||||
Get item from cache.
|
||||
|
||||
|
|
@ -153,7 +154,7 @@ class LRUCache:
|
|||
with self.lock:
|
||||
return len(self.cache)
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""Get cache metrics."""
|
||||
return self.metrics.get_stats()
|
||||
|
||||
|
|
@ -173,7 +174,7 @@ class ModelCacheManager:
|
|||
model = cache.get_or_load_model("classifier", loader_func)
|
||||
"""
|
||||
|
||||
_instance: Optional[ModelCacheManager] = None
|
||||
_instance: ModelCacheManager | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
|
@ -187,7 +188,7 @@ class ModelCacheManager:
|
|||
def __init__(
|
||||
self,
|
||||
max_models: int = 3,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
disk_cache_dir: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize model cache manager.
|
||||
|
|
@ -215,7 +216,7 @@ class ModelCacheManager:
|
|||
def get_instance(
|
||||
cls,
|
||||
max_models: int = 3,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
disk_cache_dir: str | None = None,
|
||||
) -> ModelCacheManager:
|
||||
"""
|
||||
Get singleton instance of ModelCacheManager.
|
||||
|
|
@ -278,7 +279,7 @@ class ModelCacheManager:
|
|||
load_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Model loaded successfully: {model_key} "
|
||||
f"(took {load_time:.2f}s)"
|
||||
f"(took {load_time:.2f}s)",
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
@ -289,7 +290,7 @@ class ModelCacheManager:
|
|||
def save_embeddings_to_disk(
|
||||
self,
|
||||
key: str,
|
||||
embeddings: Dict[int, Any],
|
||||
embeddings: dict[int, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Save embeddings to disk cache.
|
||||
|
|
@ -311,7 +312,7 @@ class ModelCacheManager:
|
|||
cache_file = self.disk_cache_dir / f"{key}.pkl"
|
||||
|
||||
try:
|
||||
with open(cache_file, 'wb') as f:
|
||||
with open(cache_file, "wb") as f:
|
||||
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}")
|
||||
return True
|
||||
|
|
@ -330,7 +331,7 @@ class ModelCacheManager:
|
|||
def load_embeddings_from_disk(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[Dict[int, Any]]:
|
||||
) -> dict[int, Any] | None:
|
||||
"""
|
||||
Load embeddings from disk cache.
|
||||
|
||||
|
|
@ -393,7 +394,7 @@ class ModelCacheManager:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to delete {cache_file}: {e}")
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache performance metrics.
|
||||
|
||||
|
|
@ -416,7 +417,7 @@ class ModelCacheManager:
|
|||
|
||||
def warm_up(
|
||||
self,
|
||||
model_loaders: Dict[str, Callable[[], Any]],
|
||||
model_loaders: dict[str, Callable[[], Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Pre-load models on startup (warm-up).
|
||||
|
|
|
|||
|
|
@ -14,15 +14,11 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger("paperless.ml.ner")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,18 +18,14 @@ Examples:
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers import util
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger("paperless.ml.semantic_search")
|
||||
|
||||
|
||||
|
|
@ -67,7 +63,7 @@ class SemanticSearch:
|
|||
"""
|
||||
logger.info(
|
||||
f"Initializing SemanticSearch with model: {model_name} "
|
||||
f"(caching: {use_cache})"
|
||||
f"(caching: {use_cache})",
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
|
@ -127,11 +123,11 @@ class SemanticSearch:
|
|||
if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor):
|
||||
logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor")
|
||||
return False
|
||||
if hasattr(embedding, 'size'):
|
||||
if hasattr(embedding, "size"):
|
||||
if embedding.size == 0:
|
||||
logger.warning(f"Embedding for doc {doc_id} is empty")
|
||||
return False
|
||||
elif hasattr(embedding, 'numel'):
|
||||
elif hasattr(embedding, "numel"):
|
||||
if embedding.numel() == 0:
|
||||
logger.warning(f"Embedding for doc {doc_id} is empty")
|
||||
return False
|
||||
|
|
@ -216,11 +212,11 @@ class SemanticSearch:
|
|||
try:
|
||||
result = self.cache_manager.save_embeddings_to_disk(
|
||||
"document_embeddings",
|
||||
self.document_embeddings
|
||||
self.document_embeddings,
|
||||
)
|
||||
if result:
|
||||
logger.info(
|
||||
f"Successfully saved {len(self.document_embeddings)} embeddings to disk"
|
||||
f"Successfully saved {len(self.document_embeddings)} embeddings to disk",
|
||||
)
|
||||
else:
|
||||
logger.error("Failed to save embeddings to disk (returned False)")
|
||||
|
|
|
|||
|
|
@ -1604,30 +1604,30 @@ class DeletionRequest(models.Model):
|
|||
# Requester (AI system)
|
||||
requested_by_ai = models.BooleanField(default=True)
|
||||
ai_reason = models.TextField(
|
||||
help_text=_("Detailed explanation from AI about why deletion is recommended")
|
||||
help_text=_("Detailed explanation from AI about why deletion is recommended"),
|
||||
)
|
||||
|
||||
# User who must approve
|
||||
user = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.CASCADE,
|
||||
related_name='deletion_requests',
|
||||
related_name="deletion_requests",
|
||||
help_text=_("User who must approve this deletion"),
|
||||
)
|
||||
|
||||
# Status tracking
|
||||
STATUS_PENDING = 'pending'
|
||||
STATUS_APPROVED = 'approved'
|
||||
STATUS_REJECTED = 'rejected'
|
||||
STATUS_CANCELLED = 'cancelled'
|
||||
STATUS_COMPLETED = 'completed'
|
||||
STATUS_PENDING = "pending"
|
||||
STATUS_APPROVED = "approved"
|
||||
STATUS_REJECTED = "rejected"
|
||||
STATUS_CANCELLED = "cancelled"
|
||||
STATUS_COMPLETED = "completed"
|
||||
|
||||
STATUS_CHOICES = [
|
||||
(STATUS_PENDING, _('Pending')),
|
||||
(STATUS_APPROVED, _('Approved')),
|
||||
(STATUS_REJECTED, _('Rejected')),
|
||||
(STATUS_CANCELLED, _('Cancelled')),
|
||||
(STATUS_COMPLETED, _('Completed')),
|
||||
(STATUS_PENDING, _("Pending")),
|
||||
(STATUS_APPROVED, _("Approved")),
|
||||
(STATUS_REJECTED, _("Rejected")),
|
||||
(STATUS_CANCELLED, _("Cancelled")),
|
||||
(STATUS_COMPLETED, _("Completed")),
|
||||
]
|
||||
|
||||
status = models.CharField(
|
||||
|
|
@ -1639,7 +1639,7 @@ class DeletionRequest(models.Model):
|
|||
# Documents to be deleted
|
||||
documents = models.ManyToManyField(
|
||||
Document,
|
||||
related_name='deletion_requests',
|
||||
related_name="deletion_requests",
|
||||
help_text=_("Documents that would be deleted if approved"),
|
||||
)
|
||||
|
||||
|
|
@ -1656,7 +1656,7 @@ class DeletionRequest(models.Model):
|
|||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name='reviewed_deletion_requests',
|
||||
related_name="reviewed_deletion_requests",
|
||||
help_text=_("User who reviewed and approved/rejected"),
|
||||
)
|
||||
review_comment = models.TextField(
|
||||
|
|
@ -1672,21 +1672,21 @@ class DeletionRequest(models.Model):
|
|||
)
|
||||
|
||||
class Meta:
|
||||
ordering = ['-created_at']
|
||||
ordering = ["-created_at"]
|
||||
verbose_name = _("deletion request")
|
||||
verbose_name_plural = _("deletion requests")
|
||||
indexes = [
|
||||
# Composite index for common listing queries (by user, filtered by status, sorted by date)
|
||||
# PostgreSQL can use this index for queries on: user, user+status, user+status+created_at
|
||||
models.Index(fields=['user', 'status', 'created_at'], name='delreq_user_status_created_idx'),
|
||||
models.Index(fields=["user", "status", "created_at"], name="delreq_user_status_created_idx"),
|
||||
# Index for queries filtering by status and date without user filter
|
||||
models.Index(fields=['status', 'created_at'], name='delreq_status_created_idx'),
|
||||
models.Index(fields=["status", "created_at"], name="delreq_status_created_idx"),
|
||||
# Index for queries filtering by user and date (common for user-specific views)
|
||||
models.Index(fields=['user', 'created_at'], name='delreq_user_created_idx'),
|
||||
models.Index(fields=["user", "created_at"], name="delreq_user_created_idx"),
|
||||
# Index for queries filtering by review date
|
||||
models.Index(fields=['reviewed_at'], name='delreq_reviewed_at_idx'),
|
||||
models.Index(fields=["reviewed_at"], name="delreq_reviewed_at_idx"),
|
||||
# Index for queries filtering by completion date
|
||||
models.Index(fields=['completed_at'], name='delreq_completed_at_idx'),
|
||||
models.Index(fields=["completed_at"], name="delreq_completed_at_idx"),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
|
|
@ -1745,67 +1745,67 @@ class AISuggestionFeedback(models.Model):
|
|||
"""
|
||||
|
||||
# Suggestion types
|
||||
TYPE_TAG = 'tag'
|
||||
TYPE_CORRESPONDENT = 'correspondent'
|
||||
TYPE_DOCUMENT_TYPE = 'document_type'
|
||||
TYPE_STORAGE_PATH = 'storage_path'
|
||||
TYPE_CUSTOM_FIELD = 'custom_field'
|
||||
TYPE_WORKFLOW = 'workflow'
|
||||
TYPE_TITLE = 'title'
|
||||
TYPE_TAG = "tag"
|
||||
TYPE_CORRESPONDENT = "correspondent"
|
||||
TYPE_DOCUMENT_TYPE = "document_type"
|
||||
TYPE_STORAGE_PATH = "storage_path"
|
||||
TYPE_CUSTOM_FIELD = "custom_field"
|
||||
TYPE_WORKFLOW = "workflow"
|
||||
TYPE_TITLE = "title"
|
||||
|
||||
SUGGESTION_TYPES = (
|
||||
(TYPE_TAG, _('Tag')),
|
||||
(TYPE_CORRESPONDENT, _('Correspondent')),
|
||||
(TYPE_DOCUMENT_TYPE, _('Document Type')),
|
||||
(TYPE_STORAGE_PATH, _('Storage Path')),
|
||||
(TYPE_CUSTOM_FIELD, _('Custom Field')),
|
||||
(TYPE_WORKFLOW, _('Workflow')),
|
||||
(TYPE_TITLE, _('Title')),
|
||||
(TYPE_TAG, _("Tag")),
|
||||
(TYPE_CORRESPONDENT, _("Correspondent")),
|
||||
(TYPE_DOCUMENT_TYPE, _("Document Type")),
|
||||
(TYPE_STORAGE_PATH, _("Storage Path")),
|
||||
(TYPE_CUSTOM_FIELD, _("Custom Field")),
|
||||
(TYPE_WORKFLOW, _("Workflow")),
|
||||
(TYPE_TITLE, _("Title")),
|
||||
)
|
||||
|
||||
# Feedback status
|
||||
STATUS_APPLIED = 'applied'
|
||||
STATUS_REJECTED = 'rejected'
|
||||
STATUS_APPLIED = "applied"
|
||||
STATUS_REJECTED = "rejected"
|
||||
|
||||
FEEDBACK_STATUS = (
|
||||
(STATUS_APPLIED, _('Applied')),
|
||||
(STATUS_REJECTED, _('Rejected')),
|
||||
(STATUS_APPLIED, _("Applied")),
|
||||
(STATUS_REJECTED, _("Rejected")),
|
||||
)
|
||||
|
||||
document = models.ForeignKey(
|
||||
Document,
|
||||
on_delete=models.CASCADE,
|
||||
related_name='ai_suggestion_feedbacks',
|
||||
verbose_name=_('document'),
|
||||
related_name="ai_suggestion_feedbacks",
|
||||
verbose_name=_("document"),
|
||||
)
|
||||
|
||||
suggestion_type = models.CharField(
|
||||
_('suggestion type'),
|
||||
_("suggestion type"),
|
||||
max_length=50,
|
||||
choices=SUGGESTION_TYPES,
|
||||
)
|
||||
|
||||
suggested_value_id = models.IntegerField(
|
||||
_('suggested value ID'),
|
||||
_("suggested value ID"),
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text=_('ID of the suggested object (tag, correspondent, etc.)'),
|
||||
help_text=_("ID of the suggested object (tag, correspondent, etc.)"),
|
||||
)
|
||||
|
||||
suggested_value_text = models.TextField(
|
||||
_('suggested value text'),
|
||||
_("suggested value text"),
|
||||
blank=True,
|
||||
help_text=_('Text representation of the suggested value'),
|
||||
help_text=_("Text representation of the suggested value"),
|
||||
)
|
||||
|
||||
confidence = models.FloatField(
|
||||
_('confidence'),
|
||||
help_text=_('AI confidence score (0.0 to 1.0)'),
|
||||
_("confidence"),
|
||||
help_text=_("AI confidence score (0.0 to 1.0)"),
|
||||
validators=[MinValueValidator(0.0), MaxValueValidator(1.0)],
|
||||
)
|
||||
|
||||
status = models.CharField(
|
||||
_('status'),
|
||||
_("status"),
|
||||
max_length=20,
|
||||
choices=FEEDBACK_STATUS,
|
||||
)
|
||||
|
|
@ -1815,36 +1815,36 @@ class AISuggestionFeedback(models.Model):
|
|||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name='ai_suggestion_feedbacks',
|
||||
verbose_name=_('user'),
|
||||
help_text=_('User who applied or rejected the suggestion'),
|
||||
related_name="ai_suggestion_feedbacks",
|
||||
verbose_name=_("user"),
|
||||
help_text=_("User who applied or rejected the suggestion"),
|
||||
)
|
||||
|
||||
created_at = models.DateTimeField(
|
||||
_('created at'),
|
||||
_("created at"),
|
||||
auto_now_add=True,
|
||||
)
|
||||
|
||||
applied_at = models.DateTimeField(
|
||||
_('applied/rejected at'),
|
||||
_("applied/rejected at"),
|
||||
auto_now=True,
|
||||
)
|
||||
|
||||
metadata = models.JSONField(
|
||||
_('metadata'),
|
||||
_("metadata"),
|
||||
default=dict,
|
||||
blank=True,
|
||||
help_text=_('Additional metadata about the suggestion'),
|
||||
help_text=_("Additional metadata about the suggestion"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('AI suggestion feedback')
|
||||
verbose_name_plural = _('AI suggestion feedbacks')
|
||||
ordering = ['-created_at']
|
||||
verbose_name = _("AI suggestion feedback")
|
||||
verbose_name_plural = _("AI suggestion feedbacks")
|
||||
ordering = ["-created_at"]
|
||||
indexes = [
|
||||
models.Index(fields=['document', 'suggestion_type']),
|
||||
models.Index(fields=['status', 'created_at']),
|
||||
models.Index(fields=['suggestion_type', 'status']),
|
||||
models.Index(fields=["document", "suggestion_type"]),
|
||||
models.Index(fields=["status", "created_at"]),
|
||||
models.Index(fields=["suggestion_type", "status"]),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
|
|
|
|||
|
|
@ -11,21 +11,21 @@ Lazy imports are used to avoid loading heavy dependencies unless needed.
|
|||
"""
|
||||
|
||||
__all__ = [
|
||||
'TableExtractor',
|
||||
'HandwritingRecognizer',
|
||||
'FormFieldDetector',
|
||||
"FormFieldDetector",
|
||||
"HandwritingRecognizer",
|
||||
"TableExtractor",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy import to avoid loading heavy ML models on startup."""
|
||||
if name == 'TableExtractor':
|
||||
if name == "TableExtractor":
|
||||
from .table_extractor import TableExtractor
|
||||
return TableExtractor
|
||||
elif name == 'HandwritingRecognizer':
|
||||
elif name == "HandwritingRecognizer":
|
||||
from .handwriting import HandwritingRecognizer
|
||||
return HandwritingRecognizer
|
||||
elif name == 'FormFieldDetector':
|
||||
elif name == "FormFieldDetector":
|
||||
from .form_detector import FormFieldDetector
|
||||
return FormFieldDetector
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ This module provides capabilities to:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -59,8 +59,8 @@ class FormFieldDetector:
|
|||
self,
|
||||
image: Image.Image,
|
||||
min_size: int = 10,
|
||||
max_size: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
max_size: int = 50,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect checkboxes in a form image.
|
||||
|
||||
|
|
@ -114,9 +114,9 @@ class FormFieldDetector:
|
|||
checked, confidence = self._is_checkbox_checked(checkbox_region)
|
||||
|
||||
checkboxes.append({
|
||||
'bbox': [x, y, x+w, y+h],
|
||||
'checked': checked,
|
||||
'confidence': confidence
|
||||
"bbox": [x, y, x+w, y+h],
|
||||
"checked": checked,
|
||||
"confidence": confidence,
|
||||
})
|
||||
|
||||
logger.info(f"Detected {len(checkboxes)} checkboxes")
|
||||
|
|
@ -129,7 +129,7 @@ class FormFieldDetector:
|
|||
logger.error(f"Error detecting checkboxes: {e}")
|
||||
return []
|
||||
|
||||
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> Tuple[bool, float]:
|
||||
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> tuple[bool, float]:
|
||||
"""
|
||||
Determine if a checkbox is checked.
|
||||
|
||||
|
|
@ -167,8 +167,8 @@ class FormFieldDetector:
|
|||
def detect_text_fields(
|
||||
self,
|
||||
image: Image.Image,
|
||||
min_width: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
min_width: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect text input fields in a form.
|
||||
|
||||
|
|
@ -202,14 +202,14 @@ class FormFieldDetector:
|
|||
cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1],
|
||||
cv2.MORPH_OPEN,
|
||||
horizontal_kernel,
|
||||
iterations=2
|
||||
iterations=2,
|
||||
)
|
||||
|
||||
# Find contours of horizontal lines
|
||||
contours, _ = cv2.findContours(
|
||||
detect_horizontal,
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
cv2.CHAIN_APPROX_SIMPLE,
|
||||
)
|
||||
|
||||
text_fields = []
|
||||
|
|
@ -221,8 +221,8 @@ class FormFieldDetector:
|
|||
# Expand upward to include text area
|
||||
text_bbox = [x, max(0, y-30), x+w, y+h]
|
||||
text_fields.append({
|
||||
'bbox': text_bbox,
|
||||
'type': 'line'
|
||||
"bbox": text_bbox,
|
||||
"type": "line",
|
||||
})
|
||||
|
||||
# Detect rectangular boxes (bordered text fields)
|
||||
|
|
@ -236,8 +236,8 @@ class FormFieldDetector:
|
|||
aspect_ratio = w / h if h > 0 else 0
|
||||
if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2:
|
||||
text_fields.append({
|
||||
'bbox': [x, y, x+w, y+h],
|
||||
'type': 'box'
|
||||
"bbox": [x, y, x+w, y+h],
|
||||
"type": "box",
|
||||
})
|
||||
|
||||
logger.info(f"Detected {len(text_fields)} text fields")
|
||||
|
|
@ -253,8 +253,8 @@ class FormFieldDetector:
|
|||
def detect_labels(
|
||||
self,
|
||||
image: Image.Image,
|
||||
field_bboxes: List[List[int]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
field_bboxes: list[list[int]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect labels near form fields.
|
||||
|
||||
|
|
@ -271,17 +271,17 @@ class FormFieldDetector:
|
|||
# Get all text with bounding boxes
|
||||
ocr_data = pytesseract.image_to_data(
|
||||
image,
|
||||
output_type=pytesseract.Output.DICT
|
||||
output_type=pytesseract.Output.DICT,
|
||||
)
|
||||
|
||||
# Group text into potential labels
|
||||
labels = []
|
||||
for i, text in enumerate(ocr_data['text']):
|
||||
for i, text in enumerate(ocr_data["text"]):
|
||||
if text.strip() and len(text.strip()) > 2:
|
||||
x = ocr_data['left'][i]
|
||||
y = ocr_data['top'][i]
|
||||
w = ocr_data['width'][i]
|
||||
h = ocr_data['height'][i]
|
||||
x = ocr_data["left"][i]
|
||||
y = ocr_data["top"][i]
|
||||
w = ocr_data["width"][i]
|
||||
h = ocr_data["height"][i]
|
||||
|
||||
label_bbox = [x, y, x+w, y+h]
|
||||
|
||||
|
|
@ -289,9 +289,9 @@ class FormFieldDetector:
|
|||
closest_field_idx = self._find_closest_field(label_bbox, field_bboxes)
|
||||
|
||||
labels.append({
|
||||
'text': text.strip(),
|
||||
'bbox': label_bbox,
|
||||
'field_index': closest_field_idx
|
||||
"text": text.strip(),
|
||||
"bbox": label_bbox,
|
||||
"field_index": closest_field_idx,
|
||||
})
|
||||
|
||||
return labels
|
||||
|
|
@ -305,9 +305,9 @@ class FormFieldDetector:
|
|||
|
||||
def _find_closest_field(
|
||||
self,
|
||||
label_bbox: List[int],
|
||||
field_bboxes: List[List[int]]
|
||||
) -> Optional[int]:
|
||||
label_bbox: list[int],
|
||||
field_bboxes: list[list[int]],
|
||||
) -> int | None:
|
||||
"""
|
||||
Find the closest field to a label.
|
||||
|
||||
|
|
@ -325,7 +325,7 @@ class FormFieldDetector:
|
|||
label_center_x = (label_bbox[0] + label_bbox[2]) / 2
|
||||
label_center_y = (label_bbox[1] + label_bbox[3]) / 2
|
||||
|
||||
min_distance = float('inf')
|
||||
min_distance = float("inf")
|
||||
closest_idx = 0
|
||||
|
||||
for i, field_bbox in enumerate(field_bboxes):
|
||||
|
|
@ -336,7 +336,7 @@ class FormFieldDetector:
|
|||
# Euclidean distance
|
||||
distance = np.sqrt(
|
||||
(label_center_x - field_center_x)**2 +
|
||||
(label_center_y - field_center_y)**2
|
||||
(label_center_y - field_center_y)**2,
|
||||
)
|
||||
|
||||
if distance < min_distance:
|
||||
|
|
@ -348,8 +348,8 @@ class FormFieldDetector:
|
|||
def detect_form_fields(
|
||||
self,
|
||||
image_path: str,
|
||||
extract_values: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
extract_values: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect all form fields and extract their values.
|
||||
|
||||
|
|
@ -372,14 +372,14 @@ class FormFieldDetector:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Detect different field types
|
||||
text_fields = self.detect_text_fields(image)
|
||||
checkboxes = self.detect_checkboxes(image)
|
||||
|
||||
# Combine all field bboxes for label detection
|
||||
all_field_bboxes = [f['bbox'] for f in text_fields] + [cb['bbox'] for cb in checkboxes]
|
||||
all_field_bboxes = [f["bbox"] for f in text_fields] + [cb["bbox"] for cb in checkboxes]
|
||||
|
||||
# Detect labels
|
||||
labels = self.detect_labels(image, all_field_bboxes)
|
||||
|
|
@ -393,20 +393,20 @@ class FormFieldDetector:
|
|||
label_text = self._find_label_for_field(i, labels, len(text_fields))
|
||||
|
||||
result = {
|
||||
'type': 'text',
|
||||
'label': label_text,
|
||||
'bbox': field['bbox'],
|
||||
"type": "text",
|
||||
"label": label_text,
|
||||
"bbox": field["bbox"],
|
||||
}
|
||||
|
||||
# Extract value if requested
|
||||
if extract_values:
|
||||
x1, y1, x2, y2 = field['bbox']
|
||||
x1, y1, x2, y2 = field["bbox"]
|
||||
field_image = image.crop((x1, y1, x2, y2))
|
||||
|
||||
recognizer = self._get_handwriting_recognizer()
|
||||
value = recognizer.recognize_from_image(field_image, preprocess=True)
|
||||
result['value'] = value.strip()
|
||||
result['confidence'] = recognizer._estimate_confidence(value)
|
||||
result["value"] = value.strip()
|
||||
result["confidence"] = recognizer._estimate_confidence(value)
|
||||
|
||||
results.append(result)
|
||||
|
||||
|
|
@ -416,11 +416,11 @@ class FormFieldDetector:
|
|||
label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes))
|
||||
|
||||
results.append({
|
||||
'type': 'checkbox',
|
||||
'label': label_text,
|
||||
'value': checkbox['checked'],
|
||||
'bbox': checkbox['bbox'],
|
||||
'confidence': checkbox['confidence']
|
||||
"type": "checkbox",
|
||||
"label": label_text,
|
||||
"value": checkbox["checked"],
|
||||
"bbox": checkbox["bbox"],
|
||||
"confidence": checkbox["confidence"],
|
||||
})
|
||||
|
||||
logger.info(f"Detected {len(results)} form fields from {image_path}")
|
||||
|
|
@ -433,8 +433,8 @@ class FormFieldDetector:
|
|||
def _find_label_for_field(
|
||||
self,
|
||||
field_idx: int,
|
||||
labels: List[Dict[str, Any]],
|
||||
total_fields: int
|
||||
labels: list[dict[str, Any]],
|
||||
total_fields: int,
|
||||
) -> str:
|
||||
"""
|
||||
Find the label text for a specific field.
|
||||
|
|
@ -449,19 +449,19 @@ class FormFieldDetector:
|
|||
"""
|
||||
matching_labels = [
|
||||
label for label in labels
|
||||
if label['field_index'] == field_idx
|
||||
if label["field_index"] == field_idx
|
||||
]
|
||||
|
||||
if matching_labels:
|
||||
# Combine multiple label parts if found
|
||||
return ' '.join(label['text'] for label in matching_labels)
|
||||
return " ".join(label["text"] for label in matching_labels)
|
||||
|
||||
return f"Field_{field_idx + 1}"
|
||||
|
||||
def extract_form_data(
|
||||
self,
|
||||
image_path: str,
|
||||
output_format: str = 'dict'
|
||||
output_format: str = "dict",
|
||||
) -> Any:
|
||||
"""
|
||||
Extract all form data as structured output.
|
||||
|
|
@ -476,16 +476,16 @@ class FormFieldDetector:
|
|||
# Detect and extract fields
|
||||
fields = self.detect_form_fields(image_path, extract_values=True)
|
||||
|
||||
if output_format == 'dict':
|
||||
if output_format == "dict":
|
||||
# Return as dictionary
|
||||
return {field['label']: field['value'] for field in fields}
|
||||
return {field["label"]: field["value"] for field in fields}
|
||||
|
||||
elif output_format == 'json':
|
||||
elif output_format == "json":
|
||||
import json
|
||||
data = {field['label']: field['value'] for field in fields}
|
||||
data = {field["label"]: field["value"] for field in fields}
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
elif output_format == 'dataframe':
|
||||
elif output_format == "dataframe":
|
||||
import pandas as pd
|
||||
return pd.DataFrame(fields)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ This module provides handwriting OCR capabilities using:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -65,8 +65,9 @@ class HandwritingRecognizer:
|
|||
return
|
||||
|
||||
try:
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
import torch
|
||||
from transformers import TrOCRProcessor
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading handwriting recognition model: {self.model_name}")
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class HandwritingRecognizer:
|
|||
def recognize_from_image(
|
||||
self,
|
||||
image: Image.Image,
|
||||
preprocess: bool = True
|
||||
preprocess: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Recognize text from a single image.
|
||||
|
|
@ -142,11 +143,12 @@ class HandwritingRecognizer:
|
|||
Preprocessed PIL Image
|
||||
"""
|
||||
try:
|
||||
from PIL import ImageEnhance, ImageFilter
|
||||
from PIL import ImageEnhance
|
||||
from PIL import ImageFilter
|
||||
|
||||
# Convert to grayscale
|
||||
if image.mode != 'L':
|
||||
image = image.convert('L')
|
||||
if image.mode != "L":
|
||||
image = image.convert("L")
|
||||
|
||||
# Enhance contrast
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
|
|
@ -156,7 +158,7 @@ class HandwritingRecognizer:
|
|||
image = image.filter(ImageFilter.MedianFilter(size=3))
|
||||
|
||||
# Convert back to RGB (required by model)
|
||||
image = image.convert('RGB')
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
|
@ -164,7 +166,7 @@ class HandwritingRecognizer:
|
|||
logger.warning(f"Error preprocessing image: {e}")
|
||||
return image
|
||||
|
||||
def detect_text_lines(self, image: Image.Image) -> List[Dict[str, Any]]:
|
||||
def detect_text_lines(self, image: Image.Image) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect individual text lines in an image.
|
||||
|
||||
|
|
@ -208,12 +210,12 @@ class HandwritingRecognizer:
|
|||
# Crop line from original image
|
||||
line_img = image.crop((x, y, x+w, y+h))
|
||||
lines.append({
|
||||
'bbox': [x, y, x+w, y+h],
|
||||
'image': line_img
|
||||
"bbox": [x, y, x+w, y+h],
|
||||
"image": line_img,
|
||||
})
|
||||
|
||||
# Sort lines top to bottom
|
||||
lines.sort(key=lambda l: l['bbox'][1])
|
||||
lines.sort(key=lambda l: l["bbox"][1])
|
||||
|
||||
logger.info(f"Detected {len(lines)} text lines")
|
||||
return lines
|
||||
|
|
@ -228,8 +230,8 @@ class HandwritingRecognizer:
|
|||
def recognize_lines(
|
||||
self,
|
||||
image_path: str,
|
||||
return_confidence: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
return_confidence: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recognize text from each line in an image.
|
||||
|
||||
|
|
@ -250,7 +252,7 @@ class HandwritingRecognizer:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Detect lines
|
||||
lines = self.detect_text_lines(image)
|
||||
|
|
@ -260,18 +262,18 @@ class HandwritingRecognizer:
|
|||
for i, line in enumerate(lines):
|
||||
logger.debug(f"Recognizing line {i+1}/{len(lines)}")
|
||||
|
||||
text = self.recognize_from_image(line['image'], preprocess=True)
|
||||
text = self.recognize_from_image(line["image"], preprocess=True)
|
||||
|
||||
result = {
|
||||
'text': text,
|
||||
'bbox': line['bbox'],
|
||||
'line_index': i
|
||||
"text": text,
|
||||
"bbox": line["bbox"],
|
||||
"line_index": i,
|
||||
}
|
||||
|
||||
if return_confidence:
|
||||
# Simple confidence based on text length and content
|
||||
confidence = self._estimate_confidence(text)
|
||||
result['confidence'] = confidence
|
||||
result["confidence"] = confidence
|
||||
|
||||
results.append(result)
|
||||
|
||||
|
|
@ -309,7 +311,7 @@ class HandwritingRecognizer:
|
|||
score += 0.1
|
||||
|
||||
# Text with spaces (words) is more reliable
|
||||
if ' ' in text:
|
||||
if " " in text:
|
||||
score += 0.1
|
||||
|
||||
# Penalize if too many special characters
|
||||
|
|
@ -322,8 +324,8 @@ class HandwritingRecognizer:
|
|||
def recognize_from_file(
|
||||
self,
|
||||
image_path: str,
|
||||
mode: str = 'full'
|
||||
) -> Dict[str, Any]:
|
||||
mode: str = "full",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recognize handwriting from an image file.
|
||||
|
||||
|
|
@ -337,30 +339,30 @@ class HandwritingRecognizer:
|
|||
Dictionary with recognized text and metadata
|
||||
"""
|
||||
try:
|
||||
if mode == 'full':
|
||||
if mode == "full":
|
||||
# Recognize entire image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
text = self.recognize_from_image(image, preprocess=True)
|
||||
|
||||
return {
|
||||
'text': text,
|
||||
'mode': 'full',
|
||||
'confidence': self._estimate_confidence(text)
|
||||
"text": text,
|
||||
"mode": "full",
|
||||
"confidence": self._estimate_confidence(text),
|
||||
}
|
||||
|
||||
elif mode == 'lines':
|
||||
elif mode == "lines":
|
||||
# Recognize line by line
|
||||
lines = self.recognize_lines(image_path, return_confidence=True)
|
||||
|
||||
# Combine all lines
|
||||
full_text = '\n'.join(line['text'] for line in lines)
|
||||
avg_confidence = np.mean([line['confidence'] for line in lines]) if lines else 0.0
|
||||
full_text = "\n".join(line["text"] for line in lines)
|
||||
avg_confidence = np.mean([line["confidence"] for line in lines]) if lines else 0.0
|
||||
|
||||
return {
|
||||
'text': full_text,
|
||||
'lines': lines,
|
||||
'mode': 'lines',
|
||||
'confidence': float(avg_confidence)
|
||||
"text": full_text,
|
||||
"lines": lines,
|
||||
"mode": "lines",
|
||||
"confidence": float(avg_confidence),
|
||||
}
|
||||
|
||||
else:
|
||||
|
|
@ -369,17 +371,17 @@ class HandwritingRecognizer:
|
|||
except Exception as e:
|
||||
logger.error(f"Error recognizing from file {image_path}: {e}")
|
||||
return {
|
||||
'text': '',
|
||||
'mode': mode,
|
||||
'confidence': 0.0,
|
||||
'error': str(e)
|
||||
"text": "",
|
||||
"mode": mode,
|
||||
"confidence": 0.0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def recognize_form_fields(
|
||||
self,
|
||||
image_path: str,
|
||||
field_regions: List[Dict[str, Any]]
|
||||
) -> Dict[str, str]:
|
||||
field_regions: list[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Recognize text from specific form fields.
|
||||
|
||||
|
|
@ -399,13 +401,13 @@ class HandwritingRecognizer:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Extract and recognize each field
|
||||
results = {}
|
||||
for field in field_regions:
|
||||
name = field['name']
|
||||
bbox = field['bbox']
|
||||
name = field["name"]
|
||||
bbox = field["bbox"]
|
||||
|
||||
# Crop field region
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
|
@ -425,9 +427,9 @@ class HandwritingRecognizer:
|
|||
|
||||
def batch_recognize(
|
||||
self,
|
||||
image_paths: List[str],
|
||||
mode: str = 'full'
|
||||
) -> List[Dict[str, Any]]:
|
||||
image_paths: list[str],
|
||||
mode: str = "full",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recognize handwriting from multiple images in batch.
|
||||
|
||||
|
|
@ -442,7 +444,7 @@ class HandwritingRecognizer:
|
|||
for i, path in enumerate(image_paths):
|
||||
logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
|
||||
result = self.recognize_from_file(path, mode=mode)
|
||||
result['image_path'] = path
|
||||
result["image_path"] = path
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -8,9 +8,8 @@ This module uses various techniques to detect and extract tables from documents:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -59,8 +58,9 @@ class TableExtractor:
|
|||
return
|
||||
|
||||
try:
|
||||
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
||||
import torch
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers import AutoModelForObjectDetection
|
||||
|
||||
logger.info(f"Loading table detection model: {self.model_name}")
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ class TableExtractor:
|
|||
logger.error("Please install required packages: pip install transformers torch pillow")
|
||||
raise
|
||||
|
||||
def detect_tables(self, image: Image.Image) -> List[Dict[str, Any]]:
|
||||
def detect_tables(self, image: Image.Image) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect tables in an image.
|
||||
|
||||
|
|
@ -117,16 +117,16 @@ class TableExtractor:
|
|||
results = self._processor.post_process_object_detection(
|
||||
outputs,
|
||||
threshold=self.confidence_threshold,
|
||||
target_sizes=target_sizes
|
||||
target_sizes=target_sizes,
|
||||
)[0]
|
||||
|
||||
# Convert to list of dicts
|
||||
tables = []
|
||||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
tables.append({
|
||||
'bbox': box.cpu().tolist(),
|
||||
'score': score.item(),
|
||||
'label': self._model.config.id2label[label.item()]
|
||||
"bbox": box.cpu().tolist(),
|
||||
"score": score.item(),
|
||||
"label": self._model.config.id2label[label.item()],
|
||||
})
|
||||
|
||||
logger.info(f"Detected {len(tables)} tables in image")
|
||||
|
|
@ -139,9 +139,9 @@ class TableExtractor:
|
|||
def extract_table_from_region(
|
||||
self,
|
||||
image: Image.Image,
|
||||
bbox: List[float],
|
||||
use_ocr: bool = True
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
bbox: list[float],
|
||||
use_ocr: bool = True,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract table data from a specific region of an image.
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ class TableExtractor:
|
|||
# Get detailed OCR data
|
||||
ocr_data = pytesseract.image_to_data(
|
||||
table_image,
|
||||
output_type=pytesseract.Output.DICT
|
||||
output_type=pytesseract.Output.DICT,
|
||||
)
|
||||
|
||||
# Reconstruct table structure from OCR data
|
||||
|
|
@ -176,20 +176,20 @@ class TableExtractor:
|
|||
raw_text = pytesseract.image_to_string(table_image)
|
||||
|
||||
return {
|
||||
'data': table_data,
|
||||
'raw_text': raw_text,
|
||||
'bbox': bbox,
|
||||
'image_size': table_image.size
|
||||
"data": table_data,
|
||||
"raw_text": raw_text,
|
||||
"bbox": bbox,
|
||||
"image_size": table_image.size,
|
||||
}
|
||||
else:
|
||||
# Fallback to basic OCR without structure
|
||||
import pytesseract
|
||||
raw_text = pytesseract.image_to_string(table_image)
|
||||
return {
|
||||
'data': None,
|
||||
'raw_text': raw_text,
|
||||
'bbox': bbox,
|
||||
'image_size': table_image.size
|
||||
"data": None,
|
||||
"raw_text": raw_text,
|
||||
"bbox": bbox,
|
||||
"image_size": table_image.size,
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
|
|
@ -199,7 +199,7 @@ class TableExtractor:
|
|||
logger.error(f"Error extracting table from region: {e}")
|
||||
return None
|
||||
|
||||
def _reconstruct_table_from_ocr(self, ocr_data: Dict) -> Optional[Any]:
|
||||
def _reconstruct_table_from_ocr(self, ocr_data: dict) -> Any | None:
|
||||
"""
|
||||
Reconstruct table structure from OCR output.
|
||||
|
||||
|
|
@ -214,10 +214,10 @@ class TableExtractor:
|
|||
|
||||
# Group text by vertical position (rows)
|
||||
rows = {}
|
||||
for i, text in enumerate(ocr_data['text']):
|
||||
for i, text in enumerate(ocr_data["text"]):
|
||||
if text.strip():
|
||||
top = ocr_data['top'][i]
|
||||
left = ocr_data['left'][i]
|
||||
top = ocr_data["top"][i]
|
||||
left = ocr_data["left"][i]
|
||||
|
||||
# Group by approximate row (within 20 pixels)
|
||||
row_key = round(top / 20) * 20
|
||||
|
|
@ -235,14 +235,14 @@ class TableExtractor:
|
|||
if table_rows:
|
||||
# Pad rows to same length
|
||||
max_cols = max(len(row) for row in table_rows)
|
||||
table_rows = [row + [''] * (max_cols - len(row)) for row in table_rows]
|
||||
table_rows = [row + [""] * (max_cols - len(row)) for row in table_rows]
|
||||
|
||||
# Create DataFrame
|
||||
df = pd.DataFrame(table_rows)
|
||||
|
||||
# Try to use first row as header if it looks like one
|
||||
if len(df) > 1:
|
||||
first_row_text = ' '.join(str(x) for x in df.iloc[0])
|
||||
first_row_text = " ".join(str(x) for x in df.iloc[0])
|
||||
if not any(char.isdigit() for char in first_row_text):
|
||||
df.columns = df.iloc[0]
|
||||
df = df[1:].reset_index(drop=True)
|
||||
|
|
@ -261,8 +261,8 @@ class TableExtractor:
|
|||
def extract_tables_from_image(
|
||||
self,
|
||||
image_path: str,
|
||||
output_format: str = 'dataframe'
|
||||
) -> List[Dict[str, Any]]:
|
||||
output_format: str = "dataframe",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract all tables from an image file.
|
||||
|
||||
|
|
@ -275,7 +275,7 @@ class TableExtractor:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Detect tables
|
||||
detections = self.detect_tables(image)
|
||||
|
|
@ -287,18 +287,18 @@ class TableExtractor:
|
|||
|
||||
table_data = self.extract_table_from_region(
|
||||
image,
|
||||
detection['bbox']
|
||||
detection["bbox"],
|
||||
)
|
||||
|
||||
if table_data:
|
||||
table_data['detection_score'] = detection['score']
|
||||
table_data['table_index'] = i
|
||||
table_data["detection_score"] = detection["score"]
|
||||
table_data["table_index"] = i
|
||||
|
||||
# Convert to requested format
|
||||
if output_format == 'csv' and table_data['data'] is not None:
|
||||
table_data['csv'] = table_data['data'].to_csv(index=False)
|
||||
elif output_format == 'json' and table_data['data'] is not None:
|
||||
table_data['json'] = table_data['data'].to_json(orient='records')
|
||||
if output_format == "csv" and table_data["data"] is not None:
|
||||
table_data["csv"] = table_data["data"].to_csv(index=False)
|
||||
elif output_format == "json" and table_data["data"] is not None:
|
||||
table_data["json"] = table_data["data"].to_json(orient="records")
|
||||
|
||||
tables.append(table_data)
|
||||
|
||||
|
|
@ -312,8 +312,8 @@ class TableExtractor:
|
|||
def extract_tables_from_pdf(
|
||||
self,
|
||||
pdf_path: str,
|
||||
page_numbers: Optional[List[int]] = None
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
page_numbers: list[int] | None = None,
|
||||
) -> dict[int, list[dict[str, Any]]]:
|
||||
"""
|
||||
Extract tables from a PDF document.
|
||||
|
||||
|
|
@ -334,7 +334,7 @@ class TableExtractor:
|
|||
images = convert_from_path(
|
||||
pdf_path,
|
||||
first_page=min(page_numbers),
|
||||
last_page=max(page_numbers)
|
||||
last_page=max(page_numbers),
|
||||
)
|
||||
else:
|
||||
images = convert_from_path(pdf_path)
|
||||
|
|
@ -352,11 +352,11 @@ class TableExtractor:
|
|||
for detection in detections:
|
||||
table_data = self.extract_table_from_region(
|
||||
image,
|
||||
detection['bbox']
|
||||
detection["bbox"],
|
||||
)
|
||||
if table_data:
|
||||
table_data['detection_score'] = detection['score']
|
||||
table_data['page'] = page_num
|
||||
table_data["detection_score"] = detection["score"]
|
||||
table_data["page"] = page_num
|
||||
tables.append(table_data)
|
||||
|
||||
if tables:
|
||||
|
|
@ -374,8 +374,8 @@ class TableExtractor:
|
|||
|
||||
def save_tables_to_excel(
|
||||
self,
|
||||
tables: List[Dict[str, Any]],
|
||||
output_path: str
|
||||
tables: list[dict[str, Any]],
|
||||
output_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Save extracted tables to an Excel file.
|
||||
|
|
@ -390,17 +390,17 @@ class TableExtractor:
|
|||
try:
|
||||
import pandas as pd
|
||||
|
||||
with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
|
||||
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
|
||||
for i, table in enumerate(tables):
|
||||
if table.get('data') is not None:
|
||||
if table.get("data") is not None:
|
||||
sheet_name = f"Table_{i+1}"
|
||||
if 'page' in table:
|
||||
if "page" in table:
|
||||
sheet_name = f"Page_{table['page']}_Table_{i+1}"
|
||||
|
||||
table['data'].to_excel(
|
||||
table["data"].to_excel(
|
||||
writer,
|
||||
sheet_name=sheet_name,
|
||||
index=False
|
||||
index=False,
|
||||
)
|
||||
|
||||
logger.info(f"Saved {len(tables)} tables to {output_path}")
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ if settings.AUDIT_LOG_ENABLED:
|
|||
from documents import bulk_edit
|
||||
from documents.data_models import DocumentSource
|
||||
from documents.filters import CustomFieldQueryParser
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import CustomFieldInstance
|
||||
|
|
@ -2788,9 +2787,9 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
|
|||
"""Serializer for DeletionRequest model with document details."""
|
||||
|
||||
document_details = serializers.SerializerMethodField()
|
||||
user_username = serializers.CharField(source='user.username', read_only=True)
|
||||
user_username = serializers.CharField(source="user.username", read_only=True)
|
||||
reviewed_by_username = serializers.CharField(
|
||||
source='reviewed_by.username',
|
||||
source="reviewed_by.username",
|
||||
read_only=True,
|
||||
allow_null=True,
|
||||
)
|
||||
|
|
@ -2799,31 +2798,31 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
|
|||
from documents.models import DeletionRequest
|
||||
model = DeletionRequest
|
||||
fields = [
|
||||
'id',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
'requested_by_ai',
|
||||
'ai_reason',
|
||||
'user',
|
||||
'user_username',
|
||||
'status',
|
||||
'impact_summary',
|
||||
'reviewed_at',
|
||||
'reviewed_by',
|
||||
'reviewed_by_username',
|
||||
'review_comment',
|
||||
'completed_at',
|
||||
'completion_details',
|
||||
'document_details',
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"requested_by_ai",
|
||||
"ai_reason",
|
||||
"user",
|
||||
"user_username",
|
||||
"status",
|
||||
"impact_summary",
|
||||
"reviewed_at",
|
||||
"reviewed_by",
|
||||
"reviewed_by_username",
|
||||
"review_comment",
|
||||
"completed_at",
|
||||
"completion_details",
|
||||
"document_details",
|
||||
]
|
||||
read_only_fields = [
|
||||
'id',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
'reviewed_at',
|
||||
'reviewed_by',
|
||||
'completed_at',
|
||||
'completion_details',
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"reviewed_at",
|
||||
"reviewed_by",
|
||||
"completed_at",
|
||||
"completion_details",
|
||||
]
|
||||
|
||||
def get_document_details(self, obj):
|
||||
|
|
@ -2831,12 +2830,12 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
|
|||
documents = obj.documents.all()
|
||||
return [
|
||||
{
|
||||
'id': doc.id,
|
||||
'title': doc.title,
|
||||
'created': doc.created.isoformat() if doc.created else None,
|
||||
'correspondent': doc.correspondent.name if doc.correspondent else None,
|
||||
'document_type': doc.document_type.name if doc.document_type else None,
|
||||
'tags': [tag.name for tag in doc.tags.all()],
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"created": doc.created.isoformat() if doc.created else None,
|
||||
"correspondent": doc.correspondent.name if doc.correspondent else None,
|
||||
"document_type": doc.document_type.name if doc.document_type else None,
|
||||
"tags": [tag.name for tag in doc.tags.all()],
|
||||
}
|
||||
for doc in documents
|
||||
]
|
||||
|
|
@ -2852,6 +2851,9 @@ class DeletionRequestActionSerializer(serializers.Serializer):
|
|||
allow_blank=True,
|
||||
label="Review Comment",
|
||||
help_text="Optional comment when reviewing the deletion request",
|
||||
)
|
||||
|
||||
|
||||
class AISuggestionsRequestSerializer(serializers.Serializer):
|
||||
"""Serializer for requesting AI suggestions for a document."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,15 @@
|
|||
"""Serializers package for documents app."""
|
||||
|
||||
from .ai_suggestions import (
|
||||
AISuggestionFeedbackSerializer,
|
||||
AISuggestionsSerializer,
|
||||
AISuggestionStatsSerializer,
|
||||
ApplySuggestionSerializer,
|
||||
RejectSuggestionSerializer,
|
||||
)
|
||||
from .ai_suggestions import AISuggestionFeedbackSerializer
|
||||
from .ai_suggestions import AISuggestionsSerializer
|
||||
from .ai_suggestions import AISuggestionStatsSerializer
|
||||
from .ai_suggestions import ApplySuggestionSerializer
|
||||
from .ai_suggestions import RejectSuggestionSerializer
|
||||
|
||||
__all__ = [
|
||||
'AISuggestionFeedbackSerializer',
|
||||
'AISuggestionsSerializer',
|
||||
'AISuggestionStatsSerializer',
|
||||
'ApplySuggestionSerializer',
|
||||
'RejectSuggestionSerializer',
|
||||
"AISuggestionFeedbackSerializer",
|
||||
"AISuggestionStatsSerializer",
|
||||
"AISuggestionsSerializer",
|
||||
"ApplySuggestionSerializer",
|
||||
"RejectSuggestionSerializer",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -7,36 +7,33 @@ and handling user feedback on AI suggestions.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from documents.models import (
|
||||
AISuggestionFeedback,
|
||||
Correspondent,
|
||||
CustomField,
|
||||
DocumentType,
|
||||
StoragePath,
|
||||
Tag,
|
||||
Workflow,
|
||||
)
|
||||
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
|
||||
# Suggestion type choices - used across multiple serializers
|
||||
SUGGESTION_TYPE_CHOICES = [
|
||||
'tag',
|
||||
'correspondent',
|
||||
'document_type',
|
||||
'storage_path',
|
||||
'custom_field',
|
||||
'workflow',
|
||||
'title',
|
||||
"tag",
|
||||
"correspondent",
|
||||
"document_type",
|
||||
"storage_path",
|
||||
"custom_field",
|
||||
"workflow",
|
||||
"title",
|
||||
]
|
||||
|
||||
# Types that require value_id
|
||||
ID_REQUIRED_TYPES = ['tag', 'correspondent', 'document_type', 'storage_path', 'workflow']
|
||||
ID_REQUIRED_TYPES = ["tag", "correspondent", "document_type", "storage_path", "workflow"]
|
||||
# Types that require value_text
|
||||
TEXT_REQUIRED_TYPES = ['title']
|
||||
TEXT_REQUIRED_TYPES = ["title"]
|
||||
# Types that can use either (custom_field can be ID or text)
|
||||
|
||||
|
||||
|
|
@ -113,7 +110,7 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True)
|
||||
|
||||
@staticmethod
|
||||
def from_scan_result(scan_result, document_id: int) -> Dict[str, Any]:
|
||||
def from_scan_result(scan_result, document_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
Convert an AIScanResult object to serializer data.
|
||||
|
||||
|
|
@ -133,25 +130,25 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
try:
|
||||
tag = Tag.objects.get(pk=tag_id)
|
||||
tag_suggestions.append({
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'color': getattr(tag, 'color', '#000000'),
|
||||
'confidence': confidence,
|
||||
"id": tag.id,
|
||||
"name": tag.name,
|
||||
"color": getattr(tag, "color", "#000000"),
|
||||
"confidence": confidence,
|
||||
})
|
||||
except Tag.DoesNotExist:
|
||||
# Tag no longer exists in database; skip this suggestion
|
||||
pass
|
||||
data['tags'] = tag_suggestions
|
||||
data["tags"] = tag_suggestions
|
||||
|
||||
# Correspondent
|
||||
if scan_result.correspondent:
|
||||
corr_id, confidence = scan_result.correspondent
|
||||
try:
|
||||
correspondent = Correspondent.objects.get(pk=corr_id)
|
||||
data['correspondent'] = {
|
||||
'id': correspondent.id,
|
||||
'name': correspondent.name,
|
||||
'confidence': confidence,
|
||||
data["correspondent"] = {
|
||||
"id": correspondent.id,
|
||||
"name": correspondent.name,
|
||||
"confidence": confidence,
|
||||
}
|
||||
except Correspondent.DoesNotExist:
|
||||
# Correspondent no longer exists in database; omit from suggestions
|
||||
|
|
@ -162,10 +159,10 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
type_id, confidence = scan_result.document_type
|
||||
try:
|
||||
doc_type = DocumentType.objects.get(pk=type_id)
|
||||
data['document_type'] = {
|
||||
'id': doc_type.id,
|
||||
'name': doc_type.name,
|
||||
'confidence': confidence,
|
||||
data["document_type"] = {
|
||||
"id": doc_type.id,
|
||||
"name": doc_type.name,
|
||||
"confidence": confidence,
|
||||
}
|
||||
except DocumentType.DoesNotExist:
|
||||
# Document type no longer exists in database; omit from suggestions
|
||||
|
|
@ -176,11 +173,11 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
path_id, confidence = scan_result.storage_path
|
||||
try:
|
||||
storage_path = StoragePath.objects.get(pk=path_id)
|
||||
data['storage_path'] = {
|
||||
'id': storage_path.id,
|
||||
'name': storage_path.name,
|
||||
'path': storage_path.path,
|
||||
'confidence': confidence,
|
||||
data["storage_path"] = {
|
||||
"id": storage_path.id,
|
||||
"name": storage_path.name,
|
||||
"path": storage_path.path,
|
||||
"confidence": confidence,
|
||||
}
|
||||
except StoragePath.DoesNotExist:
|
||||
# Storage path no longer exists in database; omit from suggestions
|
||||
|
|
@ -193,15 +190,15 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
try:
|
||||
field = CustomField.objects.get(pk=field_id)
|
||||
field_suggestions.append({
|
||||
'field_id': field.id,
|
||||
'field_name': field.name,
|
||||
'value': str(value),
|
||||
'confidence': confidence,
|
||||
"field_id": field.id,
|
||||
"field_name": field.name,
|
||||
"value": str(value),
|
||||
"confidence": confidence,
|
||||
})
|
||||
except CustomField.DoesNotExist:
|
||||
# Custom field no longer exists in database; skip this suggestion
|
||||
pass
|
||||
data['custom_fields'] = field_suggestions
|
||||
data["custom_fields"] = field_suggestions
|
||||
|
||||
# Workflows
|
||||
if scan_result.workflows:
|
||||
|
|
@ -210,19 +207,19 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
try:
|
||||
workflow = Workflow.objects.get(pk=workflow_id)
|
||||
workflow_suggestions.append({
|
||||
'id': workflow.id,
|
||||
'name': workflow.name,
|
||||
'confidence': confidence,
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"confidence": confidence,
|
||||
})
|
||||
except Workflow.DoesNotExist:
|
||||
# Workflow no longer exists in database; skip this suggestion
|
||||
pass
|
||||
data['workflows'] = workflow_suggestions
|
||||
data["workflows"] = workflow_suggestions
|
||||
|
||||
# Title suggestion
|
||||
if scan_result.title_suggestion:
|
||||
data['title_suggestion'] = {
|
||||
'title': scan_result.title_suggestion,
|
||||
data["title_suggestion"] = {
|
||||
"title": scan_result.title_suggestion,
|
||||
}
|
||||
|
||||
return data
|
||||
|
|
@ -234,26 +231,26 @@ class SuggestionSerializerMixin:
|
|||
"""
|
||||
def validate(self, attrs):
|
||||
"""Validate that the correct value field is provided for the suggestion type."""
|
||||
suggestion_type = attrs.get('suggestion_type')
|
||||
value_id = attrs.get('value_id')
|
||||
value_text = attrs.get('value_text')
|
||||
suggestion_type = attrs.get("suggestion_type")
|
||||
value_id = attrs.get("value_id")
|
||||
value_text = attrs.get("value_text")
|
||||
|
||||
# Types that require value_id
|
||||
if suggestion_type in ID_REQUIRED_TYPES and not value_id:
|
||||
raise serializers.ValidationError(
|
||||
f"value_id is required for suggestion_type '{suggestion_type}'"
|
||||
f"value_id is required for suggestion_type '{suggestion_type}'",
|
||||
)
|
||||
|
||||
# Types that require value_text
|
||||
if suggestion_type in TEXT_REQUIRED_TYPES and not value_text:
|
||||
raise serializers.ValidationError(
|
||||
f"value_text is required for suggestion_type '{suggestion_type}'"
|
||||
f"value_text is required for suggestion_type '{suggestion_type}'",
|
||||
)
|
||||
|
||||
# For custom_field, either is acceptable
|
||||
if suggestion_type == 'custom_field' and not value_id and not value_text:
|
||||
if suggestion_type == "custom_field" and not value_id and not value_text:
|
||||
raise serializers.ValidationError(
|
||||
"Either value_id or value_text must be provided for custom_field"
|
||||
"Either value_id or value_text must be provided for custom_field",
|
||||
)
|
||||
|
||||
return attrs
|
||||
|
|
@ -295,19 +292,19 @@ class AISuggestionFeedbackSerializer(serializers.ModelSerializer):
|
|||
class Meta:
|
||||
model = AISuggestionFeedback
|
||||
fields = [
|
||||
'id',
|
||||
'document',
|
||||
'suggestion_type',
|
||||
'suggested_value_id',
|
||||
'suggested_value_text',
|
||||
'confidence',
|
||||
'status',
|
||||
'user',
|
||||
'created_at',
|
||||
'applied_at',
|
||||
'metadata',
|
||||
"id",
|
||||
"document",
|
||||
"suggestion_type",
|
||||
"suggested_value_id",
|
||||
"suggested_value_text",
|
||||
"confidence",
|
||||
"status",
|
||||
"user",
|
||||
"created_at",
|
||||
"applied_at",
|
||||
"metadata",
|
||||
]
|
||||
read_only_fields = ['id', 'created_at', 'applied_at']
|
||||
read_only_fields = ["id", "created_at", "applied_at"]
|
||||
|
||||
|
||||
class AISuggestionStatsSerializer(serializers.Serializer):
|
||||
|
|
|
|||
|
|
@ -18,13 +18,11 @@ from django.test import TestCase
|
|||
from django.utils import timezone
|
||||
|
||||
from documents.ai_deletion_manager import AIDeletionManager
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
|
||||
|
||||
class TestAIDeletionManagerCreateRequest(TestCase):
|
||||
|
|
|
|||
|
|
@ -10,24 +10,23 @@ Tests cover:
|
|||
- Permission assignment and verification
|
||||
"""
|
||||
|
||||
from django.contrib.auth.models import Group, Permission, User
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.test import TestCase
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import (
|
||||
CanApplyAISuggestionsPermission,
|
||||
CanApproveDeletionsPermission,
|
||||
CanConfigureAIPermission,
|
||||
CanViewAISuggestionsPermission,
|
||||
)
|
||||
from documents.permissions import CanApplyAISuggestionsPermission
|
||||
from documents.permissions import CanApproveDeletionsPermission
|
||||
from documents.permissions import CanConfigureAIPermission
|
||||
from documents.permissions import CanViewAISuggestionsPermission
|
||||
|
||||
|
||||
class MockView:
|
||||
"""Mock view for testing permissions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestCanViewAISuggestionsPermission(TestCase):
|
||||
|
|
@ -41,13 +40,13 @@ class TestCanViewAISuggestionsPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -107,13 +106,13 @@ class TestCanApplyAISuggestionsPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -173,13 +172,13 @@ class TestCanApproveDeletionsPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -239,13 +238,13 @@ class TestCanConfigureAIPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -345,7 +344,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_viewer_role_permissions(self):
|
||||
"""Test that viewer role has appropriate permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="viewer", email="viewer@test.com", password="viewer123"
|
||||
username="viewer", email="viewer@test.com", password="viewer123",
|
||||
)
|
||||
user.groups.add(self.viewer_group)
|
||||
|
||||
|
|
@ -360,7 +359,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_editor_role_permissions(self):
|
||||
"""Test that editor role has appropriate permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="editor", email="editor@test.com", password="editor123"
|
||||
username="editor", email="editor@test.com", password="editor123",
|
||||
)
|
||||
user.groups.add(self.editor_group)
|
||||
|
||||
|
|
@ -375,7 +374,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_admin_role_permissions(self):
|
||||
"""Test that admin role has all permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="ai_admin", email="ai_admin@test.com", password="admin123"
|
||||
username="ai_admin", email="ai_admin@test.com", password="admin123",
|
||||
)
|
||||
user.groups.add(self.admin_group)
|
||||
|
||||
|
|
@ -390,7 +389,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_user_with_multiple_groups(self):
|
||||
"""Test that user permissions accumulate from multiple groups."""
|
||||
user = User.objects.create_user(
|
||||
username="multi_role", email="multi@test.com", password="multi123"
|
||||
username="multi_role", email="multi@test.com", password="multi123",
|
||||
)
|
||||
user.groups.add(self.viewer_group, self.editor_group)
|
||||
|
||||
|
|
@ -405,7 +404,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_direct_permission_assignment_overrides_group(self):
|
||||
"""Test that direct permission assignment works alongside group permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="special", email="special@test.com", password="special123"
|
||||
username="special", email="special@test.com", password="special123",
|
||||
)
|
||||
user.groups.add(self.viewer_group)
|
||||
|
||||
|
|
@ -428,7 +427,7 @@ class TestPermissionAssignment(TestCase):
|
|||
def setUp(self):
|
||||
"""Set up test user."""
|
||||
self.user = User.objects.create_user(
|
||||
username="testuser", email="test@test.com", password="test123"
|
||||
username="testuser", email="test@test.com", password="test123",
|
||||
)
|
||||
content_type = ContentType.objects.get_for_model(Document)
|
||||
self.view_permission, _ = Permission.objects.get_or_create(
|
||||
|
|
@ -500,7 +499,7 @@ class TestPermissionEdgeCases(TestCase):
|
|||
def test_inactive_user_with_permission(self):
|
||||
"""Test that inactive users are denied even with permission."""
|
||||
user = User.objects.create_user(
|
||||
username="inactive", email="inactive@test.com", password="inactive123"
|
||||
username="inactive", email="inactive@test.com", password="inactive123",
|
||||
)
|
||||
user.is_active = False
|
||||
user.save()
|
||||
|
|
|
|||
|
|
@ -21,24 +21,21 @@ Tests cover:
|
|||
from unittest import mock
|
||||
|
||||
from django.db import transaction
|
||||
from django.test import TestCase, override_settings
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.ai_scanner import (
|
||||
AIScanResult,
|
||||
AIDocumentScanner,
|
||||
get_ai_scanner,
|
||||
)
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
CustomField,
|
||||
Document,
|
||||
DocumentType,
|
||||
StoragePath,
|
||||
Tag,
|
||||
Workflow,
|
||||
WorkflowTrigger,
|
||||
WorkflowAction,
|
||||
)
|
||||
from documents.ai_scanner import AIDocumentScanner
|
||||
from documents.ai_scanner import AIScanResult
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
from documents.models import WorkflowAction
|
||||
from documents.models import WorkflowTrigger
|
||||
|
||||
|
||||
class TestAIScanResult(TestCase):
|
||||
|
|
@ -100,7 +97,7 @@ class TestAIDocumentScannerInitialization(TestCase):
|
|||
"""Test scanner initialization with custom confidence thresholds."""
|
||||
scanner = AIDocumentScanner(
|
||||
auto_apply_threshold=0.90,
|
||||
suggest_threshold=0.70
|
||||
suggest_threshold=0.70,
|
||||
)
|
||||
|
||||
self.assertEqual(scanner.auto_apply_threshold, 0.90)
|
||||
|
|
@ -145,14 +142,14 @@ class TestAIDocumentScannerInitialization(TestCase):
|
|||
class TestAIDocumentScannerLazyLoading(TestCase):
|
||||
"""Test lazy loading of ML components."""
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_get_classifier_loads_successfully(self, mock_logger):
|
||||
"""Test successful lazy loading of classifier."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
# Mock the import and class
|
||||
mock_classifier_instance = mock.MagicMock()
|
||||
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
|
||||
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
|
||||
return_value=mock_classifier_instance) as mock_classifier_class:
|
||||
classifier = scanner._get_classifier()
|
||||
|
||||
|
|
@ -161,13 +158,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
|||
mock_classifier_class.assert_called_once()
|
||||
mock_logger.info.assert_called_with("ML classifier loaded successfully")
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_get_classifier_returns_cached_instance(self, mock_logger):
|
||||
"""Test that classifier is only loaded once."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
mock_classifier_instance = mock.MagicMock()
|
||||
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
|
||||
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
|
||||
return_value=mock_classifier_instance):
|
||||
classifier1 = scanner._get_classifier()
|
||||
classifier2 = scanner._get_classifier()
|
||||
|
|
@ -175,12 +172,12 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
|||
self.assertEqual(classifier1, classifier2)
|
||||
self.assertIs(classifier1, classifier2)
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_get_classifier_handles_import_error(self, mock_logger):
|
||||
"""Test that classifier loading handles import errors gracefully."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
|
||||
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
|
||||
side_effect=ImportError("Module not found")):
|
||||
classifier = scanner._get_classifier()
|
||||
|
||||
|
|
@ -196,13 +193,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
|||
|
||||
self.assertIsNone(classifier)
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_get_ner_extractor_loads_successfully(self, mock_logger):
|
||||
"""Test successful lazy loading of NER extractor."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
mock_ner_instance = mock.MagicMock()
|
||||
with mock.patch('documents.ai_scanner.DocumentNER',
|
||||
with mock.patch("documents.ai_scanner.DocumentNER",
|
||||
return_value=mock_ner_instance) as mock_ner_class:
|
||||
ner = scanner._get_ner_extractor()
|
||||
|
||||
|
|
@ -211,25 +208,25 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
|||
mock_ner_class.assert_called_once()
|
||||
mock_logger.info.assert_called_with("NER extractor loaded successfully")
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_get_ner_extractor_handles_error(self, mock_logger):
|
||||
"""Test NER extractor handles loading errors."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
with mock.patch('documents.ai_scanner.DocumentNER',
|
||||
with mock.patch("documents.ai_scanner.DocumentNER",
|
||||
side_effect=Exception("Failed to load")):
|
||||
ner = scanner._get_ner_extractor()
|
||||
|
||||
self.assertIsNone(ner)
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_get_semantic_search_loads_successfully(self, mock_logger):
|
||||
"""Test successful lazy loading of semantic search."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
mock_search_instance = mock.MagicMock()
|
||||
with mock.patch('documents.ai_scanner.SemanticSearch',
|
||||
with mock.patch("documents.ai_scanner.SemanticSearch",
|
||||
return_value=mock_search_instance) as mock_search_class:
|
||||
search = scanner._get_semantic_search()
|
||||
|
||||
|
|
@ -238,13 +235,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
|||
mock_search_class.assert_called_once()
|
||||
mock_logger.info.assert_called_with("Semantic search loaded successfully")
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_get_table_extractor_loads_successfully(self, mock_logger):
|
||||
"""Test successful lazy loading of table extractor."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
mock_extractor_instance = mock.MagicMock()
|
||||
with mock.patch('documents.ai_scanner.TableExtractor',
|
||||
with mock.patch("documents.ai_scanner.TableExtractor",
|
||||
return_value=mock_extractor_instance) as mock_extractor_class:
|
||||
extractor = scanner._get_table_extractor()
|
||||
|
||||
|
|
@ -276,7 +273,7 @@ class TestExtractEntities(TestCase):
|
|||
"dates": ["2024-01-01", "2024-12-31"],
|
||||
"amounts": ["$1,000", "$500"],
|
||||
"locations": ["New York"],
|
||||
"misc": ["Invoice#123"]
|
||||
"misc": ["Invoice#123"],
|
||||
}
|
||||
|
||||
scanner._ner_extractor = mock_ner
|
||||
|
|
@ -320,7 +317,7 @@ class TestExtractEntities(TestCase):
|
|||
|
||||
self.assertEqual(entities, {})
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_extract_entities_handles_exception(self, mock_logger):
|
||||
"""Test that entity extraction handles exceptions gracefully."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -345,10 +342,10 @@ class TestSuggestTags(TestCase):
|
|||
self.tag3 = Tag.objects.create(name="Tax", matching_algorithm=Tag.MATCH_AUTO)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
def test_suggest_tags_with_matched_tags(self, mock_match_tags):
|
||||
"""Test tag suggestions from matching."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -357,7 +354,7 @@ class TestSuggestTags(TestCase):
|
|||
suggestions = scanner._suggest_tags(
|
||||
self.document,
|
||||
"Invoice from ACME Corp",
|
||||
{}
|
||||
{},
|
||||
)
|
||||
|
||||
# Should suggest both matched tags
|
||||
|
|
@ -370,14 +367,14 @@ class TestSuggestTags(TestCase):
|
|||
for _, confidence in suggestions:
|
||||
self.assertGreaterEqual(confidence, 0.6)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
def test_suggest_tags_with_organization_entities(self, mock_match_tags):
|
||||
"""Test tag suggestions based on organization entities."""
|
||||
scanner = AIDocumentScanner()
|
||||
mock_match_tags.return_value = []
|
||||
|
||||
entities = {
|
||||
"organizations": [{"text": "ACME Corp"}]
|
||||
"organizations": [{"text": "ACME Corp"}],
|
||||
}
|
||||
|
||||
suggestions = scanner._suggest_tags(self.document, "text", entities)
|
||||
|
|
@ -386,7 +383,7 @@ class TestSuggestTags(TestCase):
|
|||
tag_ids = [tag_id for tag_id, _ in suggestions]
|
||||
self.assertIn(self.tag2.id, tag_ids)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
def test_suggest_tags_removes_duplicates(self, mock_match_tags):
|
||||
"""Test that duplicate tags keep highest confidence."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -397,8 +394,8 @@ class TestSuggestTags(TestCase):
|
|||
|
||||
# Implementation should remove duplicates in actual code
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_suggest_tags_handles_exception(self, mock_logger, mock_match_tags):
|
||||
"""Test tag suggestion handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -417,18 +414,18 @@ class TestDetectCorrespondent(TestCase):
|
|||
"""Set up test correspondents."""
|
||||
self.correspondent1 = Correspondent.objects.create(
|
||||
name="ACME Corporation",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
)
|
||||
self.correspondent2 = Correspondent.objects.create(
|
||||
name="TechStart Inc",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
||||
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||
def test_detect_correspondent_with_match(self, mock_match):
|
||||
"""Test correspondent detection with successful match."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -441,7 +438,7 @@ class TestDetectCorrespondent(TestCase):
|
|||
self.assertEqual(corr_id, self.correspondent1.id)
|
||||
self.assertEqual(confidence, 0.85)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
||||
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||
def test_detect_correspondent_without_match(self, mock_match):
|
||||
"""Test correspondent detection without match."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -451,14 +448,14 @@ class TestDetectCorrespondent(TestCase):
|
|||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
||||
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||
def test_detect_correspondent_from_ner_entities(self, mock_match):
|
||||
"""Test correspondent detection from NER organizations."""
|
||||
scanner = AIDocumentScanner()
|
||||
mock_match.return_value = []
|
||||
|
||||
entities = {
|
||||
"organizations": [{"text": "ACME Corporation"}]
|
||||
"organizations": [{"text": "ACME Corporation"}],
|
||||
}
|
||||
|
||||
result = scanner._detect_correspondent(self.document, "text", entities)
|
||||
|
|
@ -468,8 +465,8 @@ class TestDetectCorrespondent(TestCase):
|
|||
self.assertEqual(corr_id, self.correspondent1.id)
|
||||
self.assertEqual(confidence, 0.70)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_detect_correspondent_handles_exception(self, mock_logger, mock_match):
|
||||
"""Test correspondent detection handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -488,18 +485,18 @@ class TestClassifyDocumentType(TestCase):
|
|||
"""Set up test document types."""
|
||||
self.doc_type1 = DocumentType.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=DocumentType.MATCH_AUTO
|
||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||
)
|
||||
self.doc_type2 = DocumentType.objects.create(
|
||||
name="Receipt",
|
||||
matching_algorithm=DocumentType.MATCH_AUTO
|
||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||
)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_document_types')
|
||||
@mock.patch("documents.ai_scanner.match_document_types")
|
||||
def test_classify_document_type_with_match(self, mock_match):
|
||||
"""Test document type classification with match."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -512,7 +509,7 @@ class TestClassifyDocumentType(TestCase):
|
|||
self.assertEqual(type_id, self.doc_type1.id)
|
||||
self.assertEqual(confidence, 0.85)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_document_types')
|
||||
@mock.patch("documents.ai_scanner.match_document_types")
|
||||
def test_classify_document_type_without_match(self, mock_match):
|
||||
"""Test document type classification without match."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -522,8 +519,8 @@ class TestClassifyDocumentType(TestCase):
|
|||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_document_types')
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.match_document_types")
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_classify_document_type_handles_exception(self, mock_logger, mock_match):
|
||||
"""Test classification handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -543,14 +540,14 @@ class TestSuggestStoragePath(TestCase):
|
|||
self.storage_path1 = StoragePath.objects.create(
|
||||
name="Invoices",
|
||||
path="/documents/invoices",
|
||||
matching_algorithm=StoragePath.MATCH_AUTO
|
||||
matching_algorithm=StoragePath.MATCH_AUTO,
|
||||
)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
||||
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||
def test_suggest_storage_path_with_match(self, mock_match):
|
||||
"""Test storage path suggestion with match."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -564,7 +561,7 @@ class TestSuggestStoragePath(TestCase):
|
|||
self.assertEqual(path_id, self.storage_path1.id)
|
||||
self.assertEqual(confidence, 0.80)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
||||
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||
def test_suggest_storage_path_without_match(self, mock_match):
|
||||
"""Test storage path suggestion without match."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -575,8 +572,8 @@ class TestSuggestStoragePath(TestCase):
|
|||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_suggest_storage_path_handles_exception(self, mock_logger, mock_match):
|
||||
"""Test storage path suggestion handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -596,19 +593,19 @@ class TestExtractCustomFields(TestCase):
|
|||
"""Set up test custom fields."""
|
||||
self.field_date = CustomField.objects.create(
|
||||
name="Invoice Date",
|
||||
data_type=CustomField.FieldDataType.DATE
|
||||
data_type=CustomField.FieldDataType.DATE,
|
||||
)
|
||||
self.field_amount = CustomField.objects.create(
|
||||
name="Total Amount",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_email = CustomField.objects.create(
|
||||
name="Contact Email",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
def test_extract_custom_fields_with_entities(self):
|
||||
|
|
@ -618,7 +615,7 @@ class TestExtractCustomFields(TestCase):
|
|||
entities = {
|
||||
"dates": [{"text": "2024-01-01"}],
|
||||
"amounts": [{"text": "$1,000"}],
|
||||
"emails": ["test@example.com"]
|
||||
"emails": ["test@example.com"],
|
||||
}
|
||||
|
||||
fields = scanner._extract_custom_fields(self.document, "text", entities)
|
||||
|
|
@ -638,12 +635,12 @@ class TestExtractCustomFields(TestCase):
|
|||
# Should return empty dict
|
||||
self.assertEqual(fields, {})
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_extract_custom_fields_handles_exception(self, mock_logger):
|
||||
"""Test custom field extraction handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
with mock.patch.object(CustomField.objects, 'all',
|
||||
with mock.patch.object(CustomField.objects, "all",
|
||||
side_effect=Exception("DB error")):
|
||||
fields = scanner._extract_custom_fields(self.document, "text", {})
|
||||
|
||||
|
|
@ -658,31 +655,31 @@ class TestExtractFieldValue(TestCase):
|
|||
"""Set up test fields."""
|
||||
self.field_date = CustomField.objects.create(
|
||||
name="Invoice Date",
|
||||
data_type=CustomField.FieldDataType.DATE
|
||||
data_type=CustomField.FieldDataType.DATE,
|
||||
)
|
||||
self.field_amount = CustomField.objects.create(
|
||||
name="Total Amount",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_invoice = CustomField.objects.create(
|
||||
name="Invoice Number",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_email = CustomField.objects.create(
|
||||
name="Contact Email",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_phone = CustomField.objects.create(
|
||||
name="Phone Number",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_person = CustomField.objects.create(
|
||||
name="Person Name",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_company = CustomField.objects.create(
|
||||
name="Company Name",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
|
||||
def test_extract_field_value_date(self):
|
||||
|
|
@ -691,7 +688,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {"dates": [{"text": "2024-01-01"}]}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_date, "text", entities
|
||||
self.field_date, "text", entities,
|
||||
)
|
||||
|
||||
self.assertEqual(value, "2024-01-01")
|
||||
|
|
@ -703,7 +700,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {"amounts": [{"text": "$1,000"}]}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_amount, "text", entities
|
||||
self.field_amount, "text", entities,
|
||||
)
|
||||
|
||||
self.assertEqual(value, "$1,000")
|
||||
|
|
@ -715,7 +712,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {"invoice_numbers": ["INV-12345"]}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_invoice, "text", entities
|
||||
self.field_invoice, "text", entities,
|
||||
)
|
||||
|
||||
self.assertEqual(value, "INV-12345")
|
||||
|
|
@ -727,7 +724,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {"emails": ["test@example.com"]}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_email, "text", entities
|
||||
self.field_email, "text", entities,
|
||||
)
|
||||
|
||||
self.assertEqual(value, "test@example.com")
|
||||
|
|
@ -739,7 +736,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {"phones": ["+1-555-1234"]}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_phone, "text", entities
|
||||
self.field_phone, "text", entities,
|
||||
)
|
||||
|
||||
self.assertEqual(value, "+1-555-1234")
|
||||
|
|
@ -751,7 +748,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {"persons": [{"text": "John Doe"}]}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_person, "text", entities
|
||||
self.field_person, "text", entities,
|
||||
)
|
||||
|
||||
self.assertEqual(value, "John Doe")
|
||||
|
|
@ -763,7 +760,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {"organizations": [{"text": "ACME Corp"}]}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_company, "text", entities
|
||||
self.field_company, "text", entities,
|
||||
)
|
||||
|
||||
self.assertEqual(value, "ACME Corp")
|
||||
|
|
@ -775,7 +772,7 @@ class TestExtractFieldValue(TestCase):
|
|||
entities = {}
|
||||
|
||||
value, confidence = scanner._extract_field_value(
|
||||
self.field_date, "text", entities
|
||||
self.field_date, "text", entities,
|
||||
)
|
||||
|
||||
self.assertIsNone(value)
|
||||
|
|
@ -789,23 +786,23 @@ class TestSuggestWorkflows(TestCase):
|
|||
"""Set up test workflows."""
|
||||
self.workflow1 = Workflow.objects.create(
|
||||
name="Invoice Processing",
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
self.trigger1 = WorkflowTrigger.objects.create(
|
||||
workflow=self.workflow1,
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||
)
|
||||
self.workflow2 = Workflow.objects.create(
|
||||
name="Document Archival",
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
self.trigger2 = WorkflowTrigger.objects.create(
|
||||
workflow=self.workflow2,
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||
)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
def test_suggest_workflows_with_matches(self):
|
||||
|
|
@ -820,7 +817,7 @@ class TestSuggestWorkflows(TestCase):
|
|||
# Create action for workflow
|
||||
WorkflowAction.objects.create(
|
||||
workflow=self.workflow1,
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||
)
|
||||
|
||||
suggestions = scanner._suggest_workflows(self.document, "text", scan_result)
|
||||
|
|
@ -841,17 +838,17 @@ class TestSuggestWorkflows(TestCase):
|
|||
# Should not suggest any (confidence too low)
|
||||
self.assertEqual(len(suggestions), 0)
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_suggest_workflows_handles_exception(self, mock_logger):
|
||||
"""Test workflow suggestion handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
scan_result = AIScanResult()
|
||||
|
||||
with mock.patch.object(Workflow.objects, 'filter',
|
||||
with mock.patch.object(Workflow.objects, "filter",
|
||||
side_effect=Exception("DB error")):
|
||||
suggestions = scanner._suggest_workflows(
|
||||
self.document, "text", scan_result
|
||||
self.document, "text", scan_result,
|
||||
)
|
||||
|
||||
self.assertEqual(suggestions, [])
|
||||
|
|
@ -865,11 +862,11 @@ class TestEvaluateWorkflowMatch(TestCase):
|
|||
"""Set up test workflow."""
|
||||
self.workflow = Workflow.objects.create(
|
||||
name="Test Workflow",
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
def test_evaluate_workflow_match_base_confidence(self):
|
||||
|
|
@ -878,7 +875,7 @@ class TestEvaluateWorkflowMatch(TestCase):
|
|||
scan_result = AIScanResult()
|
||||
|
||||
confidence = scanner._evaluate_workflow_match(
|
||||
self.workflow, self.document, scan_result
|
||||
self.workflow, self.document, scan_result,
|
||||
)
|
||||
|
||||
self.assertEqual(confidence, 0.5)
|
||||
|
|
@ -892,11 +889,11 @@ class TestEvaluateWorkflowMatch(TestCase):
|
|||
# Create action for workflow
|
||||
WorkflowAction.objects.create(
|
||||
workflow=self.workflow,
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||
)
|
||||
|
||||
confidence = scanner._evaluate_workflow_match(
|
||||
self.workflow, self.document, scan_result
|
||||
self.workflow, self.document, scan_result,
|
||||
)
|
||||
|
||||
self.assertGreater(confidence, 0.5)
|
||||
|
|
@ -908,7 +905,7 @@ class TestEvaluateWorkflowMatch(TestCase):
|
|||
scan_result.correspondent = (1, 0.90)
|
||||
|
||||
confidence = scanner._evaluate_workflow_match(
|
||||
self.workflow, self.document, scan_result
|
||||
self.workflow, self.document, scan_result,
|
||||
)
|
||||
|
||||
self.assertGreater(confidence, 0.5)
|
||||
|
|
@ -920,7 +917,7 @@ class TestEvaluateWorkflowMatch(TestCase):
|
|||
scan_result.tags = [(1, 0.80), (2, 0.75)]
|
||||
|
||||
confidence = scanner._evaluate_workflow_match(
|
||||
self.workflow, self.document, scan_result
|
||||
self.workflow, self.document, scan_result,
|
||||
)
|
||||
|
||||
self.assertGreater(confidence, 0.5)
|
||||
|
|
@ -936,11 +933,11 @@ class TestEvaluateWorkflowMatch(TestCase):
|
|||
# Create action
|
||||
WorkflowAction.objects.create(
|
||||
workflow=self.workflow,
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||
)
|
||||
|
||||
confidence = scanner._evaluate_workflow_match(
|
||||
self.workflow, self.document, scan_result
|
||||
self.workflow, self.document, scan_result,
|
||||
)
|
||||
|
||||
self.assertLessEqual(confidence, 1.0)
|
||||
|
|
@ -953,7 +950,7 @@ class TestSuggestTitle(TestCase):
|
|||
"""Set up test document."""
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
def test_suggest_title_with_all_entities(self):
|
||||
|
|
@ -963,7 +960,7 @@ class TestSuggestTitle(TestCase):
|
|||
entities = {
|
||||
"document_type": "Invoice",
|
||||
"organizations": [{"text": "ACME Corporation"}],
|
||||
"dates": [{"text": "2024-01-01"}]
|
||||
"dates": [{"text": "2024-01-01"}],
|
||||
}
|
||||
|
||||
title = scanner._suggest_title(self.document, "text", entities)
|
||||
|
|
@ -978,7 +975,7 @@ class TestSuggestTitle(TestCase):
|
|||
scanner = AIDocumentScanner()
|
||||
|
||||
entities = {
|
||||
"organizations": [{"text": "TechStart Inc"}]
|
||||
"organizations": [{"text": "TechStart Inc"}],
|
||||
}
|
||||
|
||||
title = scanner._suggest_title(self.document, "text", entities)
|
||||
|
|
@ -1002,7 +999,7 @@ class TestSuggestTitle(TestCase):
|
|||
long_org = "A" * 100
|
||||
entities = {
|
||||
"organizations": [{"text": long_org}],
|
||||
"dates": [{"text": "2024-01-01"}]
|
||||
"dates": [{"text": "2024-01-01"}],
|
||||
}
|
||||
|
||||
title = scanner._suggest_title(self.document, "text", entities)
|
||||
|
|
@ -1010,7 +1007,7 @@ class TestSuggestTitle(TestCase):
|
|||
self.assertIsNotNone(title)
|
||||
self.assertLessEqual(len(title), 127)
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_suggest_title_handles_exception(self, mock_logger):
|
||||
"""Test title suggestion handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -1034,7 +1031,7 @@ class TestExtractTables(TestCase):
|
|||
|
||||
mock_extractor = mock.MagicMock()
|
||||
mock_extractor.extract_tables_from_image.return_value = [
|
||||
{"data": [[1, 2], [3, 4]], "headers": ["A", "B"]}
|
||||
{"data": [[1, 2], [3, 4]], "headers": ["A", "B"]},
|
||||
]
|
||||
scanner._table_extractor = mock_extractor
|
||||
|
||||
|
|
@ -1053,7 +1050,7 @@ class TestExtractTables(TestCase):
|
|||
|
||||
self.assertEqual(tables, [])
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_extract_tables_handles_exception(self, mock_logger):
|
||||
"""Test table extraction handles exceptions."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -1075,17 +1072,17 @@ class TestScanDocument(TestCase):
|
|||
"""Set up test document."""
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Invoice from ACME Corporation dated 2024-01-01"
|
||||
content="Invoice from ACME Corporation dated 2024-01-01",
|
||||
)
|
||||
|
||||
@mock.patch.object(AIDocumentScanner, '_extract_entities')
|
||||
@mock.patch.object(AIDocumentScanner, '_suggest_tags')
|
||||
@mock.patch.object(AIDocumentScanner, '_detect_correspondent')
|
||||
@mock.patch.object(AIDocumentScanner, '_classify_document_type')
|
||||
@mock.patch.object(AIDocumentScanner, '_suggest_storage_path')
|
||||
@mock.patch.object(AIDocumentScanner, '_extract_custom_fields')
|
||||
@mock.patch.object(AIDocumentScanner, '_suggest_workflows')
|
||||
@mock.patch.object(AIDocumentScanner, '_suggest_title')
|
||||
@mock.patch.object(AIDocumentScanner, "_extract_entities")
|
||||
@mock.patch.object(AIDocumentScanner, "_suggest_tags")
|
||||
@mock.patch.object(AIDocumentScanner, "_detect_correspondent")
|
||||
@mock.patch.object(AIDocumentScanner, "_classify_document_type")
|
||||
@mock.patch.object(AIDocumentScanner, "_suggest_storage_path")
|
||||
@mock.patch.object(AIDocumentScanner, "_extract_custom_fields")
|
||||
@mock.patch.object(AIDocumentScanner, "_suggest_workflows")
|
||||
@mock.patch.object(AIDocumentScanner, "_suggest_title")
|
||||
def test_scan_document_orchestrates_all_methods(
|
||||
self,
|
||||
mock_title,
|
||||
|
|
@ -1095,7 +1092,7 @@ class TestScanDocument(TestCase):
|
|||
mock_doc_type,
|
||||
mock_correspondent,
|
||||
mock_tags,
|
||||
mock_entities
|
||||
mock_entities,
|
||||
):
|
||||
"""Test that scan_document calls all extraction methods."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -1127,26 +1124,26 @@ class TestScanDocument(TestCase):
|
|||
self.assertEqual(result.correspondent, (1, 0.90))
|
||||
self.assertEqual(result.document_type, (1, 0.80))
|
||||
|
||||
@mock.patch.object(AIDocumentScanner, '_extract_tables')
|
||||
@mock.patch.object(AIDocumentScanner, "_extract_tables")
|
||||
def test_scan_document_extracts_tables_when_enabled(self, mock_extract_tables):
|
||||
"""Test that tables are extracted when OCR is enabled and file path provided."""
|
||||
scanner = AIDocumentScanner(enable_advanced_ocr=True)
|
||||
mock_extract_tables.return_value = [{"data": "test"}]
|
||||
|
||||
# Mock other methods to avoid complexity
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(
|
||||
self.document,
|
||||
"Document text",
|
||||
original_file_path="/path/to/file.pdf"
|
||||
original_file_path="/path/to/file.pdf",
|
||||
)
|
||||
|
||||
mock_extract_tables.assert_called_once_with("/path/to/file.pdf")
|
||||
|
|
@ -1156,15 +1153,15 @@ class TestScanDocument(TestCase):
|
|||
"""Test that tables are not extracted when file path is not provided."""
|
||||
scanner = AIDocumentScanner(enable_advanced_ocr=True)
|
||||
|
||||
with mock.patch.object(scanner, '_extract_tables') as mock_extract_tables, \
|
||||
mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_tables") as mock_extract_tables, \
|
||||
mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(self.document, "Document text")
|
||||
|
||||
|
|
@ -1183,11 +1180,11 @@ class TestApplyScanResults(TestCase):
|
|||
self.doc_type = DocumentType.objects.create(name="Invoice")
|
||||
self.storage_path = StoragePath.objects.create(
|
||||
name="Invoices",
|
||||
path="/invoices"
|
||||
path="/invoices",
|
||||
)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
def test_apply_scan_results_auto_applies_high_confidence(self):
|
||||
|
|
@ -1203,7 +1200,7 @@ class TestApplyScanResults(TestCase):
|
|||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
# Verify auto-applied
|
||||
|
|
@ -1222,7 +1219,7 @@ class TestApplyScanResults(TestCase):
|
|||
"""Test that medium confidence items are suggested, not applied."""
|
||||
scanner = AIDocumentScanner(
|
||||
auto_apply_threshold=0.80,
|
||||
suggest_threshold=0.60
|
||||
suggest_threshold=0.60,
|
||||
)
|
||||
|
||||
scan_result = AIScanResult()
|
||||
|
|
@ -1232,7 +1229,7 @@ class TestApplyScanResults(TestCase):
|
|||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
# Verify suggested but not applied
|
||||
|
|
@ -1255,7 +1252,7 @@ class TestApplyScanResults(TestCase):
|
|||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=False
|
||||
auto_apply=False,
|
||||
)
|
||||
|
||||
# Verify nothing was applied
|
||||
|
|
@ -1268,21 +1265,21 @@ class TestApplyScanResults(TestCase):
|
|||
scan_result = AIScanResult()
|
||||
scan_result.correspondent = (self.correspondent.id, 0.90)
|
||||
|
||||
with mock.patch.object(self.document, 'save',
|
||||
with mock.patch.object(self.document, "save",
|
||||
side_effect=Exception("Save failed")):
|
||||
with self.assertRaises(Exception):
|
||||
with transaction.atomic():
|
||||
scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
# Verify transaction was rolled back
|
||||
self.document.refresh_from_db()
|
||||
self.assertIsNone(self.document.correspondent)
|
||||
|
||||
@mock.patch('documents.ai_scanner.logger')
|
||||
@mock.patch("documents.ai_scanner.logger")
|
||||
def test_apply_scan_results_handles_exception(self, mock_logger):
|
||||
"""Test that apply_scan_results handles exceptions gracefully."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
|
@ -1293,7 +1290,7 @@ class TestApplyScanResults(TestCase):
|
|||
scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
mock_logger.error.assert_called()
|
||||
|
|
@ -1323,21 +1320,21 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
|||
"""Set up test document."""
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
def test_scan_document_with_empty_text(self):
|
||||
"""Test scanning document with empty text."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(self.document, "")
|
||||
|
||||
|
|
@ -1349,14 +1346,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
|||
scanner = AIDocumentScanner()
|
||||
long_text = "A" * 100000
|
||||
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(self.document, long_text)
|
||||
|
||||
|
|
@ -1367,14 +1364,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
|||
scanner = AIDocumentScanner()
|
||||
special_text = "Test with émojis 😀 and special chars: <>{}[]|\\`~"
|
||||
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(self.document, special_text)
|
||||
|
||||
|
|
@ -1388,7 +1385,7 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
|||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
self.assertEqual(result["applied"]["tags"], [])
|
||||
|
|
@ -1403,12 +1400,12 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
|||
# Test extreme values
|
||||
scanner_low = AIDocumentScanner(
|
||||
auto_apply_threshold=0.01,
|
||||
suggest_threshold=0.01
|
||||
suggest_threshold=0.01,
|
||||
)
|
||||
self.assertEqual(scanner_low.auto_apply_threshold, 0.01)
|
||||
|
||||
scanner_high = AIDocumentScanner(
|
||||
auto_apply_threshold=0.99,
|
||||
suggest_threshold=0.80
|
||||
suggest_threshold=0.80,
|
||||
)
|
||||
self.assertEqual(scanner_high.auto_apply_threshold, 0.99)
|
||||
|
|
|
|||
|
|
@ -8,24 +8,21 @@ document consumption to metadata application.
|
|||
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from django.test import TestCase
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from documents.ai_scanner import (
|
||||
AIDocumentScanner,
|
||||
AIScanResult,
|
||||
get_ai_scanner,
|
||||
)
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
CustomField,
|
||||
Document,
|
||||
DocumentType,
|
||||
StoragePath,
|
||||
Tag,
|
||||
Workflow,
|
||||
WorkflowTrigger,
|
||||
WorkflowAction,
|
||||
)
|
||||
from documents.ai_scanner import AIDocumentScanner
|
||||
from documents.ai_scanner import AIScanResult
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
from documents.models import WorkflowAction
|
||||
from documents.models import WorkflowTrigger
|
||||
|
||||
|
||||
class TestAIScannerIntegrationBasic(TestCase):
|
||||
|
|
@ -35,49 +32,49 @@ class TestAIScannerIntegrationBasic(TestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Invoice from ACME Corporation",
|
||||
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000"
|
||||
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000",
|
||||
)
|
||||
|
||||
self.tag_invoice = Tag.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=Tag.MATCH_AUTO,
|
||||
match="invoice"
|
||||
match="invoice",
|
||||
)
|
||||
self.tag_important = Tag.objects.create(
|
||||
name="Important",
|
||||
matching_algorithm=Tag.MATCH_AUTO,
|
||||
match="total"
|
||||
match="total",
|
||||
)
|
||||
|
||||
self.correspondent = Correspondent.objects.create(
|
||||
name="ACME Corporation",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
match="acme"
|
||||
match="acme",
|
||||
)
|
||||
|
||||
self.doc_type = DocumentType.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||
match="invoice"
|
||||
match="invoice",
|
||||
)
|
||||
|
||||
self.storage_path = StoragePath.objects.create(
|
||||
name="Invoices",
|
||||
path="/invoices",
|
||||
matching_algorithm=StoragePath.MATCH_AUTO,
|
||||
match="invoice"
|
||||
match="invoice",
|
||||
)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
||||
@mock.patch('documents.ai_scanner.match_document_types')
|
||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||
@mock.patch("documents.ai_scanner.match_document_types")
|
||||
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||
def test_full_scan_and_apply_workflow(
|
||||
self,
|
||||
mock_storage,
|
||||
mock_types,
|
||||
mock_correspondents,
|
||||
mock_tags
|
||||
mock_tags,
|
||||
):
|
||||
"""Test complete workflow from scan to application."""
|
||||
# Mock the matching functions to return our test data
|
||||
|
|
@ -91,7 +88,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
|||
# Scan the document
|
||||
scan_result = scanner.scan_document(
|
||||
self.document,
|
||||
self.document.content
|
||||
self.document.content,
|
||||
)
|
||||
|
||||
# Verify scan results
|
||||
|
|
@ -105,7 +102,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
|||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
# Verify application
|
||||
|
|
@ -118,7 +115,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
|||
self.assertEqual(self.document.document_type, self.doc_type)
|
||||
self.assertEqual(self.document.storage_path, self.storage_path)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
def test_scan_with_no_matches(self, mock_tags):
|
||||
"""Test scanning when no matches are found."""
|
||||
mock_tags.return_value = []
|
||||
|
|
@ -127,7 +124,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
|||
|
||||
scan_result = scanner.scan_document(
|
||||
self.document,
|
||||
"Some random text with no matches"
|
||||
"Some random text with no matches",
|
||||
)
|
||||
|
||||
# Should return empty results
|
||||
|
|
@ -143,24 +140,24 @@ class TestAIScannerIntegrationCustomFields(TestCase):
|
|||
"""Set up test data with custom fields."""
|
||||
self.document = Document.objects.create(
|
||||
title="Invoice",
|
||||
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com"
|
||||
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com",
|
||||
)
|
||||
|
||||
self.field_date = CustomField.objects.create(
|
||||
name="Invoice Date",
|
||||
data_type=CustomField.FieldDataType.DATE
|
||||
data_type=CustomField.FieldDataType.DATE,
|
||||
)
|
||||
self.field_number = CustomField.objects.create(
|
||||
name="Invoice Number",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_amount = CustomField.objects.create(
|
||||
name="Total Amount",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_email = CustomField.objects.create(
|
||||
name="Contact Email",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
|
||||
def test_custom_field_extraction_integration(self):
|
||||
|
|
@ -173,7 +170,7 @@ class TestAIScannerIntegrationCustomFields(TestCase):
|
|||
"dates": [{"text": "2024-01-01"}],
|
||||
"amounts": [{"text": "$1,500"}],
|
||||
"invoice_numbers": ["INV-123"],
|
||||
"emails": ["john@example.com"]
|
||||
"emails": ["john@example.com"],
|
||||
}
|
||||
scanner._ner_extractor = mock_ner
|
||||
|
||||
|
|
@ -196,29 +193,29 @@ class TestAIScannerIntegrationWorkflows(TestCase):
|
|||
"""Set up test workflows."""
|
||||
self.document = Document.objects.create(
|
||||
title="Invoice",
|
||||
content="Invoice document"
|
||||
content="Invoice document",
|
||||
)
|
||||
|
||||
self.workflow1 = Workflow.objects.create(
|
||||
name="Invoice Processing",
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
self.trigger1 = WorkflowTrigger.objects.create(
|
||||
workflow=self.workflow1,
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||
)
|
||||
self.action1 = WorkflowAction.objects.create(
|
||||
workflow=self.workflow1,
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||
)
|
||||
|
||||
self.workflow2 = Workflow.objects.create(
|
||||
name="Archive Documents",
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
self.trigger2 = WorkflowTrigger.objects.create(
|
||||
workflow=self.workflow2,
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||
)
|
||||
|
||||
def test_workflow_suggestion_integration(self):
|
||||
|
|
@ -234,7 +231,7 @@ class TestAIScannerIntegrationWorkflows(TestCase):
|
|||
workflows = scanner._suggest_workflows(
|
||||
self.document,
|
||||
self.document.content,
|
||||
scan_result
|
||||
scan_result,
|
||||
)
|
||||
|
||||
# Should suggest workflows
|
||||
|
|
@ -250,7 +247,7 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
self.tag = Tag.objects.create(name="TestTag")
|
||||
self.correspondent = Correspondent.objects.create(name="TestCorp")
|
||||
|
|
@ -273,12 +270,12 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
|
|||
raise Exception("Forced save failure")
|
||||
return original_save(self, *args, **kwargs)
|
||||
|
||||
with mock.patch.object(Document, 'save', failing_save):
|
||||
with mock.patch.object(Document, "save", failing_save):
|
||||
with self.assertRaises(Exception):
|
||||
scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
# Verify changes were rolled back
|
||||
|
|
@ -297,19 +294,19 @@ class TestAIScannerIntegrationPerformance(TestCase):
|
|||
for i in range(5):
|
||||
doc = Document.objects.create(
|
||||
title=f"Document {i}",
|
||||
content=f"Content for document {i}"
|
||||
content=f"Content for document {i}",
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
# Mock to avoid actual ML loading
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
results = []
|
||||
for doc in documents:
|
||||
|
|
@ -329,16 +326,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Business Invoice",
|
||||
content="Invoice from ACME Corporation"
|
||||
content="Invoice from ACME Corporation",
|
||||
)
|
||||
|
||||
self.correspondent_acme = Correspondent.objects.create(
|
||||
name="ACME Corporation",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
)
|
||||
self.correspondent_other = Correspondent.objects.create(
|
||||
name="Other Company",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
)
|
||||
|
||||
def test_correspondent_matching_with_ner_entities(self):
|
||||
|
|
@ -348,16 +345,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
|
|||
# Mock NER to extract organization
|
||||
mock_ner = mock.MagicMock()
|
||||
mock_ner.extract_all.return_value = {
|
||||
"organizations": [{"text": "ACME Corporation"}]
|
||||
"organizations": [{"text": "ACME Corporation"}],
|
||||
}
|
||||
scanner._ner_extractor = mock_ner
|
||||
|
||||
# Mock matching to return empty (so NER-based matching is used)
|
||||
with mock.patch('documents.ai_scanner.match_correspondents', return_value=[]):
|
||||
with mock.patch("documents.ai_scanner.match_correspondents", return_value=[]):
|
||||
result = scanner._detect_correspondent(
|
||||
self.document,
|
||||
self.document.content,
|
||||
{"organizations": [{"text": "ACME Corporation"}]}
|
||||
{"organizations": [{"text": "ACME Corporation"}]},
|
||||
)
|
||||
|
||||
# Should detect ACME correspondent
|
||||
|
|
@ -375,13 +372,13 @@ class TestAIScannerIntegrationTitleGeneration(TestCase):
|
|||
|
||||
document = Document.objects.create(
|
||||
title="document.pdf",
|
||||
content="Invoice from ACME Corp dated 2024-01-15"
|
||||
content="Invoice from ACME Corp dated 2024-01-15",
|
||||
)
|
||||
|
||||
entities = {
|
||||
"document_type": "Invoice",
|
||||
"organizations": [{"text": "ACME Corp"}],
|
||||
"dates": [{"text": "2024-01-15"}]
|
||||
"dates": [{"text": "2024-01-15"}],
|
||||
}
|
||||
|
||||
title = scanner._suggest_title(document, document.content, entities)
|
||||
|
|
@ -399,7 +396,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Test",
|
||||
content="Test"
|
||||
content="Test",
|
||||
)
|
||||
self.tag_high = Tag.objects.create(name="HighConfidence")
|
||||
self.tag_medium = Tag.objects.create(name="MediumConfidence")
|
||||
|
|
@ -409,7 +406,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
|||
"""Test that only high confidence suggestions are auto-applied."""
|
||||
scanner = AIDocumentScanner(
|
||||
auto_apply_threshold=0.80,
|
||||
suggest_threshold=0.60
|
||||
suggest_threshold=0.60,
|
||||
)
|
||||
|
||||
scan_result = AIScanResult()
|
||||
|
|
@ -422,7 +419,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
|||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
# Verify high confidence was applied
|
||||
|
|
@ -448,17 +445,17 @@ class TestAIScannerIntegrationGlobalInstance(TestCase):
|
|||
# Should be functional
|
||||
document = Document.objects.create(
|
||||
title="Test",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
with mock.patch.object(scanner1, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner1, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner1, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner1, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner1, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner1, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner1, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner1, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner1, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner1, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner1, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner1, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner1, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner1, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner1, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner1, "_suggest_title", return_value=None):
|
||||
|
||||
result1 = scanner1.scan_document(document, document.content)
|
||||
result2 = scanner2.scan_document(document, document.content)
|
||||
|
|
@ -476,17 +473,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
|||
|
||||
document = Document.objects.create(
|
||||
title="",
|
||||
content=""
|
||||
content="",
|
||||
)
|
||||
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(document, document.content)
|
||||
|
||||
|
|
@ -498,7 +495,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
|||
|
||||
document = Document.objects.create(
|
||||
title="Test",
|
||||
content="Test"
|
||||
content="Test",
|
||||
)
|
||||
|
||||
scan_result = AIScanResult()
|
||||
|
|
@ -509,7 +506,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
|||
result = scanner.apply_scan_results(
|
||||
document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
# Should not crash, just log errors
|
||||
|
|
@ -521,17 +518,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
|||
|
||||
document = Document.objects.create(
|
||||
title="Factura - España 🇪🇸",
|
||||
content="Société française • 日本語 • Ελληνικά • مرحبا"
|
||||
content="Société française • 日本語 • Ελληνικά • مرحبا",
|
||||
)
|
||||
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(document, document.content)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,18 +12,17 @@ Tests cover:
|
|||
|
||||
from unittest import mock
|
||||
|
||||
from django.contrib.auth.models import Permission, User
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
|
|
@ -36,13 +35,13 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_with_permission = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
self.user_without_permission = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
|
||||
# Assign view permission
|
||||
|
|
@ -57,7 +56,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
# Create test document
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is a test invoice from ACME Corporation"
|
||||
content="This is a test invoice from ACME Corporation",
|
||||
)
|
||||
|
||||
# Create test metadata objects
|
||||
|
|
@ -70,7 +69,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
|
@ -82,7 +81,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
|
@ -91,7 +90,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
"""Test that superusers can access the endpoint."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
||||
|
|
@ -108,7 +107,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -119,7 +118,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
"""Test that users with permission can access the endpoint."""
|
||||
self.client.force_authenticate(user=self.user_with_permission)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = []
|
||||
|
|
@ -136,7 +135,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -148,7 +147,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": 99999},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
|
@ -160,7 +159,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
|
@ -175,10 +174,10 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_with_permission = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign apply permission
|
||||
|
|
@ -193,7 +192,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
# Create test document
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
# Create test metadata
|
||||
|
|
@ -205,7 +204,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/apply/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
|
@ -214,7 +213,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
"""Test successfully applying tag suggestions."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
||||
|
|
@ -233,9 +232,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/suggestions/apply/",
|
||||
{
|
||||
"document_id": self.document.id,
|
||||
"apply_tags": True
|
||||
"apply_tags": True,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -245,7 +244,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
"""Test successfully applying correspondent suggestion."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = []
|
||||
|
|
@ -264,9 +263,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/suggestions/apply/",
|
||||
{
|
||||
"document_id": self.document.id,
|
||||
"apply_correspondent": True
|
||||
"apply_correspondent": True,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -285,10 +284,10 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_without_permission = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
|
||||
def test_unauthorized_access_denied(self):
|
||||
|
|
@ -309,7 +308,7 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
|||
"""Test getting AI configuration."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
mock_scanner_instance = mock.MagicMock()
|
||||
mock_scanner_instance.auto_apply_threshold = 0.80
|
||||
mock_scanner_instance.suggest_threshold = 0.60
|
||||
|
|
@ -331,9 +330,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/config/",
|
||||
{
|
||||
"auto_apply_threshold": 0.90,
|
||||
"suggest_threshold": 0.70
|
||||
"suggest_threshold": 0.70,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -346,9 +345,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/config/",
|
||||
{
|
||||
"auto_apply_threshold": 1.5 # Invalid: > 1.0
|
||||
"auto_apply_threshold": 1.5, # Invalid: > 1.0
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
|
@ -363,13 +362,13 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_with_permission = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
self.user_without_permission = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
|
||||
# Assign approval permission
|
||||
|
|
@ -385,7 +384,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
self.deletion_request = DeletionRequest.objects.create(
|
||||
user=self.user_with_permission,
|
||||
requested_by_ai=True,
|
||||
ai_reason="Document appears to be a duplicate"
|
||||
ai_reason="Document appears to be a duplicate",
|
||||
)
|
||||
|
||||
def test_unauthorized_access_denied(self):
|
||||
|
|
@ -394,9 +393,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
|
@ -409,9 +408,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
|
@ -424,9 +423,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -436,7 +435,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
self.deletion_request.refresh_from_db()
|
||||
self.assertEqual(
|
||||
self.deletion_request.status,
|
||||
DeletionRequest.STATUS_APPROVED
|
||||
DeletionRequest.STATUS_APPROVED,
|
||||
)
|
||||
|
||||
def test_reject_deletion_success(self):
|
||||
|
|
@ -448,9 +447,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "reject",
|
||||
"reason": "Document is still needed"
|
||||
"reason": "Document is still needed",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -459,7 +458,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
self.deletion_request.refresh_from_db()
|
||||
self.assertEqual(
|
||||
self.deletion_request.status,
|
||||
DeletionRequest.STATUS_REJECTED
|
||||
DeletionRequest.STATUS_REJECTED,
|
||||
)
|
||||
|
||||
def test_invalid_request_id(self):
|
||||
|
|
@ -470,9 +469,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": 99999,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
|
@ -485,9 +484,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -502,7 +501,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
|
||||
# Create user with all AI permissions
|
||||
self.power_user = User.objects.create_user(
|
||||
username="power_user", email="power@test.com", password="power123"
|
||||
username="power_user", email="power@test.com", password="power123",
|
||||
)
|
||||
|
||||
content_type = ContentType.objects.get_for_model(Document)
|
||||
|
|
@ -525,7 +524,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
|
||||
self.document = Document.objects.create(
|
||||
title="Test Doc",
|
||||
content="Test"
|
||||
content="Test",
|
||||
)
|
||||
|
||||
def test_power_user_can_access_all_endpoints(self):
|
||||
|
|
@ -533,7 +532,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
self.client.force_authenticate(user=self.power_user)
|
||||
|
||||
# Test suggestions endpoint
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = []
|
||||
mock_scan_result.correspondent = None
|
||||
|
|
@ -553,7 +552,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
response1 = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
|
@ -562,9 +561,9 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/suggestions/apply/",
|
||||
{
|
||||
"document_id": self.document.id,
|
||||
"apply_tags": False
|
||||
"apply_tags": False,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ from rest_framework import status
|
|||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.ai_scanner import AIScanResult
|
||||
from documents.models import (
|
||||
AISuggestionFeedback,
|
||||
Correspondent,
|
||||
Document,
|
||||
DocumentType,
|
||||
StoragePath,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
|
|
@ -64,12 +62,12 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_ai_suggestions_endpoint_exists(self):
|
||||
"""Test that the ai-suggestions endpoint is accessible."""
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||
)
|
||||
# Should not be 404
|
||||
self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
@mock.patch('documents.ai_scanner.get_ai_scanner')
|
||||
@mock.patch("documents.ai_scanner.get_ai_scanner")
|
||||
def test_get_ai_suggestions_success(self, mock_get_scanner):
|
||||
"""Test successfully getting AI suggestions for a document."""
|
||||
# Create mock scan result
|
||||
|
|
@ -87,7 +85,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
|
||||
# Make request
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||
)
|
||||
|
||||
# Verify response
|
||||
|
|
@ -95,23 +93,23 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
data = response.json()
|
||||
|
||||
# Check tags
|
||||
self.assertIn('tags', data)
|
||||
self.assertEqual(len(data['tags']), 2)
|
||||
self.assertEqual(data['tags'][0]['id'], self.tag1.id)
|
||||
self.assertEqual(data['tags'][0]['confidence'], 0.85)
|
||||
self.assertIn("tags", data)
|
||||
self.assertEqual(len(data["tags"]), 2)
|
||||
self.assertEqual(data["tags"][0]["id"], self.tag1.id)
|
||||
self.assertEqual(data["tags"][0]["confidence"], 0.85)
|
||||
|
||||
# Check correspondent
|
||||
self.assertIn('correspondent', data)
|
||||
self.assertEqual(data['correspondent']['id'], self.correspondent.id)
|
||||
self.assertEqual(data['correspondent']['confidence'], 0.90)
|
||||
self.assertIn("correspondent", data)
|
||||
self.assertEqual(data["correspondent"]["id"], self.correspondent.id)
|
||||
self.assertEqual(data["correspondent"]["confidence"], 0.90)
|
||||
|
||||
# Check document type
|
||||
self.assertIn('document_type', data)
|
||||
self.assertEqual(data['document_type']['id'], self.doc_type.id)
|
||||
self.assertIn("document_type", data)
|
||||
self.assertEqual(data["document_type"]["id"], self.doc_type.id)
|
||||
|
||||
# Check title suggestion
|
||||
self.assertIn('title_suggestion', data)
|
||||
self.assertEqual(data['title_suggestion']['title'], "Suggested Title")
|
||||
self.assertIn("title_suggestion", data)
|
||||
self.assertEqual(data["title_suggestion"]["title"], "Suggested Title")
|
||||
|
||||
def test_get_ai_suggestions_no_content(self):
|
||||
"""Test getting AI suggestions for document without content."""
|
||||
|
|
@ -126,7 +124,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("no content", response.json()['detail'].lower())
|
||||
self.assertIn("no content", response.json()["detail"].lower())
|
||||
|
||||
def test_get_ai_suggestions_document_not_found(self):
|
||||
"""Test getting AI suggestions for non-existent document."""
|
||||
|
|
@ -137,19 +135,19 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_tag(self):
|
||||
"""Test applying a tag suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'value_id': self.tag1.id,
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "tag",
|
||||
"value_id": self.tag1.id,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json()['status'], 'success')
|
||||
self.assertEqual(response.json()["status"], "success")
|
||||
|
||||
# Verify tag was applied
|
||||
self.document.refresh_from_db()
|
||||
|
|
@ -158,7 +156,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
||||
|
|
@ -169,15 +167,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_correspondent(self):
|
||||
"""Test applying a correspondent suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'correspondent',
|
||||
'value_id': self.correspondent.id,
|
||||
'confidence': 0.90,
|
||||
"suggestion_type": "correspondent",
|
||||
"value_id": self.correspondent.id,
|
||||
"confidence": 0.90,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -189,7 +187,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='correspondent',
|
||||
suggestion_type="correspondent",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
||||
|
|
@ -197,15 +195,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_document_type(self):
|
||||
"""Test applying a document type suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'document_type',
|
||||
'value_id': self.doc_type.id,
|
||||
'confidence': 0.88,
|
||||
"suggestion_type": "document_type",
|
||||
"value_id": self.doc_type.id,
|
||||
"confidence": 0.88,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -217,35 +215,35 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_title(self):
|
||||
"""Test applying a title suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'title',
|
||||
'value_text': 'New Suggested Title',
|
||||
'confidence': 0.80,
|
||||
"suggestion_type": "title",
|
||||
"value_text": "New Suggested Title",
|
||||
"confidence": 0.80,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Verify title was applied
|
||||
self.document.refresh_from_db()
|
||||
self.assertEqual(self.document.title, 'New Suggested Title')
|
||||
self.assertEqual(self.document.title, "New Suggested Title")
|
||||
|
||||
def test_apply_suggestion_invalid_type(self):
|
||||
"""Test applying suggestion with invalid type."""
|
||||
request_data = {
|
||||
'suggestion_type': 'invalid_type',
|
||||
'value_id': 1,
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "invalid_type",
|
||||
"value_id": 1,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
|
@ -253,14 +251,14 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_missing_value(self):
|
||||
"""Test applying suggestion without value_id or value_text."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "tag",
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
|
@ -268,15 +266,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_nonexistent_object(self):
|
||||
"""Test applying suggestion with non-existent object ID."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'value_id': 99999,
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "tag",
|
||||
"value_id": 99999,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
|
@ -284,24 +282,24 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_reject_suggestion(self):
|
||||
"""Test rejecting an AI suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'value_id': self.tag1.id,
|
||||
'confidence': 0.65,
|
||||
"suggestion_type": "tag",
|
||||
"value_id": self.tag1.id,
|
||||
"confidence": 0.65,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json()['status'], 'success')
|
||||
self.assertEqual(response.json()["status"], "success")
|
||||
|
||||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
||||
|
|
@ -312,15 +310,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_reject_suggestion_with_text(self):
|
||||
"""Test rejecting a suggestion with text value."""
|
||||
request_data = {
|
||||
'suggestion_type': 'title',
|
||||
'value_text': 'Bad Title Suggestion',
|
||||
'confidence': 0.50,
|
||||
"suggestion_type": "title",
|
||||
"value_text": "Bad Title Suggestion",
|
||||
"confidence": 0.50,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
|
@ -328,11 +326,11 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='title',
|
||||
suggestion_type="title",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
||||
self.assertEqual(feedback.suggested_value_text, 'Bad Title Suggestion')
|
||||
self.assertEqual(feedback.suggested_value_text, "Bad Title Suggestion")
|
||||
|
||||
def test_ai_suggestion_stats_empty(self):
|
||||
"""Test getting statistics when no feedback exists."""
|
||||
|
|
@ -341,17 +339,17 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
|
||||
self.assertEqual(data['total_suggestions'], 0)
|
||||
self.assertEqual(data['total_applied'], 0)
|
||||
self.assertEqual(data['total_rejected'], 0)
|
||||
self.assertEqual(data['accuracy_rate'], 0)
|
||||
self.assertEqual(data["total_suggestions"], 0)
|
||||
self.assertEqual(data["total_applied"], 0)
|
||||
self.assertEqual(data["total_rejected"], 0)
|
||||
self.assertEqual(data["accuracy_rate"], 0)
|
||||
|
||||
def test_ai_suggestion_stats_with_data(self):
|
||||
"""Test getting statistics with feedback data."""
|
||||
# Create some feedback entries
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag1.id,
|
||||
confidence=0.85,
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
|
|
@ -359,7 +357,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
)
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag2.id,
|
||||
confidence=0.70,
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
|
|
@ -367,7 +365,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
)
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='correspondent',
|
||||
suggestion_type="correspondent",
|
||||
suggested_value_id=self.correspondent.id,
|
||||
confidence=0.60,
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
|
|
@ -380,25 +378,25 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
data = response.json()
|
||||
|
||||
# Check overall stats
|
||||
self.assertEqual(data['total_suggestions'], 3)
|
||||
self.assertEqual(data['total_applied'], 2)
|
||||
self.assertEqual(data['total_rejected'], 1)
|
||||
self.assertAlmostEqual(data['accuracy_rate'], 66.67, places=1)
|
||||
self.assertEqual(data["total_suggestions"], 3)
|
||||
self.assertEqual(data["total_applied"], 2)
|
||||
self.assertEqual(data["total_rejected"], 1)
|
||||
self.assertAlmostEqual(data["accuracy_rate"], 66.67, places=1)
|
||||
|
||||
# Check by_type stats
|
||||
self.assertIn('by_type', data)
|
||||
self.assertIn('tag', data['by_type'])
|
||||
self.assertEqual(data['by_type']['tag']['total'], 2)
|
||||
self.assertEqual(data['by_type']['tag']['applied'], 2)
|
||||
self.assertEqual(data['by_type']['tag']['rejected'], 0)
|
||||
self.assertIn("by_type", data)
|
||||
self.assertIn("tag", data["by_type"])
|
||||
self.assertEqual(data["by_type"]["tag"]["total"], 2)
|
||||
self.assertEqual(data["by_type"]["tag"]["applied"], 2)
|
||||
self.assertEqual(data["by_type"]["tag"]["rejected"], 0)
|
||||
|
||||
# Check confidence averages
|
||||
self.assertGreater(data['average_confidence_applied'], 0)
|
||||
self.assertGreater(data['average_confidence_rejected'], 0)
|
||||
self.assertGreater(data["average_confidence_applied"], 0)
|
||||
self.assertGreater(data["average_confidence_rejected"], 0)
|
||||
|
||||
# Check recent suggestions
|
||||
self.assertIn('recent_suggestions', data)
|
||||
self.assertEqual(len(data['recent_suggestions']), 3)
|
||||
self.assertIn("recent_suggestions", data)
|
||||
self.assertEqual(len(data["recent_suggestions"]), 3)
|
||||
|
||||
def test_ai_suggestion_stats_accuracy_calculation(self):
|
||||
"""Test that accuracy rate is calculated correctly."""
|
||||
|
|
@ -406,7 +404,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
for i in range(7):
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag1.id,
|
||||
confidence=0.80,
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
|
|
@ -416,7 +414,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
for i in range(3):
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag2.id,
|
||||
confidence=0.60,
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
|
|
@ -428,10 +426,10 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
|
||||
self.assertEqual(data['total_suggestions'], 10)
|
||||
self.assertEqual(data['total_applied'], 7)
|
||||
self.assertEqual(data['total_rejected'], 3)
|
||||
self.assertEqual(data['accuracy_rate'], 70.0)
|
||||
self.assertEqual(data["total_suggestions"], 10)
|
||||
self.assertEqual(data["total_applied"], 7)
|
||||
self.assertEqual(data["total_rejected"], 3)
|
||||
self.assertEqual(data["accuracy_rate"], 70.0)
|
||||
|
||||
def test_authentication_required(self):
|
||||
"""Test that authentication is required for all endpoints."""
|
||||
|
|
@ -439,7 +437,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
|
||||
# Test ai-suggestions endpoint
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,17 +11,14 @@ Tests cover:
|
|||
"""
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
|
||||
|
||||
class TestDeletionRequestAPI(APITestCase):
|
||||
|
|
|
|||
|
|
@ -1561,7 +1561,6 @@ class TestConsumerAIScannerIntegration(
|
|||
Verifies that AI scanner respects database transactions and handles
|
||||
rollbacks correctly.
|
||||
"""
|
||||
from django.db import transaction as db_transaction
|
||||
|
||||
tag = Tag.objects.create(name="Invoice")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,13 +15,11 @@ from django.contrib.auth.models import User
|
|||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
|
||||
|
||||
class TestDeletionRequestModelCreation(TestCase):
|
||||
|
|
|
|||
|
|
@ -8,11 +8,9 @@ from unittest import mock
|
|||
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.ml.model_cache import (
|
||||
CacheMetrics,
|
||||
LRUCache,
|
||||
ModelCacheManager,
|
||||
)
|
||||
from documents.ml.model_cache import CacheMetrics
|
||||
from documents.ml.model_cache import LRUCache
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
|
||||
class TestCacheMetrics(TestCase):
|
||||
|
|
|
|||
|
|
@ -150,7 +150,6 @@ class TestMLCacheDirectory:
|
|||
|
||||
def test_model_cache_writable(self, tmp_path):
|
||||
"""Test that we can write to model cache directory."""
|
||||
import pathlib
|
||||
|
||||
# Use tmp_path fixture for testing
|
||||
cache_dir = tmp_path / ".cache" / "huggingface"
|
||||
|
|
@ -169,7 +168,6 @@ class TestMLCacheDirectory:
|
|||
|
||||
def test_torch_cache_directory(self, tmp_path, monkeypatch):
|
||||
"""Test that PyTorch can use a custom cache directory."""
|
||||
import torch
|
||||
|
||||
# Set custom cache directory
|
||||
cache_dir = tmp_path / ".cache" / "torch"
|
||||
|
|
@ -204,9 +202,10 @@ class TestMLPerformanceBasic:
|
|||
|
||||
def test_numpy_performance_basic(self):
|
||||
"""Test basic NumPy performance with larger arrays."""
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Create large array (10 million elements)
|
||||
arr = np.random.rand(10_000_000)
|
||||
|
||||
|
|
|
|||
|
|
@ -89,6 +89,8 @@ from rest_framework.viewsets import ViewSet
|
|||
|
||||
from documents import bulk_edit
|
||||
from documents import index
|
||||
from documents.ai_scanner import AIDocumentScanner
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.bulk_download import ArchiveOnlyStrategy
|
||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||
from documents.bulk_download import OriginalsOnlyStrategy
|
||||
|
|
@ -141,13 +143,10 @@ from documents.models import UiSettings
|
|||
from documents.models import Workflow
|
||||
from documents.models import WorkflowAction
|
||||
from documents.models import WorkflowTrigger
|
||||
from documents.ai_scanner import AIDocumentScanner
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.parsers import get_parser_class_for_mime_type
|
||||
from documents.parsers import parse_date_generator
|
||||
from documents.permissions import AcknowledgeTasksPermissions
|
||||
from documents.permissions import CanApplyAISuggestionsPermission
|
||||
from documents.permissions import CanApproveDeletionsPermission
|
||||
from documents.permissions import CanConfigureAIPermission
|
||||
from documents.permissions import CanViewAISuggestionsPermission
|
||||
from documents.permissions import PaperlessAdminPermissions
|
||||
|
|
@ -1388,7 +1387,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
scan_result = scanner.scan_document(
|
||||
document=document,
|
||||
document_text=document.content,
|
||||
original_file_path=document.source_path if hasattr(document, 'source_path') else None,
|
||||
original_file_path=document.source_path if hasattr(document, "source_path") else None,
|
||||
)
|
||||
|
||||
# Convert scan result to serializable format
|
||||
|
|
@ -1424,43 +1423,43 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
serializer = ApplySuggestionSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
suggestion_type = serializer.validated_data['suggestion_type']
|
||||
value_id = serializer.validated_data.get('value_id')
|
||||
value_text = serializer.validated_data.get('value_text')
|
||||
confidence = serializer.validated_data['confidence']
|
||||
suggestion_type = serializer.validated_data["suggestion_type"]
|
||||
value_id = serializer.validated_data.get("value_id")
|
||||
value_text = serializer.validated_data.get("value_text")
|
||||
confidence = serializer.validated_data["confidence"]
|
||||
|
||||
# Apply the suggestion based on type
|
||||
applied = False
|
||||
result_message = ""
|
||||
|
||||
if suggestion_type == 'tag' and value_id:
|
||||
if suggestion_type == "tag" and value_id:
|
||||
tag = Tag.objects.get(pk=value_id)
|
||||
document.tags.add(tag)
|
||||
applied = True
|
||||
result_message = f"Tag '{tag.name}' applied"
|
||||
|
||||
elif suggestion_type == 'correspondent' and value_id:
|
||||
elif suggestion_type == "correspondent" and value_id:
|
||||
correspondent = Correspondent.objects.get(pk=value_id)
|
||||
document.correspondent = correspondent
|
||||
document.save()
|
||||
applied = True
|
||||
result_message = f"Correspondent '{correspondent.name}' applied"
|
||||
|
||||
elif suggestion_type == 'document_type' and value_id:
|
||||
elif suggestion_type == "document_type" and value_id:
|
||||
doc_type = DocumentType.objects.get(pk=value_id)
|
||||
document.document_type = doc_type
|
||||
document.save()
|
||||
applied = True
|
||||
result_message = f"Document type '{doc_type.name}' applied"
|
||||
|
||||
elif suggestion_type == 'storage_path' and value_id:
|
||||
elif suggestion_type == "storage_path" and value_id:
|
||||
storage_path = StoragePath.objects.get(pk=value_id)
|
||||
document.storage_path = storage_path
|
||||
document.save()
|
||||
applied = True
|
||||
result_message = f"Storage path '{storage_path.name}' applied"
|
||||
|
||||
elif suggestion_type == 'title' and value_text:
|
||||
elif suggestion_type == "title" and value_text:
|
||||
document.title = value_text
|
||||
document.save()
|
||||
applied = True
|
||||
|
|
@ -1518,10 +1517,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
serializer = RejectSuggestionSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
suggestion_type = serializer.validated_data['suggestion_type']
|
||||
value_id = serializer.validated_data.get('value_id')
|
||||
value_text = serializer.validated_data.get('value_text')
|
||||
confidence = serializer.validated_data['confidence']
|
||||
suggestion_type = serializer.validated_data["suggestion_type"]
|
||||
value_id = serializer.validated_data.get("value_id")
|
||||
value_text = serializer.validated_data.get("value_text")
|
||||
confidence = serializer.validated_data["confidence"]
|
||||
|
||||
# Record feedback
|
||||
AISuggestionFeedback.objects.create(
|
||||
|
|
@ -1554,7 +1553,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
Returns aggregated data about applied vs rejected suggestions,
|
||||
accuracy rates, and confidence scores.
|
||||
"""
|
||||
from django.db.models import Avg, Count, Q
|
||||
from django.db.models import Avg
|
||||
from django.db.models import Count
|
||||
from django.db.models import Q
|
||||
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.serializers.ai_suggestions import AISuggestionStatsSerializer
|
||||
|
||||
|
|
@ -1562,61 +1564,63 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
# Get overall counts
|
||||
total_feedbacks = AISuggestionFeedback.objects.count()
|
||||
total_applied = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_APPLIED
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
).count()
|
||||
total_rejected = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_REJECTED
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
).count()
|
||||
|
||||
# Calculate accuracy rate
|
||||
accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0
|
||||
|
||||
# Get statistics by suggestion type using a single aggregated query
|
||||
stats_by_type = AISuggestionFeedback.objects.values('suggestion_type').annotate(
|
||||
total=Count('id'),
|
||||
applied=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
|
||||
rejected=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_REJECTED))
|
||||
stats_by_type = AISuggestionFeedback.objects.values("suggestion_type").annotate(
|
||||
total=Count("id"),
|
||||
applied=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
|
||||
rejected=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)),
|
||||
)
|
||||
|
||||
# Build the by_type dictionary using the aggregated results
|
||||
by_type = {}
|
||||
for stat in stats_by_type:
|
||||
suggestion_type = stat['suggestion_type']
|
||||
type_total = stat['total']
|
||||
type_applied = stat['applied']
|
||||
type_rejected = stat['rejected']
|
||||
suggestion_type = stat["suggestion_type"]
|
||||
type_total = stat["total"]
|
||||
type_applied = stat["applied"]
|
||||
type_rejected = stat["rejected"]
|
||||
|
||||
by_type[suggestion_type] = {
|
||||
'total': type_total,
|
||||
'applied': type_applied,
|
||||
'rejected': type_rejected,
|
||||
'accuracy_rate': (type_applied / type_total * 100) if type_total > 0 else 0,
|
||||
"total": type_total,
|
||||
"applied": type_applied,
|
||||
"rejected": type_rejected,
|
||||
"accuracy_rate": (type_applied / type_total * 100) if type_total > 0 else 0,
|
||||
}
|
||||
|
||||
# Get average confidence scores
|
||||
avg_confidence_applied = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_APPLIED
|
||||
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
|
||||
|
||||
avg_confidence_rejected = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_REJECTED
|
||||
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
|
||||
|
||||
# Get recent suggestions (last 10)
|
||||
recent_suggestions = AISuggestionFeedback.objects.order_by('-created_at')[:10]
|
||||
recent_suggestions = AISuggestionFeedback.objects.order_by("-created_at")[:10]
|
||||
|
||||
# Build response data
|
||||
from documents.serializers.ai_suggestions import AISuggestionFeedbackSerializer
|
||||
from documents.serializers.ai_suggestions import (
|
||||
AISuggestionFeedbackSerializer,
|
||||
)
|
||||
data = {
|
||||
'total_suggestions': total_feedbacks,
|
||||
'total_applied': total_applied,
|
||||
'total_rejected': total_rejected,
|
||||
'accuracy_rate': accuracy_rate,
|
||||
'by_type': by_type,
|
||||
'average_confidence_applied': avg_confidence_applied,
|
||||
'average_confidence_rejected': avg_confidence_rejected,
|
||||
'recent_suggestions': AISuggestionFeedbackSerializer(
|
||||
recent_suggestions, many=True
|
||||
"total_suggestions": total_feedbacks,
|
||||
"total_applied": total_applied,
|
||||
"total_rejected": total_rejected,
|
||||
"accuracy_rate": accuracy_rate,
|
||||
"by_type": by_type,
|
||||
"average_confidence_applied": avg_confidence_applied,
|
||||
"average_confidence_rejected": avg_confidence_rejected,
|
||||
"recent_suggestions": AISuggestionFeedbackSerializer(
|
||||
recent_suggestions, many=True,
|
||||
).data,
|
||||
}
|
||||
|
||||
|
|
@ -3571,21 +3575,21 @@ class AISuggestionsView(GenericAPIView):
|
|||
request_serializer = AISuggestionsRequestSerializer(data=request.data)
|
||||
request_serializer.is_valid(raise_exception=True)
|
||||
|
||||
document_id = request_serializer.validated_data['document_id']
|
||||
document_id = request_serializer.validated_data["document_id"]
|
||||
|
||||
try:
|
||||
document = Document.objects.get(pk=document_id)
|
||||
except Document.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "Document not found or you don't have permission to view it"},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Check if user has permission to view this document
|
||||
if not has_perms_owner_aware(request.user, 'documents.view_document', document):
|
||||
if not has_perms_owner_aware(request.user, "documents.view_document", document):
|
||||
return Response(
|
||||
{"error": "Permission denied"},
|
||||
status=status.HTTP_403_FORBIDDEN
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# Get AI scanner and scan document
|
||||
|
|
@ -3600,7 +3604,7 @@ class AISuggestionsView(GenericAPIView):
|
|||
"document_type": None,
|
||||
"storage_path": None,
|
||||
"title_suggestion": scan_result.title_suggestion,
|
||||
"custom_fields": {}
|
||||
"custom_fields": {},
|
||||
}
|
||||
|
||||
# Format tag suggestions
|
||||
|
|
@ -3610,7 +3614,7 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["tags"].append({
|
||||
"id": tag.id,
|
||||
"name": tag.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
})
|
||||
except Tag.DoesNotExist:
|
||||
# Tag was suggested by AI but no longer exists; skip it
|
||||
|
|
@ -3624,7 +3628,7 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["correspondent"] = {
|
||||
"id": correspondent.id,
|
||||
"name": correspondent.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
except Correspondent.DoesNotExist:
|
||||
# Correspondent was suggested but no longer exists; skip it
|
||||
|
|
@ -3638,7 +3642,7 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["document_type"] = {
|
||||
"id": doc_type.id,
|
||||
"name": doc_type.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
except DocumentType.DoesNotExist:
|
||||
# Document type was suggested but no longer exists; skip it
|
||||
|
|
@ -3652,7 +3656,7 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["storage_path"] = {
|
||||
"id": storage_path.id,
|
||||
"name": storage_path.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
except StoragePath.DoesNotExist:
|
||||
# Storage path was suggested but no longer exists; skip it
|
||||
|
|
@ -3662,7 +3666,7 @@ class AISuggestionsView(GenericAPIView):
|
|||
for field_id, (value, confidence) in scan_result.custom_fields.items():
|
||||
response_data["custom_fields"][str(field_id)] = {
|
||||
"value": value,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
|
||||
return Response(response_data)
|
||||
|
|
@ -3683,21 +3687,21 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
serializer = ApplyAISuggestionsSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
document_id = serializer.validated_data['document_id']
|
||||
document_id = serializer.validated_data["document_id"]
|
||||
|
||||
try:
|
||||
document = Document.objects.get(pk=document_id)
|
||||
except Document.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "Document not found"},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Check if user has permission to change this document
|
||||
if not has_perms_owner_aware(request.user, 'documents.change_document', document):
|
||||
if not has_perms_owner_aware(request.user, "documents.change_document", document):
|
||||
return Response(
|
||||
{"error": "Permission denied"},
|
||||
status=status.HTTP_403_FORBIDDEN
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# Get AI scanner and scan document
|
||||
|
|
@ -3707,8 +3711,8 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
# Apply suggestions based on user selections
|
||||
applied = []
|
||||
|
||||
if serializer.validated_data.get('apply_tags'):
|
||||
selected_tags = serializer.validated_data.get('selected_tags', [])
|
||||
if serializer.validated_data.get("apply_tags"):
|
||||
selected_tags = serializer.validated_data.get("selected_tags", [])
|
||||
if selected_tags:
|
||||
# Apply only selected tags
|
||||
tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags]
|
||||
|
|
@ -3725,7 +3729,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
# Tag not found; skip applying this tag
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_correspondent') and scan_result.correspondent:
|
||||
if serializer.validated_data.get("apply_correspondent") and scan_result.correspondent:
|
||||
corr_id, confidence = scan_result.correspondent
|
||||
try:
|
||||
correspondent = Correspondent.objects.get(pk=corr_id)
|
||||
|
|
@ -3735,7 +3739,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
# Correspondent not found; skip applying
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_document_type') and scan_result.document_type:
|
||||
if serializer.validated_data.get("apply_document_type") and scan_result.document_type:
|
||||
type_id, confidence = scan_result.document_type
|
||||
try:
|
||||
doc_type = DocumentType.objects.get(pk=type_id)
|
||||
|
|
@ -3745,7 +3749,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
# Document type not found; skip applying
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_storage_path') and scan_result.storage_path:
|
||||
if serializer.validated_data.get("apply_storage_path") and scan_result.storage_path:
|
||||
path_id, confidence = scan_result.storage_path
|
||||
try:
|
||||
storage_path = StoragePath.objects.get(pk=path_id)
|
||||
|
|
@ -3755,7 +3759,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
# Storage path not found; skip applying
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_title') and scan_result.title_suggestion:
|
||||
if serializer.validated_data.get("apply_title") and scan_result.title_suggestion:
|
||||
document.title = scan_result.title_suggestion
|
||||
applied.append(f"title: {scan_result.title_suggestion}")
|
||||
|
||||
|
|
@ -3765,7 +3769,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
return Response({
|
||||
"status": "success",
|
||||
"document_id": document.id,
|
||||
"applied": applied
|
||||
"applied": applied,
|
||||
})
|
||||
|
||||
|
||||
|
|
@ -3805,14 +3809,14 @@ class AIConfigurationView(GenericAPIView):
|
|||
|
||||
# Create new scanner with updated configuration
|
||||
config = {}
|
||||
if 'auto_apply_threshold' in serializer.validated_data:
|
||||
config['auto_apply_threshold'] = serializer.validated_data['auto_apply_threshold']
|
||||
if 'suggest_threshold' in serializer.validated_data:
|
||||
config['suggest_threshold'] = serializer.validated_data['suggest_threshold']
|
||||
if 'ml_enabled' in serializer.validated_data:
|
||||
config['enable_ml_features'] = serializer.validated_data['ml_enabled']
|
||||
if 'advanced_ocr_enabled' in serializer.validated_data:
|
||||
config['enable_advanced_ocr'] = serializer.validated_data['advanced_ocr_enabled']
|
||||
if "auto_apply_threshold" in serializer.validated_data:
|
||||
config["auto_apply_threshold"] = serializer.validated_data["auto_apply_threshold"]
|
||||
if "suggest_threshold" in serializer.validated_data:
|
||||
config["suggest_threshold"] = serializer.validated_data["suggest_threshold"]
|
||||
if "ml_enabled" in serializer.validated_data:
|
||||
config["enable_ml_features"] = serializer.validated_data["ml_enabled"]
|
||||
if "advanced_ocr_enabled" in serializer.validated_data:
|
||||
config["enable_advanced_ocr"] = serializer.validated_data["advanced_ocr_enabled"]
|
||||
|
||||
# Update global scanner instance
|
||||
# WARNING: Not thread-safe. Consider storing configuration in database
|
||||
|
|
@ -3822,7 +3826,7 @@ class AIConfigurationView(GenericAPIView):
|
|||
|
||||
return Response({
|
||||
"status": "success",
|
||||
"message": "AI configuration updated. Changes may require server restart for consistency."
|
||||
"message": "AI configuration updated. Changes may require server restart for consistency.",
|
||||
})
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
# Check permissions
|
||||
if not self._can_manage_request(deletion_request):
|
||||
return HttpResponseForbidden(
|
||||
"You don't have permission to approve this deletion request."
|
||||
"You don't have permission to approve this deletion request.",
|
||||
)
|
||||
|
||||
# Validate status
|
||||
|
|
@ -114,11 +114,11 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
deleted_count += 1
|
||||
logger.info(
|
||||
f"Deleted document {doc_id} ('{doc_title}') "
|
||||
f"as part of deletion request {deletion_request.id}"
|
||||
f"as part of deletion request {deletion_request.id}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete document {doc.id}: {str(e)}"
|
||||
f"Failed to delete document {doc.id}: {e!s}",
|
||||
)
|
||||
failed_deletions.append({
|
||||
"id": doc.id,
|
||||
|
|
@ -138,14 +138,14 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
|
||||
logger.info(
|
||||
f"Deletion request {deletion_request.id} completed. "
|
||||
f"Deleted {deleted_count}/{len(documents)} documents."
|
||||
f"Deleted {deleted_count}/{len(documents)} documents.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error executing deletion request {deletion_request.id}: {str(e)}"
|
||||
f"Error executing deletion request {deletion_request.id}: {e!s}",
|
||||
)
|
||||
return Response(
|
||||
{"error": f"Failed to execute deletion: {str(e)}"},
|
||||
{"error": f"Failed to execute deletion: {e!s}"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
# Check permissions
|
||||
if not self._can_manage_request(deletion_request):
|
||||
return HttpResponseForbidden(
|
||||
"You don't have permission to reject this deletion request."
|
||||
"You don't have permission to reject this deletion request.",
|
||||
)
|
||||
|
||||
# Validate status
|
||||
|
|
@ -199,7 +199,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
)
|
||||
|
||||
logger.info(
|
||||
f"Deletion request {deletion_request.id} rejected by user {request.user.username}"
|
||||
f"Deletion request {deletion_request.id} rejected by user {request.user.username}",
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(deletion_request)
|
||||
|
|
@ -228,7 +228,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
# Check permissions
|
||||
if not self._can_manage_request(deletion_request):
|
||||
return HttpResponseForbidden(
|
||||
"You don't have permission to cancel this deletion request."
|
||||
"You don't have permission to cancel this deletion request.",
|
||||
)
|
||||
|
||||
# Validate status
|
||||
|
|
@ -249,7 +249,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
deletion_request.save()
|
||||
|
||||
logger.info(
|
||||
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}"
|
||||
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}",
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(deletion_request)
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class SecurityHeadersMiddleware:
|
|||
|
||||
# Store nonce in request for use in templates
|
||||
# Templates can access this via {{ request.csp_nonce }}
|
||||
if hasattr(request, '_csp_nonce'):
|
||||
if hasattr(request, "_csp_nonce"):
|
||||
request._csp_nonce = nonce
|
||||
|
||||
# Prevent clickjacking attacks
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
|||
|
||||
import hashlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
|
@ -26,39 +25,39 @@ logger = logging.getLogger("paperless.security")
|
|||
# Lista explícita de tipos MIME permitidos
|
||||
ALLOWED_MIME_TYPES = {
|
||||
# Documentos
|
||||
'application/pdf',
|
||||
'application/vnd.oasis.opendocument.text',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'application/vnd.oasis.opendocument.spreadsheet',
|
||||
'application/vnd.oasis.opendocument.presentation',
|
||||
'application/rtf',
|
||||
'text/rtf',
|
||||
"application/pdf",
|
||||
"application/vnd.oasis.opendocument.text",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/vnd.oasis.opendocument.spreadsheet",
|
||||
"application/vnd.oasis.opendocument.presentation",
|
||||
"application/rtf",
|
||||
"text/rtf",
|
||||
|
||||
# Imágenes
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/tiff',
|
||||
'image/bmp',
|
||||
'image/webp',
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/tiff",
|
||||
"image/bmp",
|
||||
"image/webp",
|
||||
|
||||
# Texto
|
||||
'text/plain',
|
||||
'text/html',
|
||||
'text/csv',
|
||||
'text/markdown',
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/csv",
|
||||
"text/markdown",
|
||||
}
|
||||
|
||||
# Maximum file size (100MB by default)
|
||||
# Can be overridden by settings.MAX_UPLOAD_SIZE
|
||||
try:
|
||||
from django.conf import settings
|
||||
MAX_FILE_SIZE = getattr(settings, 'MAX_UPLOAD_SIZE', 100 * 1024 * 1024) # 100MB por defecto
|
||||
MAX_FILE_SIZE = getattr(settings, "MAX_UPLOAD_SIZE", 100 * 1024 * 1024) # 100MB por defecto
|
||||
except ImportError:
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes
|
||||
|
||||
|
|
@ -114,7 +113,6 @@ ALLOWED_JS_PATTERNS = [
|
|||
class FileValidationError(Exception):
|
||||
"""Raised when file validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def has_whitelisted_javascript(content: bytes) -> bool:
|
||||
|
|
@ -143,7 +141,7 @@ def validate_mime_type(mime_type: str) -> None:
|
|||
if mime_type not in ALLOWED_MIME_TYPES:
|
||||
raise FileValidationError(
|
||||
f"MIME type '{mime_type}' is not allowed. "
|
||||
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}"
|
||||
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ api_router.register(r"config", ApplicationConfigurationViewSet)
|
|||
api_router.register(r"processed_mail", ProcessedMailViewSet)
|
||||
api_router.register(r"deletion_requests", DeletionRequestViewSet)
|
||||
api_router.register(
|
||||
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests"
|
||||
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue