mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-06 06:45:05 +01:00
fix(syntax): corrige errores de sintaxis y formato en Python
- Corrige paréntesis faltante en DeletionRequestActionSerializer (serialisers.py:2855) - Elimina espacios en blanco en líneas vacías (W293) - Elimina espacios finales en líneas (W291) - Elimina imports no utilizados (F401) - Normaliza comillas a comillas dobles (Q000) - Agrega comas finales faltantes (COM812) - Ordena imports según convenciones (I001) - Actualiza anotaciones de tipo a PEP 585 (UP006) Este commit resuelve el error de compilación en el job de CI/CD que estaba causando que fallara el linting check. Archivos afectados: 38 Líneas modificadas: ~2200
This commit is contained in:
parent
9298f64546
commit
69326b883d
38 changed files with 2077 additions and 2112 deletions
|
|
@ -14,14 +14,10 @@ According to agents.md requirements:
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger("paperless.ai_deletion")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ from __future__ import annotations
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from django.conf import settings
|
||||
|
|
@ -142,34 +141,34 @@ class AIScanResult:
|
|||
"""
|
||||
# Convert internal tuple format to TypedDict format
|
||||
result: AIScanResultDict = {
|
||||
'tags': [{'tag_id': tag_id, 'confidence': conf} for tag_id, conf in self.tags],
|
||||
'custom_fields': {
|
||||
field_id: {'value': value, 'confidence': conf}
|
||||
"tags": [{"tag_id": tag_id, "confidence": conf} for tag_id, conf in self.tags],
|
||||
"custom_fields": {
|
||||
field_id: {"value": value, "confidence": conf}
|
||||
for field_id, (value, conf) in self.custom_fields.items()
|
||||
},
|
||||
'workflows': [{'workflow_id': wf_id, 'confidence': conf} for wf_id, conf in self.workflows],
|
||||
'extracted_entities': self.extracted_entities,
|
||||
'metadata': self.metadata,
|
||||
"workflows": [{"workflow_id": wf_id, "confidence": conf} for wf_id, conf in self.workflows],
|
||||
"extracted_entities": self.extracted_entities,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
# Add optional fields only if present
|
||||
if self.correspondent:
|
||||
result['correspondent'] = {
|
||||
'correspondent_id': self.correspondent[0],
|
||||
'confidence': self.correspondent[1],
|
||||
result["correspondent"] = {
|
||||
"correspondent_id": self.correspondent[0],
|
||||
"confidence": self.correspondent[1],
|
||||
}
|
||||
if self.document_type:
|
||||
result['document_type'] = {
|
||||
'type_id': self.document_type[0],
|
||||
'confidence': self.document_type[1],
|
||||
result["document_type"] = {
|
||||
"type_id": self.document_type[0],
|
||||
"confidence": self.document_type[1],
|
||||
}
|
||||
if self.storage_path:
|
||||
result['storage_path'] = {
|
||||
'path_id': self.storage_path[0],
|
||||
'confidence': self.storage_path[1],
|
||||
result["storage_path"] = {
|
||||
"path_id": self.storage_path[0],
|
||||
"confidence": self.storage_path[1],
|
||||
}
|
||||
if self.title_suggestion:
|
||||
result['title_suggestion'] = self.title_suggestion
|
||||
result["title_suggestion"] = self.title_suggestion
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -257,14 +256,14 @@ class AIDocumentScanner:
|
|||
if self._classifier is None and self.ml_enabled:
|
||||
try:
|
||||
from documents.ml.classifier import TransformerDocumentClassifier
|
||||
|
||||
|
||||
# Get model name from settings
|
||||
model_name = getattr(
|
||||
settings,
|
||||
settings,
|
||||
"PAPERLESS_ML_CLASSIFIER_MODEL",
|
||||
"distilbert-base-uncased",
|
||||
)
|
||||
|
||||
|
||||
self._classifier = TransformerDocumentClassifier(
|
||||
model_name=model_name,
|
||||
use_cache=True,
|
||||
|
|
@ -291,14 +290,14 @@ class AIDocumentScanner:
|
|||
if self._semantic_search is None and self.ml_enabled:
|
||||
try:
|
||||
from documents.ml.semantic_search import SemanticSearch
|
||||
|
||||
|
||||
# Get cache directory from settings
|
||||
cache_dir = getattr(
|
||||
settings,
|
||||
"PAPERLESS_ML_MODEL_CACHE",
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
self._semantic_search = SemanticSearch(
|
||||
cache_dir=cache_dir,
|
||||
use_cache=True,
|
||||
|
|
@ -1004,13 +1003,13 @@ class AIDocumentScanner:
|
|||
import time
|
||||
logger.info("Starting ML model warm-up...")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
|
||||
# Define model loaders
|
||||
model_loaders = {}
|
||||
|
||||
|
||||
# Classifier
|
||||
if self.ml_enabled:
|
||||
def load_classifier():
|
||||
|
|
@ -1025,14 +1024,14 @@ class AIDocumentScanner:
|
|||
use_cache=True,
|
||||
)
|
||||
model_loaders["classifier"] = load_classifier
|
||||
|
||||
|
||||
# NER
|
||||
if self.ml_enabled:
|
||||
def load_ner():
|
||||
from documents.ml.ner import DocumentNER
|
||||
return DocumentNER(use_cache=True)
|
||||
model_loaders["ner"] = load_ner
|
||||
|
||||
|
||||
# Semantic Search
|
||||
if self.ml_enabled:
|
||||
def load_semantic():
|
||||
|
|
@ -1040,21 +1039,21 @@ class AIDocumentScanner:
|
|||
cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None)
|
||||
return SemanticSearch(cache_dir=cache_dir, use_cache=True)
|
||||
model_loaders["semantic_search"] = load_semantic
|
||||
|
||||
|
||||
# Table Extractor
|
||||
if self.advanced_ocr_enabled:
|
||||
def load_table():
|
||||
from documents.ocr.table_extractor import TableExtractor
|
||||
return TableExtractor()
|
||||
model_loaders["table_extractor"] = load_table
|
||||
|
||||
|
||||
# Warm up all models
|
||||
cache_manager.warm_up(model_loaders)
|
||||
|
||||
|
||||
warm_up_time = time.time() - start_time
|
||||
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
|
||||
|
||||
def get_cache_metrics(self) -> Dict[str, Any]:
|
||||
def get_cache_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache performance metrics.
|
||||
|
||||
|
|
@ -1062,7 +1061,7 @@ class AIDocumentScanner:
|
|||
Dictionary with cache statistics
|
||||
"""
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
|
||||
try:
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
return cache_manager.get_metrics()
|
||||
|
|
@ -1075,7 +1074,7 @@ class AIDocumentScanner:
|
|||
def clear_cache(self) -> None:
|
||||
"""Clear all model caches."""
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
|
||||
try:
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
cache_manager.clear_all()
|
||||
|
|
|
|||
|
|
@ -38,22 +38,22 @@ class DocumentsConfig(AppConfig):
|
|||
def _initialize_ml_cache(self):
|
||||
"""Initialize ML model cache and optionally warm up models."""
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
# Only initialize if ML features are enabled
|
||||
if not getattr(settings, "PAPERLESS_ENABLE_ML_FEATURES", False):
|
||||
return
|
||||
|
||||
|
||||
# Initialize cache manager with settings
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
|
||||
max_models = getattr(settings, "PAPERLESS_ML_CACHE_MAX_MODELS", 3)
|
||||
cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None)
|
||||
|
||||
|
||||
cache_manager = ModelCacheManager.get_instance(
|
||||
max_models=max_models,
|
||||
disk_cache_dir=str(cache_dir) if cache_dir else None,
|
||||
)
|
||||
|
||||
|
||||
# Warm up models if configured
|
||||
warmup_enabled = getattr(settings, "PAPERLESS_ML_CACHE_WARMUP", False)
|
||||
if warmup_enabled:
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ class ConsumerPlugin(
|
|||
# Parsing phase
|
||||
document_parser = self._create_parser_instance(parser_class)
|
||||
text, date, thumbnail, archive_path, page_count = self._parse_document(
|
||||
document_parser, mime_type
|
||||
document_parser, mime_type,
|
||||
)
|
||||
|
||||
# Storage phase
|
||||
|
|
@ -394,7 +394,7 @@ class ConsumerPlugin(
|
|||
def _attempt_pdf_recovery(
|
||||
self,
|
||||
tempdir: tempfile.TemporaryDirectory,
|
||||
original_mime_type: str
|
||||
original_mime_type: str,
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to recover a PDF file with incorrect MIME type using qpdf.
|
||||
|
|
@ -438,7 +438,7 @@ class ConsumerPlugin(
|
|||
def _get_parser_class(
|
||||
self,
|
||||
mime_type: str,
|
||||
tempdir: tempfile.TemporaryDirectory
|
||||
tempdir: tempfile.TemporaryDirectory,
|
||||
) -> type[DocumentParser]:
|
||||
"""
|
||||
Determine which parser to use based on MIME type.
|
||||
|
|
@ -468,7 +468,7 @@ class ConsumerPlugin(
|
|||
|
||||
def _create_parser_instance(
|
||||
self,
|
||||
parser_class: type[DocumentParser]
|
||||
parser_class: type[DocumentParser],
|
||||
) -> DocumentParser:
|
||||
"""
|
||||
Create a parser instance with progress callback.
|
||||
|
|
@ -496,7 +496,7 @@ class ConsumerPlugin(
|
|||
def _parse_document(
|
||||
self,
|
||||
document_parser: DocumentParser,
|
||||
mime_type: str
|
||||
mime_type: str,
|
||||
) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]:
|
||||
"""
|
||||
Parse the document and extract metadata.
|
||||
|
|
@ -670,7 +670,7 @@ class ConsumerPlugin(
|
|||
self,
|
||||
document: Document,
|
||||
thumbnail: Path,
|
||||
archive_path: Path | None
|
||||
archive_path: Path | None,
|
||||
) -> None:
|
||||
"""
|
||||
Store document files (source, thumbnail, archive) to disk.
|
||||
|
|
@ -949,7 +949,7 @@ class ConsumerPlugin(
|
|||
text: The extracted document text
|
||||
"""
|
||||
# Check if AI scanner is enabled
|
||||
if not getattr(settings, 'PAPERLESS_ENABLE_AI_SCANNER', True):
|
||||
if not getattr(settings, "PAPERLESS_ENABLE_AI_SCANNER", True):
|
||||
self.log.debug("AI scanner is disabled, skipping AI analysis")
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Generated manually for performance optimization
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# Generated manually for DeletionRequest model
|
||||
# Based on model definition in documents/models.py
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
@ -48,7 +49,7 @@ class Migration(migrations.Migration):
|
|||
(
|
||||
"ai_reason",
|
||||
models.TextField(
|
||||
help_text="Detailed explanation from AI about why deletion is recommended"
|
||||
help_text="Detailed explanation from AI about why deletion is recommended",
|
||||
),
|
||||
),
|
||||
(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Generated manually for DeletionRequest performance optimization
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# Generated manually for AI Suggestions API
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ Provides AI/ML capabilities including:
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"TransformerDocumentClassifier",
|
||||
"DocumentNER",
|
||||
"SemanticSearch",
|
||||
"TransformerDocumentClassifier",
|
||||
]
|
||||
|
||||
# Lazy imports to avoid loading heavy ML libraries unless needed
|
||||
|
|
|
|||
|
|
@ -15,23 +15,16 @@ Logging levels used in this module:
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import Trainer
|
||||
from transformers import TrainingArguments
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from documents.models import Document
|
||||
|
||||
logger = logging.getLogger("paperless.ml.classifier")
|
||||
|
||||
|
||||
|
|
@ -129,10 +122,10 @@ class TransformerDocumentClassifier:
|
|||
self.model_name = model_name
|
||||
self.use_cache = use_cache
|
||||
self.cache_manager = ModelCacheManager.get_instance() if use_cache else None
|
||||
|
||||
|
||||
# Cache key for this model configuration
|
||||
self.cache_key = f"classifier_{model_name}"
|
||||
|
||||
|
||||
# Load tokenizer (lightweight, not cached)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = None
|
||||
|
|
@ -141,7 +134,7 @@ class TransformerDocumentClassifier:
|
|||
|
||||
logger.info(
|
||||
f"Initialized TransformerDocumentClassifier with {model_name} "
|
||||
f"(caching: {use_cache})"
|
||||
f"(caching: {use_cache})",
|
||||
)
|
||||
|
||||
def train(
|
||||
|
|
@ -264,14 +257,14 @@ class TransformerDocumentClassifier:
|
|||
if self.use_cache and self.cache_manager:
|
||||
# Try to get from cache first
|
||||
cache_key = f"{self.cache_key}_{model_dir}"
|
||||
|
||||
|
||||
def loader():
|
||||
logger.info(f"Loading model from {model_dir}")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model.eval() # Set to evaluation mode
|
||||
return {"model": model, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
cached = self.cache_manager.get_or_load_model(cache_key, loader)
|
||||
self.model = cached["model"]
|
||||
self.tokenizer = cached["tokenizer"]
|
||||
|
|
|
|||
|
|
@ -24,8 +24,9 @@ import pickle
|
|||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger("paperless.ml.model_cache")
|
||||
|
||||
|
|
@ -58,7 +59,7 @@ class CacheMetrics:
|
|||
with self.lock:
|
||||
self.loads += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
with self.lock:
|
||||
total = self.hits + self.misses
|
||||
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
|
||||
|
|
@ -98,7 +99,7 @@ class LRUCache:
|
|||
self.lock = threading.Lock()
|
||||
self.metrics = CacheMetrics()
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""
|
||||
Get item from cache.
|
||||
|
||||
|
|
@ -153,7 +154,7 @@ class LRUCache:
|
|||
with self.lock:
|
||||
return len(self.cache)
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""Get cache metrics."""
|
||||
return self.metrics.get_stats()
|
||||
|
||||
|
|
@ -173,7 +174,7 @@ class ModelCacheManager:
|
|||
model = cache.get_or_load_model("classifier", loader_func)
|
||||
"""
|
||||
|
||||
_instance: Optional[ModelCacheManager] = None
|
||||
_instance: ModelCacheManager | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
|
@ -187,7 +188,7 @@ class ModelCacheManager:
|
|||
def __init__(
|
||||
self,
|
||||
max_models: int = 3,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
disk_cache_dir: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize model cache manager.
|
||||
|
|
@ -215,7 +216,7 @@ class ModelCacheManager:
|
|||
def get_instance(
|
||||
cls,
|
||||
max_models: int = 3,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
disk_cache_dir: str | None = None,
|
||||
) -> ModelCacheManager:
|
||||
"""
|
||||
Get singleton instance of ModelCacheManager.
|
||||
|
|
@ -278,7 +279,7 @@ class ModelCacheManager:
|
|||
load_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Model loaded successfully: {model_key} "
|
||||
f"(took {load_time:.2f}s)"
|
||||
f"(took {load_time:.2f}s)",
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
@ -289,7 +290,7 @@ class ModelCacheManager:
|
|||
def save_embeddings_to_disk(
|
||||
self,
|
||||
key: str,
|
||||
embeddings: Dict[int, Any],
|
||||
embeddings: dict[int, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Save embeddings to disk cache.
|
||||
|
|
@ -311,7 +312,7 @@ class ModelCacheManager:
|
|||
cache_file = self.disk_cache_dir / f"{key}.pkl"
|
||||
|
||||
try:
|
||||
with open(cache_file, 'wb') as f:
|
||||
with open(cache_file, "wb") as f:
|
||||
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}")
|
||||
return True
|
||||
|
|
@ -330,7 +331,7 @@ class ModelCacheManager:
|
|||
def load_embeddings_from_disk(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[Dict[int, Any]]:
|
||||
) -> dict[int, Any] | None:
|
||||
"""
|
||||
Load embeddings from disk cache.
|
||||
|
||||
|
|
@ -344,7 +345,7 @@ class ModelCacheManager:
|
|||
return None
|
||||
|
||||
cache_file = self.disk_cache_dir / f"{key}.pkl"
|
||||
|
||||
|
||||
if not cache_file.exists():
|
||||
return None
|
||||
|
||||
|
|
@ -384,7 +385,7 @@ class ModelCacheManager:
|
|||
def clear_all(self) -> None:
|
||||
"""Clear all caches (memory and disk)."""
|
||||
self.model_cache.clear()
|
||||
|
||||
|
||||
if self.disk_cache_dir and self.disk_cache_dir.exists():
|
||||
for cache_file in self.disk_cache_dir.glob("*.pkl"):
|
||||
try:
|
||||
|
|
@ -393,7 +394,7 @@ class ModelCacheManager:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to delete {cache_file}: {e}")
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache performance metrics.
|
||||
|
||||
|
|
@ -403,20 +404,20 @@ class ModelCacheManager:
|
|||
metrics = self.model_cache.get_metrics()
|
||||
metrics["cache_size"] = self.model_cache.size()
|
||||
metrics["max_size"] = self.model_cache.max_size
|
||||
|
||||
|
||||
if self.disk_cache_dir and self.disk_cache_dir.exists():
|
||||
disk_files = list(self.disk_cache_dir.glob("*.pkl"))
|
||||
metrics["disk_cache_files"] = len(disk_files)
|
||||
|
||||
|
||||
# Calculate total disk cache size
|
||||
total_size = sum(f.stat().st_size for f in disk_files)
|
||||
metrics["disk_cache_size_mb"] = f"{total_size / 1024 / 1024:.2f}"
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
def warm_up(
|
||||
self,
|
||||
model_loaders: Dict[str, Callable[[], Any]],
|
||||
model_loaders: dict[str, Callable[[], Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Pre-load models on startup (warm-up).
|
||||
|
|
@ -426,12 +427,12 @@ class ModelCacheManager:
|
|||
"""
|
||||
logger.info(f"Starting model warm-up ({len(model_loaders)} models)...")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
for model_key, loader_func in model_loaders.items():
|
||||
try:
|
||||
self.get_or_load_model(model_key, loader_func)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to warm-up model {model_key}: {e}")
|
||||
|
||||
|
||||
warm_up_time = time.time() - start_time
|
||||
logger.info(f"Model warm-up completed in {warm_up_time:.2f}s")
|
||||
|
|
|
|||
|
|
@ -14,15 +14,11 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger("paperless.ml.ner")
|
||||
|
||||
|
||||
|
|
@ -69,10 +65,10 @@ class DocumentNER:
|
|||
self.model_name = model_name
|
||||
self.use_cache = use_cache
|
||||
self.cache_manager = ModelCacheManager.get_instance() if use_cache else None
|
||||
|
||||
|
||||
# Cache key for this model
|
||||
cache_key = f"ner_{model_name}"
|
||||
|
||||
|
||||
if self.use_cache and self.cache_manager:
|
||||
# Load from cache or create new
|
||||
def loader():
|
||||
|
|
@ -81,7 +77,7 @@ class DocumentNER:
|
|||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
)
|
||||
|
||||
|
||||
self.ner_pipeline = self.cache_manager.get_or_load_model(
|
||||
cache_key,
|
||||
loader,
|
||||
|
|
|
|||
|
|
@ -18,18 +18,14 @@ Examples:
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers import util
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger("paperless.ml.semantic_search")
|
||||
|
||||
|
||||
|
|
@ -67,7 +63,7 @@ class SemanticSearch:
|
|||
"""
|
||||
logger.info(
|
||||
f"Initializing SemanticSearch with model: {model_name} "
|
||||
f"(caching: {use_cache})"
|
||||
f"(caching: {use_cache})",
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
|
@ -75,10 +71,10 @@ class SemanticSearch:
|
|||
self.cache_manager = ModelCacheManager.get_instance(
|
||||
disk_cache_dir=cache_dir,
|
||||
) if use_cache else None
|
||||
|
||||
|
||||
# Cache key for this model
|
||||
cache_key = f"semantic_search_{model_name}"
|
||||
|
||||
|
||||
if self.use_cache and self.cache_manager:
|
||||
# Load model from cache
|
||||
def loader():
|
||||
|
|
@ -127,11 +123,11 @@ class SemanticSearch:
|
|||
if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor):
|
||||
logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor")
|
||||
return False
|
||||
if hasattr(embedding, 'size'):
|
||||
if hasattr(embedding, "size"):
|
||||
if embedding.size == 0:
|
||||
logger.warning(f"Embedding for doc {doc_id} is empty")
|
||||
return False
|
||||
elif hasattr(embedding, 'numel'):
|
||||
elif hasattr(embedding, "numel"):
|
||||
if embedding.numel() == 0:
|
||||
logger.warning(f"Embedding for doc {doc_id} is empty")
|
||||
return False
|
||||
|
|
@ -216,11 +212,11 @@ class SemanticSearch:
|
|||
try:
|
||||
result = self.cache_manager.save_embeddings_to_disk(
|
||||
"document_embeddings",
|
||||
self.document_embeddings
|
||||
self.document_embeddings,
|
||||
)
|
||||
if result:
|
||||
logger.info(
|
||||
f"Successfully saved {len(self.document_embeddings)} embeddings to disk"
|
||||
f"Successfully saved {len(self.document_embeddings)} embeddings to disk",
|
||||
)
|
||||
else:
|
||||
logger.error("Failed to save embeddings to disk (returned False)")
|
||||
|
|
|
|||
|
|
@ -1596,59 +1596,59 @@ class DeletionRequest(models.Model):
|
|||
This ensures no documents are deleted without explicit user consent,
|
||||
implementing the safety requirement from agents.md.
|
||||
"""
|
||||
|
||||
|
||||
# Request metadata
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
|
||||
# Requester (AI system)
|
||||
requested_by_ai = models.BooleanField(default=True)
|
||||
ai_reason = models.TextField(
|
||||
help_text=_("Detailed explanation from AI about why deletion is recommended")
|
||||
help_text=_("Detailed explanation from AI about why deletion is recommended"),
|
||||
)
|
||||
|
||||
|
||||
# User who must approve
|
||||
user = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.CASCADE,
|
||||
related_name='deletion_requests',
|
||||
related_name="deletion_requests",
|
||||
help_text=_("User who must approve this deletion"),
|
||||
)
|
||||
|
||||
|
||||
# Status tracking
|
||||
STATUS_PENDING = 'pending'
|
||||
STATUS_APPROVED = 'approved'
|
||||
STATUS_REJECTED = 'rejected'
|
||||
STATUS_CANCELLED = 'cancelled'
|
||||
STATUS_COMPLETED = 'completed'
|
||||
|
||||
STATUS_PENDING = "pending"
|
||||
STATUS_APPROVED = "approved"
|
||||
STATUS_REJECTED = "rejected"
|
||||
STATUS_CANCELLED = "cancelled"
|
||||
STATUS_COMPLETED = "completed"
|
||||
|
||||
STATUS_CHOICES = [
|
||||
(STATUS_PENDING, _('Pending')),
|
||||
(STATUS_APPROVED, _('Approved')),
|
||||
(STATUS_REJECTED, _('Rejected')),
|
||||
(STATUS_CANCELLED, _('Cancelled')),
|
||||
(STATUS_COMPLETED, _('Completed')),
|
||||
(STATUS_PENDING, _("Pending")),
|
||||
(STATUS_APPROVED, _("Approved")),
|
||||
(STATUS_REJECTED, _("Rejected")),
|
||||
(STATUS_CANCELLED, _("Cancelled")),
|
||||
(STATUS_COMPLETED, _("Completed")),
|
||||
]
|
||||
|
||||
|
||||
status = models.CharField(
|
||||
max_length=20,
|
||||
choices=STATUS_CHOICES,
|
||||
default=STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
# Documents to be deleted
|
||||
documents = models.ManyToManyField(
|
||||
Document,
|
||||
related_name='deletion_requests',
|
||||
related_name="deletion_requests",
|
||||
help_text=_("Documents that would be deleted if approved"),
|
||||
)
|
||||
|
||||
|
||||
# Impact summary (JSON field with details)
|
||||
impact_summary = models.JSONField(
|
||||
default=dict,
|
||||
help_text=_("Summary of what will be affected by this deletion"),
|
||||
)
|
||||
|
||||
|
||||
# Approval tracking
|
||||
reviewed_at = models.DateTimeField(null=True, blank=True)
|
||||
reviewed_by = models.ForeignKey(
|
||||
|
|
@ -1656,43 +1656,43 @@ class DeletionRequest(models.Model):
|
|||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name='reviewed_deletion_requests',
|
||||
related_name="reviewed_deletion_requests",
|
||||
help_text=_("User who reviewed and approved/rejected"),
|
||||
)
|
||||
review_comment = models.TextField(
|
||||
blank=True,
|
||||
help_text=_("User's comment when reviewing"),
|
||||
)
|
||||
|
||||
|
||||
# Completion tracking
|
||||
completed_at = models.DateTimeField(null=True, blank=True)
|
||||
completion_details = models.JSONField(
|
||||
default=dict,
|
||||
help_text=_("Details about the deletion execution"),
|
||||
)
|
||||
|
||||
|
||||
class Meta:
|
||||
ordering = ['-created_at']
|
||||
ordering = ["-created_at"]
|
||||
verbose_name = _("deletion request")
|
||||
verbose_name_plural = _("deletion requests")
|
||||
indexes = [
|
||||
# Composite index for common listing queries (by user, filtered by status, sorted by date)
|
||||
# PostgreSQL can use this index for queries on: user, user+status, user+status+created_at
|
||||
models.Index(fields=['user', 'status', 'created_at'], name='delreq_user_status_created_idx'),
|
||||
models.Index(fields=["user", "status", "created_at"], name="delreq_user_status_created_idx"),
|
||||
# Index for queries filtering by status and date without user filter
|
||||
models.Index(fields=['status', 'created_at'], name='delreq_status_created_idx'),
|
||||
models.Index(fields=["status", "created_at"], name="delreq_status_created_idx"),
|
||||
# Index for queries filtering by user and date (common for user-specific views)
|
||||
models.Index(fields=['user', 'created_at'], name='delreq_user_created_idx'),
|
||||
models.Index(fields=["user", "created_at"], name="delreq_user_created_idx"),
|
||||
# Index for queries filtering by review date
|
||||
models.Index(fields=['reviewed_at'], name='delreq_reviewed_at_idx'),
|
||||
models.Index(fields=["reviewed_at"], name="delreq_reviewed_at_idx"),
|
||||
# Index for queries filtering by completion date
|
||||
models.Index(fields=['completed_at'], name='delreq_completed_at_idx'),
|
||||
models.Index(fields=["completed_at"], name="delreq_completed_at_idx"),
|
||||
]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
doc_count = self.documents.count()
|
||||
return f"Deletion Request {self.id} - {doc_count} documents - {self.status}"
|
||||
|
||||
|
||||
def approve(self, user: User, comment: str = "") -> bool:
|
||||
"""
|
||||
Approve the deletion request.
|
||||
|
|
@ -1706,15 +1706,15 @@ class DeletionRequest(models.Model):
|
|||
"""
|
||||
if self.status != self.STATUS_PENDING:
|
||||
return False
|
||||
|
||||
|
||||
self.status = self.STATUS_APPROVED
|
||||
self.reviewed_by = user
|
||||
self.reviewed_at = timezone.now()
|
||||
self.review_comment = comment
|
||||
self.save()
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def reject(self, user: User, comment: str = "") -> bool:
|
||||
"""
|
||||
Reject the deletion request.
|
||||
|
|
@ -1728,13 +1728,13 @@ class DeletionRequest(models.Model):
|
|||
"""
|
||||
if self.status != self.STATUS_PENDING:
|
||||
return False
|
||||
|
||||
|
||||
self.status = self.STATUS_REJECTED
|
||||
self.reviewed_by = user
|
||||
self.reviewed_at = timezone.now()
|
||||
self.review_comment = comment
|
||||
self.save()
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1743,109 +1743,109 @@ class AISuggestionFeedback(models.Model):
|
|||
Model to track user feedback on AI suggestions (applied/rejected).
|
||||
Used for improving AI accuracy and providing statistics.
|
||||
"""
|
||||
|
||||
|
||||
# Suggestion types
|
||||
TYPE_TAG = 'tag'
|
||||
TYPE_CORRESPONDENT = 'correspondent'
|
||||
TYPE_DOCUMENT_TYPE = 'document_type'
|
||||
TYPE_STORAGE_PATH = 'storage_path'
|
||||
TYPE_CUSTOM_FIELD = 'custom_field'
|
||||
TYPE_WORKFLOW = 'workflow'
|
||||
TYPE_TITLE = 'title'
|
||||
|
||||
TYPE_TAG = "tag"
|
||||
TYPE_CORRESPONDENT = "correspondent"
|
||||
TYPE_DOCUMENT_TYPE = "document_type"
|
||||
TYPE_STORAGE_PATH = "storage_path"
|
||||
TYPE_CUSTOM_FIELD = "custom_field"
|
||||
TYPE_WORKFLOW = "workflow"
|
||||
TYPE_TITLE = "title"
|
||||
|
||||
SUGGESTION_TYPES = (
|
||||
(TYPE_TAG, _('Tag')),
|
||||
(TYPE_CORRESPONDENT, _('Correspondent')),
|
||||
(TYPE_DOCUMENT_TYPE, _('Document Type')),
|
||||
(TYPE_STORAGE_PATH, _('Storage Path')),
|
||||
(TYPE_CUSTOM_FIELD, _('Custom Field')),
|
||||
(TYPE_WORKFLOW, _('Workflow')),
|
||||
(TYPE_TITLE, _('Title')),
|
||||
(TYPE_TAG, _("Tag")),
|
||||
(TYPE_CORRESPONDENT, _("Correspondent")),
|
||||
(TYPE_DOCUMENT_TYPE, _("Document Type")),
|
||||
(TYPE_STORAGE_PATH, _("Storage Path")),
|
||||
(TYPE_CUSTOM_FIELD, _("Custom Field")),
|
||||
(TYPE_WORKFLOW, _("Workflow")),
|
||||
(TYPE_TITLE, _("Title")),
|
||||
)
|
||||
|
||||
|
||||
# Feedback status
|
||||
STATUS_APPLIED = 'applied'
|
||||
STATUS_REJECTED = 'rejected'
|
||||
|
||||
STATUS_APPLIED = "applied"
|
||||
STATUS_REJECTED = "rejected"
|
||||
|
||||
FEEDBACK_STATUS = (
|
||||
(STATUS_APPLIED, _('Applied')),
|
||||
(STATUS_REJECTED, _('Rejected')),
|
||||
(STATUS_APPLIED, _("Applied")),
|
||||
(STATUS_REJECTED, _("Rejected")),
|
||||
)
|
||||
|
||||
|
||||
document = models.ForeignKey(
|
||||
Document,
|
||||
on_delete=models.CASCADE,
|
||||
related_name='ai_suggestion_feedbacks',
|
||||
verbose_name=_('document'),
|
||||
related_name="ai_suggestion_feedbacks",
|
||||
verbose_name=_("document"),
|
||||
)
|
||||
|
||||
|
||||
suggestion_type = models.CharField(
|
||||
_('suggestion type'),
|
||||
_("suggestion type"),
|
||||
max_length=50,
|
||||
choices=SUGGESTION_TYPES,
|
||||
)
|
||||
|
||||
|
||||
suggested_value_id = models.IntegerField(
|
||||
_('suggested value ID'),
|
||||
_("suggested value ID"),
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text=_('ID of the suggested object (tag, correspondent, etc.)'),
|
||||
help_text=_("ID of the suggested object (tag, correspondent, etc.)"),
|
||||
)
|
||||
|
||||
|
||||
suggested_value_text = models.TextField(
|
||||
_('suggested value text'),
|
||||
_("suggested value text"),
|
||||
blank=True,
|
||||
help_text=_('Text representation of the suggested value'),
|
||||
help_text=_("Text representation of the suggested value"),
|
||||
)
|
||||
|
||||
|
||||
confidence = models.FloatField(
|
||||
_('confidence'),
|
||||
help_text=_('AI confidence score (0.0 to 1.0)'),
|
||||
_("confidence"),
|
||||
help_text=_("AI confidence score (0.0 to 1.0)"),
|
||||
validators=[MinValueValidator(0.0), MaxValueValidator(1.0)],
|
||||
)
|
||||
|
||||
|
||||
status = models.CharField(
|
||||
_('status'),
|
||||
_("status"),
|
||||
max_length=20,
|
||||
choices=FEEDBACK_STATUS,
|
||||
)
|
||||
|
||||
|
||||
user = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name='ai_suggestion_feedbacks',
|
||||
verbose_name=_('user'),
|
||||
help_text=_('User who applied or rejected the suggestion'),
|
||||
related_name="ai_suggestion_feedbacks",
|
||||
verbose_name=_("user"),
|
||||
help_text=_("User who applied or rejected the suggestion"),
|
||||
)
|
||||
|
||||
|
||||
created_at = models.DateTimeField(
|
||||
_('created at'),
|
||||
_("created at"),
|
||||
auto_now_add=True,
|
||||
)
|
||||
|
||||
|
||||
applied_at = models.DateTimeField(
|
||||
_('applied/rejected at'),
|
||||
_("applied/rejected at"),
|
||||
auto_now=True,
|
||||
)
|
||||
|
||||
|
||||
metadata = models.JSONField(
|
||||
_('metadata'),
|
||||
_("metadata"),
|
||||
default=dict,
|
||||
blank=True,
|
||||
help_text=_('Additional metadata about the suggestion'),
|
||||
help_text=_("Additional metadata about the suggestion"),
|
||||
)
|
||||
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('AI suggestion feedback')
|
||||
verbose_name_plural = _('AI suggestion feedbacks')
|
||||
ordering = ['-created_at']
|
||||
verbose_name = _("AI suggestion feedback")
|
||||
verbose_name_plural = _("AI suggestion feedbacks")
|
||||
ordering = ["-created_at"]
|
||||
indexes = [
|
||||
models.Index(fields=['document', 'suggestion_type']),
|
||||
models.Index(fields=['status', 'created_at']),
|
||||
models.Index(fields=['suggestion_type', 'status']),
|
||||
models.Index(fields=["document", "suggestion_type"]),
|
||||
models.Index(fields=["status", "created_at"]),
|
||||
models.Index(fields=["suggestion_type", "status"]),
|
||||
]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.suggestion_type} suggestion for document {self.document_id} - {self.status}"
|
||||
|
|
|
|||
|
|
@ -11,21 +11,21 @@ Lazy imports are used to avoid loading heavy dependencies unless needed.
|
|||
"""
|
||||
|
||||
__all__ = [
|
||||
'TableExtractor',
|
||||
'HandwritingRecognizer',
|
||||
'FormFieldDetector',
|
||||
"FormFieldDetector",
|
||||
"HandwritingRecognizer",
|
||||
"TableExtractor",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy import to avoid loading heavy ML models on startup."""
|
||||
if name == 'TableExtractor':
|
||||
if name == "TableExtractor":
|
||||
from .table_extractor import TableExtractor
|
||||
return TableExtractor
|
||||
elif name == 'HandwritingRecognizer':
|
||||
elif name == "HandwritingRecognizer":
|
||||
from .handwriting import HandwritingRecognizer
|
||||
return HandwritingRecognizer
|
||||
elif name == 'FormFieldDetector':
|
||||
elif name == "FormFieldDetector":
|
||||
from .form_detector import FormFieldDetector
|
||||
return FormFieldDetector
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ This module provides capabilities to:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ class FormFieldDetector:
|
|||
>>> for cb in checkboxes:
|
||||
... print(f"{cb['label']}: {'✓' if cb['checked'] else '☐'}")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, use_gpu: bool = True):
|
||||
"""
|
||||
Initialize the form field detector.
|
||||
|
|
@ -47,20 +47,20 @@ class FormFieldDetector:
|
|||
"""
|
||||
self.use_gpu = use_gpu
|
||||
self._handwriting_recognizer = None
|
||||
|
||||
|
||||
def _get_handwriting_recognizer(self):
|
||||
"""Lazy load handwriting recognizer for field value extraction."""
|
||||
if self._handwriting_recognizer is None:
|
||||
from .handwriting import HandwritingRecognizer
|
||||
self._handwriting_recognizer = HandwritingRecognizer(use_gpu=self.use_gpu)
|
||||
return self._handwriting_recognizer
|
||||
|
||||
|
||||
def detect_checkboxes(
|
||||
self,
|
||||
self,
|
||||
image: Image.Image,
|
||||
min_size: int = 10,
|
||||
max_size: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
max_size: int = 50,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect checkboxes in a form image.
|
||||
|
||||
|
|
@ -82,54 +82,54 @@ class FormFieldDetector:
|
|||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
|
||||
# Convert to OpenCV format
|
||||
img_array = np.array(image)
|
||||
if len(img_array.shape) == 3:
|
||||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_array
|
||||
|
||||
|
||||
# Detect edges
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
|
||||
|
||||
# Find contours
|
||||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
|
||||
checkboxes = []
|
||||
for contour in contours:
|
||||
# Get bounding box
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
|
||||
# Check if it looks like a checkbox (square-ish, right size)
|
||||
aspect_ratio = w / h if h > 0 else 0
|
||||
if (min_size <= w <= max_size and
|
||||
min_size <= h <= max_size and
|
||||
if (min_size <= w <= max_size and
|
||||
min_size <= h <= max_size and
|
||||
0.7 <= aspect_ratio <= 1.3):
|
||||
|
||||
|
||||
# Extract checkbox region
|
||||
checkbox_region = gray[y:y+h, x:x+w]
|
||||
|
||||
|
||||
# Determine if checked (look for marks inside)
|
||||
checked, confidence = self._is_checkbox_checked(checkbox_region)
|
||||
|
||||
|
||||
checkboxes.append({
|
||||
'bbox': [x, y, x+w, y+h],
|
||||
'checked': checked,
|
||||
'confidence': confidence
|
||||
"bbox": [x, y, x+w, y+h],
|
||||
"checked": checked,
|
||||
"confidence": confidence,
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"Detected {len(checkboxes)} checkboxes")
|
||||
return checkboxes
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("opencv-python not installed. Install with: pip install opencv-python")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting checkboxes: {e}")
|
||||
return []
|
||||
|
||||
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> Tuple[bool, float]:
|
||||
|
||||
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> tuple[bool, float]:
|
||||
"""
|
||||
Determine if a checkbox is checked.
|
||||
|
||||
|
|
@ -141,34 +141,34 @@ class FormFieldDetector:
|
|||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
|
||||
# Binarize
|
||||
_, binary = cv2.threshold(checkbox_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
|
||||
|
||||
# Count dark pixels in the center region (where mark would be)
|
||||
h, w = binary.shape
|
||||
center_region = binary[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)]
|
||||
|
||||
|
||||
if center_region.size == 0:
|
||||
return False, 0.0
|
||||
|
||||
|
||||
dark_pixel_ratio = np.sum(center_region > 0) / center_region.size
|
||||
|
||||
|
||||
# If more than 15% of center is dark, consider it checked
|
||||
checked = dark_pixel_ratio > 0.15
|
||||
confidence = min(dark_pixel_ratio * 2, 1.0) # Scale confidence
|
||||
|
||||
|
||||
return checked, confidence
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking checkbox state: {e}")
|
||||
return False, 0.0
|
||||
|
||||
|
||||
def detect_text_fields(
|
||||
self,
|
||||
self,
|
||||
image: Image.Image,
|
||||
min_width: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
min_width: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect text input fields in a form.
|
||||
|
||||
|
|
@ -188,73 +188,73 @@ class FormFieldDetector:
|
|||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
|
||||
# Convert to OpenCV format
|
||||
img_array = np.array(image)
|
||||
if len(img_array.shape) == 3:
|
||||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_array
|
||||
|
||||
|
||||
# Detect horizontal lines (underlines for text fields)
|
||||
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (min_width, 1))
|
||||
detect_horizontal = cv2.morphologyEx(
|
||||
cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1],
|
||||
cv2.MORPH_OPEN,
|
||||
horizontal_kernel,
|
||||
iterations=2
|
||||
iterations=2,
|
||||
)
|
||||
|
||||
|
||||
# Find contours of horizontal lines
|
||||
contours, _ = cv2.findContours(
|
||||
detect_horizontal,
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
detect_horizontal,
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE,
|
||||
)
|
||||
|
||||
|
||||
text_fields = []
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
|
||||
# Check if it's a horizontal line (field underline)
|
||||
if w >= min_width and h < 10:
|
||||
# Expand upward to include text area
|
||||
text_bbox = [x, max(0, y-30), x+w, y+h]
|
||||
text_fields.append({
|
||||
'bbox': text_bbox,
|
||||
'type': 'line'
|
||||
"bbox": text_bbox,
|
||||
"type": "line",
|
||||
})
|
||||
|
||||
|
||||
# Detect rectangular boxes (bordered text fields)
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
|
||||
# Check if it's a rectangular box
|
||||
aspect_ratio = w / h if h > 0 else 0
|
||||
if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2:
|
||||
text_fields.append({
|
||||
'bbox': [x, y, x+w, y+h],
|
||||
'type': 'box'
|
||||
"bbox": [x, y, x+w, y+h],
|
||||
"type": "box",
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"Detected {len(text_fields)} text fields")
|
||||
return text_fields
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("opencv-python not installed")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting text fields: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def detect_labels(
|
||||
self,
|
||||
self,
|
||||
image: Image.Image,
|
||||
field_bboxes: List[List[int]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
field_bboxes: list[list[int]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect labels near form fields.
|
||||
|
||||
|
|
@ -267,47 +267,47 @@ class FormFieldDetector:
|
|||
"""
|
||||
try:
|
||||
import pytesseract
|
||||
|
||||
|
||||
# Get all text with bounding boxes
|
||||
ocr_data = pytesseract.image_to_data(
|
||||
image,
|
||||
output_type=pytesseract.Output.DICT
|
||||
image,
|
||||
output_type=pytesseract.Output.DICT,
|
||||
)
|
||||
|
||||
|
||||
# Group text into potential labels
|
||||
labels = []
|
||||
for i, text in enumerate(ocr_data['text']):
|
||||
for i, text in enumerate(ocr_data["text"]):
|
||||
if text.strip() and len(text.strip()) > 2:
|
||||
x = ocr_data['left'][i]
|
||||
y = ocr_data['top'][i]
|
||||
w = ocr_data['width'][i]
|
||||
h = ocr_data['height'][i]
|
||||
|
||||
x = ocr_data["left"][i]
|
||||
y = ocr_data["top"][i]
|
||||
w = ocr_data["width"][i]
|
||||
h = ocr_data["height"][i]
|
||||
|
||||
label_bbox = [x, y, x+w, y+h]
|
||||
|
||||
|
||||
# Find closest field
|
||||
closest_field_idx = self._find_closest_field(label_bbox, field_bboxes)
|
||||
|
||||
|
||||
labels.append({
|
||||
'text': text.strip(),
|
||||
'bbox': label_bbox,
|
||||
'field_index': closest_field_idx
|
||||
"text": text.strip(),
|
||||
"bbox": label_bbox,
|
||||
"field_index": closest_field_idx,
|
||||
})
|
||||
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("pytesseract not installed")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting labels: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _find_closest_field(
|
||||
self,
|
||||
label_bbox: List[int],
|
||||
field_bboxes: List[List[int]]
|
||||
) -> Optional[int]:
|
||||
self,
|
||||
label_bbox: list[int],
|
||||
field_bboxes: list[list[int]],
|
||||
) -> int | None:
|
||||
"""
|
||||
Find the closest field to a label.
|
||||
|
||||
|
|
@ -320,36 +320,36 @@ class FormFieldDetector:
|
|||
"""
|
||||
if not field_bboxes:
|
||||
return None
|
||||
|
||||
|
||||
# Calculate center of label
|
||||
label_center_x = (label_bbox[0] + label_bbox[2]) / 2
|
||||
label_center_y = (label_bbox[1] + label_bbox[3]) / 2
|
||||
|
||||
min_distance = float('inf')
|
||||
|
||||
min_distance = float("inf")
|
||||
closest_idx = 0
|
||||
|
||||
|
||||
for i, field_bbox in enumerate(field_bboxes):
|
||||
# Calculate center of field
|
||||
field_center_x = (field_bbox[0] + field_bbox[2]) / 2
|
||||
field_center_y = (field_bbox[1] + field_bbox[3]) / 2
|
||||
|
||||
|
||||
# Euclidean distance
|
||||
distance = np.sqrt(
|
||||
(label_center_x - field_center_x)**2 +
|
||||
(label_center_y - field_center_y)**2
|
||||
(label_center_x - field_center_x)**2 +
|
||||
(label_center_y - field_center_y)**2,
|
||||
)
|
||||
|
||||
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
closest_idx = i
|
||||
|
||||
|
||||
return closest_idx
|
||||
|
||||
|
||||
def detect_form_fields(
|
||||
self,
|
||||
self,
|
||||
image_path: str,
|
||||
extract_values: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
extract_values: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect all form fields and extract their values.
|
||||
|
||||
|
|
@ -372,69 +372,69 @@ class FormFieldDetector:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Detect different field types
|
||||
text_fields = self.detect_text_fields(image)
|
||||
checkboxes = self.detect_checkboxes(image)
|
||||
|
||||
|
||||
# Combine all field bboxes for label detection
|
||||
all_field_bboxes = [f['bbox'] for f in text_fields] + [cb['bbox'] for cb in checkboxes]
|
||||
|
||||
all_field_bboxes = [f["bbox"] for f in text_fields] + [cb["bbox"] for cb in checkboxes]
|
||||
|
||||
# Detect labels
|
||||
labels = self.detect_labels(image, all_field_bboxes)
|
||||
|
||||
|
||||
# Build results
|
||||
results = []
|
||||
|
||||
|
||||
# Add text fields
|
||||
for i, field in enumerate(text_fields):
|
||||
# Find associated label
|
||||
label_text = self._find_label_for_field(i, labels, len(text_fields))
|
||||
|
||||
|
||||
result = {
|
||||
'type': 'text',
|
||||
'label': label_text,
|
||||
'bbox': field['bbox'],
|
||||
"type": "text",
|
||||
"label": label_text,
|
||||
"bbox": field["bbox"],
|
||||
}
|
||||
|
||||
|
||||
# Extract value if requested
|
||||
if extract_values:
|
||||
x1, y1, x2, y2 = field['bbox']
|
||||
x1, y1, x2, y2 = field["bbox"]
|
||||
field_image = image.crop((x1, y1, x2, y2))
|
||||
|
||||
|
||||
recognizer = self._get_handwriting_recognizer()
|
||||
value = recognizer.recognize_from_image(field_image, preprocess=True)
|
||||
result['value'] = value.strip()
|
||||
result['confidence'] = recognizer._estimate_confidence(value)
|
||||
|
||||
result["value"] = value.strip()
|
||||
result["confidence"] = recognizer._estimate_confidence(value)
|
||||
|
||||
results.append(result)
|
||||
|
||||
|
||||
# Add checkboxes
|
||||
for i, checkbox in enumerate(checkboxes):
|
||||
field_idx = len(text_fields) + i
|
||||
label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes))
|
||||
|
||||
|
||||
results.append({
|
||||
'type': 'checkbox',
|
||||
'label': label_text,
|
||||
'value': checkbox['checked'],
|
||||
'bbox': checkbox['bbox'],
|
||||
'confidence': checkbox['confidence']
|
||||
"type": "checkbox",
|
||||
"label": label_text,
|
||||
"value": checkbox["checked"],
|
||||
"bbox": checkbox["bbox"],
|
||||
"confidence": checkbox["confidence"],
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"Detected {len(results)} form fields from {image_path}")
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting form fields: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _find_label_for_field(
|
||||
self,
|
||||
field_idx: int,
|
||||
labels: List[Dict[str, Any]],
|
||||
total_fields: int
|
||||
self,
|
||||
field_idx: int,
|
||||
labels: list[dict[str, Any]],
|
||||
total_fields: int,
|
||||
) -> str:
|
||||
"""
|
||||
Find the label text for a specific field.
|
||||
|
|
@ -448,20 +448,20 @@ class FormFieldDetector:
|
|||
Label text or empty string if not found
|
||||
"""
|
||||
matching_labels = [
|
||||
label for label in labels
|
||||
if label['field_index'] == field_idx
|
||||
label for label in labels
|
||||
if label["field_index"] == field_idx
|
||||
]
|
||||
|
||||
|
||||
if matching_labels:
|
||||
# Combine multiple label parts if found
|
||||
return ' '.join(label['text'] for label in matching_labels)
|
||||
|
||||
return " ".join(label["text"] for label in matching_labels)
|
||||
|
||||
return f"Field_{field_idx + 1}"
|
||||
|
||||
|
||||
def extract_form_data(
|
||||
self,
|
||||
self,
|
||||
image_path: str,
|
||||
output_format: str = 'dict'
|
||||
output_format: str = "dict",
|
||||
) -> Any:
|
||||
"""
|
||||
Extract all form data as structured output.
|
||||
|
|
@ -475,19 +475,19 @@ class FormFieldDetector:
|
|||
"""
|
||||
# Detect and extract fields
|
||||
fields = self.detect_form_fields(image_path, extract_values=True)
|
||||
|
||||
if output_format == 'dict':
|
||||
|
||||
if output_format == "dict":
|
||||
# Return as dictionary
|
||||
return {field['label']: field['value'] for field in fields}
|
||||
|
||||
elif output_format == 'json':
|
||||
return {field["label"]: field["value"] for field in fields}
|
||||
|
||||
elif output_format == "json":
|
||||
import json
|
||||
data = {field['label']: field['value'] for field in fields}
|
||||
data = {field["label"]: field["value"] for field in fields}
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
elif output_format == 'dataframe':
|
||||
|
||||
elif output_format == "dataframe":
|
||||
import pandas as pd
|
||||
return pd.DataFrame(fields)
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid output format: {output_format}")
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ This module provides handwriting OCR capabilities using:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ class HandwritingRecognizer:
|
|||
>>> for line in lines:
|
||||
... print(f"{line['text']} (confidence: {line['confidence']:.2f})")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "microsoft/trocr-base-handwritten",
|
||||
|
|
@ -58,39 +58,40 @@ class HandwritingRecognizer:
|
|||
self.confidence_threshold = confidence_threshold
|
||||
self._model = None
|
||||
self._processor = None
|
||||
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load the handwriting recognition model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
import torch
|
||||
|
||||
from transformers import TrOCRProcessor
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading handwriting recognition model: {self.model_name}")
|
||||
|
||||
|
||||
self._processor = TrOCRProcessor.from_pretrained(self.model_name)
|
||||
self._model = VisionEncoderDecoderModel.from_pretrained(self.model_name)
|
||||
|
||||
|
||||
# Move to GPU if available and requested
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self._model = self._model.cuda()
|
||||
logger.info("Using GPU for handwriting recognition")
|
||||
else:
|
||||
logger.info("Using CPU for handwriting recognition")
|
||||
|
||||
|
||||
self._model.eval() # Set to evaluation mode
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to load handwriting model: {e}")
|
||||
logger.error("Please install: pip install transformers torch pillow")
|
||||
raise
|
||||
|
||||
|
||||
def recognize_from_image(
|
||||
self,
|
||||
self,
|
||||
image: Image.Image,
|
||||
preprocess: bool = True
|
||||
preprocess: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Recognize text from a single image.
|
||||
|
|
@ -103,34 +104,34 @@ class HandwritingRecognizer:
|
|||
Recognized text string
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
|
||||
# Preprocess image if requested
|
||||
if preprocess:
|
||||
image = self._preprocess_image(image)
|
||||
|
||||
|
||||
# Prepare image for model
|
||||
pixel_values = self._processor(images=image, return_tensors="pt").pixel_values
|
||||
|
||||
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
pixel_values = pixel_values.cuda()
|
||||
|
||||
|
||||
# Generate text
|
||||
with torch.no_grad():
|
||||
generated_ids = self._model.generate(pixel_values)
|
||||
|
||||
|
||||
# Decode to text
|
||||
text = self._processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
logger.debug(f"Recognized text: {text[:100]}...")
|
||||
return text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recognizing handwriting: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def _preprocess_image(self, image: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Preprocess image for better recognition.
|
||||
|
|
@ -142,29 +143,30 @@ class HandwritingRecognizer:
|
|||
Preprocessed PIL Image
|
||||
"""
|
||||
try:
|
||||
from PIL import ImageEnhance, ImageFilter
|
||||
|
||||
from PIL import ImageEnhance
|
||||
from PIL import ImageFilter
|
||||
|
||||
# Convert to grayscale
|
||||
if image.mode != 'L':
|
||||
image = image.convert('L')
|
||||
|
||||
if image.mode != "L":
|
||||
image = image.convert("L")
|
||||
|
||||
# Enhance contrast
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
image = enhancer.enhance(2.0)
|
||||
|
||||
|
||||
# Denoise
|
||||
image = image.filter(ImageFilter.MedianFilter(size=3))
|
||||
|
||||
|
||||
# Convert back to RGB (required by model)
|
||||
image = image.convert('RGB')
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error preprocessing image: {e}")
|
||||
return image
|
||||
|
||||
def detect_text_lines(self, image: Image.Image) -> List[Dict[str, Any]]:
|
||||
|
||||
def detect_text_lines(self, image: Image.Image) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect individual text lines in an image.
|
||||
|
||||
|
|
@ -184,52 +186,52 @@ class HandwritingRecognizer:
|
|||
try:
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Convert PIL to OpenCV format
|
||||
img_array = np.array(image)
|
||||
if len(img_array.shape) == 3:
|
||||
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_array
|
||||
|
||||
|
||||
# Binarize
|
||||
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
|
||||
|
||||
# Find contours
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
|
||||
# Get bounding boxes for each contour
|
||||
lines = []
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
|
||||
# Filter out very small regions
|
||||
if w > 20 and h > 10:
|
||||
# Crop line from original image
|
||||
line_img = image.crop((x, y, x+w, y+h))
|
||||
lines.append({
|
||||
'bbox': [x, y, x+w, y+h],
|
||||
'image': line_img
|
||||
"bbox": [x, y, x+w, y+h],
|
||||
"image": line_img,
|
||||
})
|
||||
|
||||
|
||||
# Sort lines top to bottom
|
||||
lines.sort(key=lambda l: l['bbox'][1])
|
||||
|
||||
lines.sort(key=lambda l: l["bbox"][1])
|
||||
|
||||
logger.info(f"Detected {len(lines)} text lines")
|
||||
return lines
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("opencv-python not installed. Install with: pip install opencv-python")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting text lines: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def recognize_lines(
|
||||
self,
|
||||
self,
|
||||
image_path: str,
|
||||
return_confidence: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
return_confidence: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recognize text from each line in an image.
|
||||
|
||||
|
|
@ -250,38 +252,38 @@ class HandwritingRecognizer:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Detect lines
|
||||
lines = self.detect_text_lines(image)
|
||||
|
||||
|
||||
# Recognize each line
|
||||
results = []
|
||||
for i, line in enumerate(lines):
|
||||
logger.debug(f"Recognizing line {i+1}/{len(lines)}")
|
||||
|
||||
text = self.recognize_from_image(line['image'], preprocess=True)
|
||||
|
||||
|
||||
text = self.recognize_from_image(line["image"], preprocess=True)
|
||||
|
||||
result = {
|
||||
'text': text,
|
||||
'bbox': line['bbox'],
|
||||
'line_index': i
|
||||
"text": text,
|
||||
"bbox": line["bbox"],
|
||||
"line_index": i,
|
||||
}
|
||||
|
||||
|
||||
if return_confidence:
|
||||
# Simple confidence based on text length and content
|
||||
confidence = self._estimate_confidence(text)
|
||||
result['confidence'] = confidence
|
||||
|
||||
result["confidence"] = confidence
|
||||
|
||||
results.append(result)
|
||||
|
||||
|
||||
logger.info(f"Recognized {len(results)} lines from {image_path}")
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recognizing lines from {image_path}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _estimate_confidence(self, text: str) -> float:
|
||||
"""
|
||||
Estimate confidence of recognition result.
|
||||
|
|
@ -294,36 +296,36 @@ class HandwritingRecognizer:
|
|||
"""
|
||||
if not text:
|
||||
return 0.0
|
||||
|
||||
|
||||
# Factors that indicate good recognition
|
||||
score = 0.5 # Base score
|
||||
|
||||
|
||||
# Longer text tends to be more reliable
|
||||
if len(text) > 10:
|
||||
score += 0.1
|
||||
if len(text) > 20:
|
||||
score += 0.1
|
||||
|
||||
|
||||
# Text with alphanumeric characters is more reliable
|
||||
if any(c.isalnum() for c in text):
|
||||
score += 0.1
|
||||
|
||||
|
||||
# Text with spaces (words) is more reliable
|
||||
if ' ' in text:
|
||||
if " " in text:
|
||||
score += 0.1
|
||||
|
||||
|
||||
# Penalize if too many special characters
|
||||
special_chars = sum(1 for c in text if not c.isalnum() and not c.isspace())
|
||||
if special_chars / len(text) > 0.5:
|
||||
score -= 0.2
|
||||
|
||||
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
def recognize_from_file(
|
||||
self,
|
||||
self,
|
||||
image_path: str,
|
||||
mode: str = 'full'
|
||||
) -> Dict[str, Any]:
|
||||
mode: str = "full",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Recognize handwriting from an image file.
|
||||
|
||||
|
|
@ -337,49 +339,49 @@ class HandwritingRecognizer:
|
|||
Dictionary with recognized text and metadata
|
||||
"""
|
||||
try:
|
||||
if mode == 'full':
|
||||
if mode == "full":
|
||||
# Recognize entire image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
text = self.recognize_from_image(image, preprocess=True)
|
||||
|
||||
|
||||
return {
|
||||
'text': text,
|
||||
'mode': 'full',
|
||||
'confidence': self._estimate_confidence(text)
|
||||
"text": text,
|
||||
"mode": "full",
|
||||
"confidence": self._estimate_confidence(text),
|
||||
}
|
||||
|
||||
elif mode == 'lines':
|
||||
|
||||
elif mode == "lines":
|
||||
# Recognize line by line
|
||||
lines = self.recognize_lines(image_path, return_confidence=True)
|
||||
|
||||
|
||||
# Combine all lines
|
||||
full_text = '\n'.join(line['text'] for line in lines)
|
||||
avg_confidence = np.mean([line['confidence'] for line in lines]) if lines else 0.0
|
||||
|
||||
full_text = "\n".join(line["text"] for line in lines)
|
||||
avg_confidence = np.mean([line["confidence"] for line in lines]) if lines else 0.0
|
||||
|
||||
return {
|
||||
'text': full_text,
|
||||
'lines': lines,
|
||||
'mode': 'lines',
|
||||
'confidence': float(avg_confidence)
|
||||
"text": full_text,
|
||||
"lines": lines,
|
||||
"mode": "lines",
|
||||
"confidence": float(avg_confidence),
|
||||
}
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}. Use 'full' or 'lines'")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recognizing from file {image_path}: {e}")
|
||||
return {
|
||||
'text': '',
|
||||
'mode': mode,
|
||||
'confidence': 0.0,
|
||||
'error': str(e)
|
||||
"text": "",
|
||||
"mode": mode,
|
||||
"confidence": 0.0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def recognize_form_fields(
|
||||
self,
|
||||
self,
|
||||
image_path: str,
|
||||
field_regions: List[Dict[str, Any]]
|
||||
) -> Dict[str, str]:
|
||||
field_regions: list[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Recognize text from specific form fields.
|
||||
|
||||
|
|
@ -399,35 +401,35 @@ class HandwritingRecognizer:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Extract and recognize each field
|
||||
results = {}
|
||||
for field in field_regions:
|
||||
name = field['name']
|
||||
bbox = field['bbox']
|
||||
|
||||
name = field["name"]
|
||||
bbox = field["bbox"]
|
||||
|
||||
# Crop field region
|
||||
x1, y1, x2, y2 = bbox
|
||||
field_image = image.crop((x1, y1, x2, y2))
|
||||
|
||||
|
||||
# Recognize text
|
||||
text = self.recognize_from_image(field_image, preprocess=True)
|
||||
results[name] = text.strip()
|
||||
|
||||
|
||||
logger.debug(f"Field '{name}': {text[:50]}...")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recognizing form fields: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def batch_recognize(
|
||||
self,
|
||||
image_paths: List[str],
|
||||
mode: str = 'full'
|
||||
) -> List[Dict[str, Any]]:
|
||||
self,
|
||||
image_paths: list[str],
|
||||
mode: str = "full",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recognize handwriting from multiple images in batch.
|
||||
|
||||
|
|
@ -442,7 +444,7 @@ class HandwritingRecognizer:
|
|||
for i, path in enumerate(image_paths):
|
||||
logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
|
||||
result = self.recognize_from_file(path, mode=mode)
|
||||
result['image_path'] = path
|
||||
result["image_path"] = path
|
||||
results.append(result)
|
||||
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -8,9 +8,8 @@ This module uses various techniques to detect and extract tables from documents:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -32,7 +31,7 @@ class TableExtractor:
|
|||
... print(table['data']) # pandas DataFrame
|
||||
... print(table['bbox']) # bounding box coordinates
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "microsoft/table-transformer-detection",
|
||||
|
|
@ -52,34 +51,35 @@ class TableExtractor:
|
|||
self.use_gpu = use_gpu
|
||||
self._model = None
|
||||
self._processor = None
|
||||
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load the table detection model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
||||
import torch
|
||||
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers import AutoModelForObjectDetection
|
||||
|
||||
logger.info(f"Loading table detection model: {self.model_name}")
|
||||
|
||||
|
||||
self._processor = AutoImageProcessor.from_pretrained(self.model_name)
|
||||
self._model = AutoModelForObjectDetection.from_pretrained(self.model_name)
|
||||
|
||||
|
||||
# Move to GPU if available and requested
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self._model = self._model.cuda()
|
||||
logger.info("Using GPU for table detection")
|
||||
else:
|
||||
logger.info("Using CPU for table detection")
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to load table detection model: {e}")
|
||||
logger.error("Please install required packages: pip install transformers torch pillow")
|
||||
raise
|
||||
|
||||
def detect_tables(self, image: Image.Image) -> List[Dict[str, Any]]:
|
||||
|
||||
def detect_tables(self, image: Image.Image) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect tables in an image.
|
||||
|
||||
|
|
@ -98,50 +98,50 @@ class TableExtractor:
|
|||
]
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
|
||||
# Prepare image
|
||||
inputs = self._processor(images=image, return_tensors="pt")
|
||||
|
||||
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
||||
|
||||
# Run detection
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
|
||||
|
||||
# Post-process results
|
||||
target_sizes = torch.tensor([image.size[::-1]])
|
||||
results = self._processor.post_process_object_detection(
|
||||
outputs,
|
||||
outputs,
|
||||
threshold=self.confidence_threshold,
|
||||
target_sizes=target_sizes
|
||||
target_sizes=target_sizes,
|
||||
)[0]
|
||||
|
||||
|
||||
# Convert to list of dicts
|
||||
tables = []
|
||||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
tables.append({
|
||||
'bbox': box.cpu().tolist(),
|
||||
'score': score.item(),
|
||||
'label': self._model.config.id2label[label.item()]
|
||||
"bbox": box.cpu().tolist(),
|
||||
"score": score.item(),
|
||||
"label": self._model.config.id2label[label.item()],
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"Detected {len(tables)} tables in image")
|
||||
return tables
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting tables: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def extract_table_from_region(
|
||||
self,
|
||||
image: Image.Image,
|
||||
bbox: List[float],
|
||||
use_ocr: bool = True
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
self,
|
||||
image: Image.Image,
|
||||
bbox: list[float],
|
||||
use_ocr: bool = True,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract table data from a specific region of an image.
|
||||
|
||||
|
|
@ -158,48 +158,48 @@ class TableExtractor:
|
|||
# Crop to table region
|
||||
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
||||
table_image = image.crop((x1, y1, x2, y2))
|
||||
|
||||
|
||||
if use_ocr:
|
||||
# Use OCR to extract text and structure
|
||||
import pytesseract
|
||||
|
||||
|
||||
# Get detailed OCR data
|
||||
ocr_data = pytesseract.image_to_data(
|
||||
table_image,
|
||||
output_type=pytesseract.Output.DICT
|
||||
table_image,
|
||||
output_type=pytesseract.Output.DICT,
|
||||
)
|
||||
|
||||
|
||||
# Reconstruct table structure from OCR data
|
||||
table_data = self._reconstruct_table_from_ocr(ocr_data)
|
||||
|
||||
|
||||
# Also get raw text
|
||||
raw_text = pytesseract.image_to_string(table_image)
|
||||
|
||||
|
||||
return {
|
||||
'data': table_data,
|
||||
'raw_text': raw_text,
|
||||
'bbox': bbox,
|
||||
'image_size': table_image.size
|
||||
"data": table_data,
|
||||
"raw_text": raw_text,
|
||||
"bbox": bbox,
|
||||
"image_size": table_image.size,
|
||||
}
|
||||
else:
|
||||
# Fallback to basic OCR without structure
|
||||
import pytesseract
|
||||
raw_text = pytesseract.image_to_string(table_image)
|
||||
return {
|
||||
'data': None,
|
||||
'raw_text': raw_text,
|
||||
'bbox': bbox,
|
||||
'image_size': table_image.size
|
||||
"data": None,
|
||||
"raw_text": raw_text,
|
||||
"bbox": bbox,
|
||||
"image_size": table_image.size,
|
||||
}
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("pytesseract not installed. Install with: pip install pytesseract")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting table from region: {e}")
|
||||
return None
|
||||
|
||||
def _reconstruct_table_from_ocr(self, ocr_data: Dict) -> Optional[Any]:
|
||||
|
||||
def _reconstruct_table_from_ocr(self, ocr_data: dict) -> Any | None:
|
||||
"""
|
||||
Reconstruct table structure from OCR output.
|
||||
|
||||
|
|
@ -211,58 +211,58 @@ class TableExtractor:
|
|||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# Group text by vertical position (rows)
|
||||
rows = {}
|
||||
for i, text in enumerate(ocr_data['text']):
|
||||
for i, text in enumerate(ocr_data["text"]):
|
||||
if text.strip():
|
||||
top = ocr_data['top'][i]
|
||||
left = ocr_data['left'][i]
|
||||
|
||||
top = ocr_data["top"][i]
|
||||
left = ocr_data["left"][i]
|
||||
|
||||
# Group by approximate row (within 20 pixels)
|
||||
row_key = round(top / 20) * 20
|
||||
if row_key not in rows:
|
||||
rows[row_key] = []
|
||||
rows[row_key].append((left, text))
|
||||
|
||||
|
||||
# Sort rows and create DataFrame
|
||||
table_rows = []
|
||||
for row_y in sorted(rows.keys()):
|
||||
# Sort cells by horizontal position
|
||||
cells = [text for _, text in sorted(rows[row_y])]
|
||||
table_rows.append(cells)
|
||||
|
||||
|
||||
if table_rows:
|
||||
# Pad rows to same length
|
||||
max_cols = max(len(row) for row in table_rows)
|
||||
table_rows = [row + [''] * (max_cols - len(row)) for row in table_rows]
|
||||
|
||||
table_rows = [row + [""] * (max_cols - len(row)) for row in table_rows]
|
||||
|
||||
# Create DataFrame
|
||||
df = pd.DataFrame(table_rows)
|
||||
|
||||
|
||||
# Try to use first row as header if it looks like one
|
||||
if len(df) > 1:
|
||||
first_row_text = ' '.join(str(x) for x in df.iloc[0])
|
||||
first_row_text = " ".join(str(x) for x in df.iloc[0])
|
||||
if not any(char.isdigit() for char in first_row_text):
|
||||
df.columns = df.iloc[0]
|
||||
df = df[1:].reset_index(drop=True)
|
||||
|
||||
|
||||
return df
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("pandas not installed. Install with: pip install pandas")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error reconstructing table: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_tables_from_image(
|
||||
self,
|
||||
self,
|
||||
image_path: str,
|
||||
output_format: str = 'dataframe'
|
||||
) -> List[Dict[str, Any]]:
|
||||
output_format: str = "dataframe",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract all tables from an image file.
|
||||
|
||||
|
|
@ -275,45 +275,45 @@ class TableExtractor:
|
|||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Detect tables
|
||||
detections = self.detect_tables(image)
|
||||
|
||||
|
||||
# Extract data from each table
|
||||
tables = []
|
||||
for i, detection in enumerate(detections):
|
||||
logger.info(f"Extracting table {i+1}/{len(detections)}")
|
||||
|
||||
|
||||
table_data = self.extract_table_from_region(
|
||||
image,
|
||||
detection['bbox']
|
||||
image,
|
||||
detection["bbox"],
|
||||
)
|
||||
|
||||
|
||||
if table_data:
|
||||
table_data['detection_score'] = detection['score']
|
||||
table_data['table_index'] = i
|
||||
|
||||
table_data["detection_score"] = detection["score"]
|
||||
table_data["table_index"] = i
|
||||
|
||||
# Convert to requested format
|
||||
if output_format == 'csv' and table_data['data'] is not None:
|
||||
table_data['csv'] = table_data['data'].to_csv(index=False)
|
||||
elif output_format == 'json' and table_data['data'] is not None:
|
||||
table_data['json'] = table_data['data'].to_json(orient='records')
|
||||
|
||||
if output_format == "csv" and table_data["data"] is not None:
|
||||
table_data["csv"] = table_data["data"].to_csv(index=False)
|
||||
elif output_format == "json" and table_data["data"] is not None:
|
||||
table_data["json"] = table_data["data"].to_json(orient="records")
|
||||
|
||||
tables.append(table_data)
|
||||
|
||||
|
||||
logger.info(f"Successfully extracted {len(tables)} tables from {image_path}")
|
||||
return tables
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting tables from image {image_path}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def extract_tables_from_pdf(
|
||||
self,
|
||||
self,
|
||||
pdf_path: str,
|
||||
page_numbers: Optional[List[int]] = None
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
page_numbers: list[int] | None = None,
|
||||
) -> dict[int, list[dict[str, Any]]]:
|
||||
"""
|
||||
Extract tables from a PDF document.
|
||||
|
||||
|
|
@ -326,56 +326,56 @@ class TableExtractor:
|
|||
"""
|
||||
try:
|
||||
from pdf2image import convert_from_path
|
||||
|
||||
|
||||
logger.info(f"Converting PDF to images: {pdf_path}")
|
||||
|
||||
|
||||
# Convert PDF pages to images
|
||||
if page_numbers:
|
||||
images = convert_from_path(
|
||||
pdf_path,
|
||||
pdf_path,
|
||||
first_page=min(page_numbers),
|
||||
last_page=max(page_numbers)
|
||||
last_page=max(page_numbers),
|
||||
)
|
||||
else:
|
||||
images = convert_from_path(pdf_path)
|
||||
|
||||
|
||||
# Extract tables from each page
|
||||
results = {}
|
||||
for i, image in enumerate(images):
|
||||
page_num = page_numbers[i] if page_numbers else i + 1
|
||||
logger.info(f"Processing page {page_num}")
|
||||
|
||||
|
||||
# Detect and extract tables
|
||||
detections = self.detect_tables(image)
|
||||
tables = []
|
||||
|
||||
|
||||
for detection in detections:
|
||||
table_data = self.extract_table_from_region(
|
||||
image,
|
||||
detection['bbox']
|
||||
image,
|
||||
detection["bbox"],
|
||||
)
|
||||
if table_data:
|
||||
table_data['detection_score'] = detection['score']
|
||||
table_data['page'] = page_num
|
||||
table_data["detection_score"] = detection["score"]
|
||||
table_data["page"] = page_num
|
||||
tables.append(table_data)
|
||||
|
||||
|
||||
if tables:
|
||||
results[page_num] = tables
|
||||
logger.info(f"Found {len(tables)} tables on page {page_num}")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("pdf2image not installed. Install with: pip install pdf2image")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting tables from PDF: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_tables_to_excel(
|
||||
self,
|
||||
tables: List[Dict[str, Any]],
|
||||
output_path: str
|
||||
self,
|
||||
tables: list[dict[str, Any]],
|
||||
output_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Save extracted tables to an Excel file.
|
||||
|
|
@ -389,23 +389,23 @@ class TableExtractor:
|
|||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
|
||||
|
||||
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
|
||||
for i, table in enumerate(tables):
|
||||
if table.get('data') is not None:
|
||||
if table.get("data") is not None:
|
||||
sheet_name = f"Table_{i+1}"
|
||||
if 'page' in table:
|
||||
if "page" in table:
|
||||
sheet_name = f"Page_{table['page']}_Table_{i+1}"
|
||||
|
||||
table['data'].to_excel(
|
||||
writer,
|
||||
sheet_name=sheet_name,
|
||||
index=False
|
||||
|
||||
table["data"].to_excel(
|
||||
writer,
|
||||
sheet_name=sheet_name,
|
||||
index=False,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Saved {len(tables)} tables to {output_path}")
|
||||
return True
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.error("openpyxl not installed. Install with: pip install openpyxl")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -233,11 +233,11 @@ class CanViewAISuggestionsPermission(BasePermission):
|
|||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
|
||||
# Superusers always have permission
|
||||
if request.user.is_superuser:
|
||||
return True
|
||||
|
||||
|
||||
# Check for specific permission
|
||||
return request.user.has_perm("documents.can_view_ai_suggestions")
|
||||
|
||||
|
|
@ -253,11 +253,11 @@ class CanApplyAISuggestionsPermission(BasePermission):
|
|||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
|
||||
# Superusers always have permission
|
||||
if request.user.is_superuser:
|
||||
return True
|
||||
|
||||
|
||||
# Check for specific permission
|
||||
return request.user.has_perm("documents.can_apply_ai_suggestions")
|
||||
|
||||
|
|
@ -273,11 +273,11 @@ class CanApproveDeletionsPermission(BasePermission):
|
|||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
|
||||
# Superusers always have permission
|
||||
if request.user.is_superuser:
|
||||
return True
|
||||
|
||||
|
||||
# Check for specific permission
|
||||
return request.user.has_perm("documents.can_approve_deletions")
|
||||
|
||||
|
|
@ -294,10 +294,10 @@ class CanConfigureAIPermission(BasePermission):
|
|||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
|
||||
# Superusers always have permission
|
||||
if request.user.is_superuser:
|
||||
return True
|
||||
|
||||
|
||||
# Check for specific permission
|
||||
return request.user.has_perm("documents.can_configure_ai")
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ if settings.AUDIT_LOG_ENABLED:
|
|||
from documents import bulk_edit
|
||||
from documents.data_models import DocumentSource
|
||||
from documents.filters import CustomFieldQueryParser
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import CustomFieldInstance
|
||||
|
|
@ -2786,57 +2785,57 @@ class DeletionRequestSerializer(serializers.ModelSerializer):
|
|||
|
||||
class DeletionRequestDetailSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for DeletionRequest model with document details."""
|
||||
|
||||
|
||||
document_details = serializers.SerializerMethodField()
|
||||
user_username = serializers.CharField(source='user.username', read_only=True)
|
||||
user_username = serializers.CharField(source="user.username", read_only=True)
|
||||
reviewed_by_username = serializers.CharField(
|
||||
source='reviewed_by.username',
|
||||
source="reviewed_by.username",
|
||||
read_only=True,
|
||||
allow_null=True,
|
||||
)
|
||||
|
||||
|
||||
class Meta:
|
||||
from documents.models import DeletionRequest
|
||||
model = DeletionRequest
|
||||
fields = [
|
||||
'id',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
'requested_by_ai',
|
||||
'ai_reason',
|
||||
'user',
|
||||
'user_username',
|
||||
'status',
|
||||
'impact_summary',
|
||||
'reviewed_at',
|
||||
'reviewed_by',
|
||||
'reviewed_by_username',
|
||||
'review_comment',
|
||||
'completed_at',
|
||||
'completion_details',
|
||||
'document_details',
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"requested_by_ai",
|
||||
"ai_reason",
|
||||
"user",
|
||||
"user_username",
|
||||
"status",
|
||||
"impact_summary",
|
||||
"reviewed_at",
|
||||
"reviewed_by",
|
||||
"reviewed_by_username",
|
||||
"review_comment",
|
||||
"completed_at",
|
||||
"completion_details",
|
||||
"document_details",
|
||||
]
|
||||
read_only_fields = [
|
||||
'id',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
'reviewed_at',
|
||||
'reviewed_by',
|
||||
'completed_at',
|
||||
'completion_details',
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"reviewed_at",
|
||||
"reviewed_by",
|
||||
"completed_at",
|
||||
"completion_details",
|
||||
]
|
||||
|
||||
|
||||
def get_document_details(self, obj):
|
||||
"""Get details of documents in this deletion request."""
|
||||
documents = obj.documents.all()
|
||||
return [
|
||||
{
|
||||
'id': doc.id,
|
||||
'title': doc.title,
|
||||
'created': doc.created.isoformat() if doc.created else None,
|
||||
'correspondent': doc.correspondent.name if doc.correspondent else None,
|
||||
'document_type': doc.document_type.name if doc.document_type else None,
|
||||
'tags': [tag.name for tag in doc.tags.all()],
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"created": doc.created.isoformat() if doc.created else None,
|
||||
"correspondent": doc.correspondent.name if doc.correspondent else None,
|
||||
"document_type": doc.document_type.name if doc.document_type else None,
|
||||
"tags": [tag.name for tag in doc.tags.all()],
|
||||
}
|
||||
for doc in documents
|
||||
]
|
||||
|
|
@ -2852,6 +2851,9 @@ class DeletionRequestActionSerializer(serializers.Serializer):
|
|||
allow_blank=True,
|
||||
label="Review Comment",
|
||||
help_text="Optional comment when reviewing the deletion request",
|
||||
)
|
||||
|
||||
|
||||
class AISuggestionsRequestSerializer(serializers.Serializer):
|
||||
"""Serializer for requesting AI suggestions for a document."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,15 @@
|
|||
"""Serializers package for documents app."""
|
||||
|
||||
from .ai_suggestions import (
|
||||
AISuggestionFeedbackSerializer,
|
||||
AISuggestionsSerializer,
|
||||
AISuggestionStatsSerializer,
|
||||
ApplySuggestionSerializer,
|
||||
RejectSuggestionSerializer,
|
||||
)
|
||||
from .ai_suggestions import AISuggestionFeedbackSerializer
|
||||
from .ai_suggestions import AISuggestionsSerializer
|
||||
from .ai_suggestions import AISuggestionStatsSerializer
|
||||
from .ai_suggestions import ApplySuggestionSerializer
|
||||
from .ai_suggestions import RejectSuggestionSerializer
|
||||
|
||||
__all__ = [
|
||||
'AISuggestionFeedbackSerializer',
|
||||
'AISuggestionsSerializer',
|
||||
'AISuggestionStatsSerializer',
|
||||
'ApplySuggestionSerializer',
|
||||
'RejectSuggestionSerializer',
|
||||
"AISuggestionFeedbackSerializer",
|
||||
"AISuggestionStatsSerializer",
|
||||
"AISuggestionsSerializer",
|
||||
"ApplySuggestionSerializer",
|
||||
"RejectSuggestionSerializer",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -7,42 +7,39 @@ and handling user feedback on AI suggestions.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from documents.models import (
|
||||
AISuggestionFeedback,
|
||||
Correspondent,
|
||||
CustomField,
|
||||
DocumentType,
|
||||
StoragePath,
|
||||
Tag,
|
||||
Workflow,
|
||||
)
|
||||
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
|
||||
# Suggestion type choices - used across multiple serializers
|
||||
SUGGESTION_TYPE_CHOICES = [
|
||||
'tag',
|
||||
'correspondent',
|
||||
'document_type',
|
||||
'storage_path',
|
||||
'custom_field',
|
||||
'workflow',
|
||||
'title',
|
||||
"tag",
|
||||
"correspondent",
|
||||
"document_type",
|
||||
"storage_path",
|
||||
"custom_field",
|
||||
"workflow",
|
||||
"title",
|
||||
]
|
||||
|
||||
# Types that require value_id
|
||||
ID_REQUIRED_TYPES = ['tag', 'correspondent', 'document_type', 'storage_path', 'workflow']
|
||||
ID_REQUIRED_TYPES = ["tag", "correspondent", "document_type", "storage_path", "workflow"]
|
||||
# Types that require value_text
|
||||
TEXT_REQUIRED_TYPES = ['title']
|
||||
TEXT_REQUIRED_TYPES = ["title"]
|
||||
# Types that can use either (custom_field can be ID or text)
|
||||
|
||||
|
||||
class TagSuggestionSerializer(serializers.Serializer):
|
||||
"""Serializer for tag suggestions."""
|
||||
|
||||
|
||||
id = serializers.IntegerField()
|
||||
name = serializers.CharField()
|
||||
color = serializers.CharField()
|
||||
|
|
@ -51,7 +48,7 @@ class TagSuggestionSerializer(serializers.Serializer):
|
|||
|
||||
class CorrespondentSuggestionSerializer(serializers.Serializer):
|
||||
"""Serializer for correspondent suggestions."""
|
||||
|
||||
|
||||
id = serializers.IntegerField()
|
||||
name = serializers.CharField()
|
||||
confidence = serializers.FloatField()
|
||||
|
|
@ -59,7 +56,7 @@ class CorrespondentSuggestionSerializer(serializers.Serializer):
|
|||
|
||||
class DocumentTypeSuggestionSerializer(serializers.Serializer):
|
||||
"""Serializer for document type suggestions."""
|
||||
|
||||
|
||||
id = serializers.IntegerField()
|
||||
name = serializers.CharField()
|
||||
confidence = serializers.FloatField()
|
||||
|
|
@ -67,7 +64,7 @@ class DocumentTypeSuggestionSerializer(serializers.Serializer):
|
|||
|
||||
class StoragePathSuggestionSerializer(serializers.Serializer):
|
||||
"""Serializer for storage path suggestions."""
|
||||
|
||||
|
||||
id = serializers.IntegerField()
|
||||
name = serializers.CharField()
|
||||
path = serializers.CharField()
|
||||
|
|
@ -76,7 +73,7 @@ class StoragePathSuggestionSerializer(serializers.Serializer):
|
|||
|
||||
class CustomFieldSuggestionSerializer(serializers.Serializer):
|
||||
"""Serializer for custom field suggestions."""
|
||||
|
||||
|
||||
field_id = serializers.IntegerField()
|
||||
field_name = serializers.CharField()
|
||||
value = serializers.CharField()
|
||||
|
|
@ -85,7 +82,7 @@ class CustomFieldSuggestionSerializer(serializers.Serializer):
|
|||
|
||||
class WorkflowSuggestionSerializer(serializers.Serializer):
|
||||
"""Serializer for workflow suggestions."""
|
||||
|
||||
|
||||
id = serializers.IntegerField()
|
||||
name = serializers.CharField()
|
||||
confidence = serializers.FloatField()
|
||||
|
|
@ -93,7 +90,7 @@ class WorkflowSuggestionSerializer(serializers.Serializer):
|
|||
|
||||
class TitleSuggestionSerializer(serializers.Serializer):
|
||||
"""Serializer for title suggestions."""
|
||||
|
||||
|
||||
title = serializers.CharField()
|
||||
|
||||
|
||||
|
|
@ -103,7 +100,7 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
|
||||
Converts AIScanResult objects to JSON format for API responses.
|
||||
"""
|
||||
|
||||
|
||||
tags = TagSuggestionSerializer(many=True, required=False)
|
||||
correspondent = CorrespondentSuggestionSerializer(required=False, allow_null=True)
|
||||
document_type = DocumentTypeSuggestionSerializer(required=False, allow_null=True)
|
||||
|
|
@ -111,9 +108,9 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
custom_fields = CustomFieldSuggestionSerializer(many=True, required=False)
|
||||
workflows = WorkflowSuggestionSerializer(many=True, required=False)
|
||||
title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_scan_result(scan_result, document_id: int) -> Dict[str, Any]:
|
||||
def from_scan_result(scan_result, document_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
Convert an AIScanResult object to serializer data.
|
||||
|
||||
|
|
@ -125,7 +122,7 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
Dictionary ready for serialization
|
||||
"""
|
||||
data = {}
|
||||
|
||||
|
||||
# Tags
|
||||
if scan_result.tags:
|
||||
tag_suggestions = []
|
||||
|
|
@ -133,59 +130,59 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
try:
|
||||
tag = Tag.objects.get(pk=tag_id)
|
||||
tag_suggestions.append({
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'color': getattr(tag, 'color', '#000000'),
|
||||
'confidence': confidence,
|
||||
"id": tag.id,
|
||||
"name": tag.name,
|
||||
"color": getattr(tag, "color", "#000000"),
|
||||
"confidence": confidence,
|
||||
})
|
||||
except Tag.DoesNotExist:
|
||||
# Tag no longer exists in database; skip this suggestion
|
||||
pass
|
||||
data['tags'] = tag_suggestions
|
||||
|
||||
data["tags"] = tag_suggestions
|
||||
|
||||
# Correspondent
|
||||
if scan_result.correspondent:
|
||||
corr_id, confidence = scan_result.correspondent
|
||||
try:
|
||||
correspondent = Correspondent.objects.get(pk=corr_id)
|
||||
data['correspondent'] = {
|
||||
'id': correspondent.id,
|
||||
'name': correspondent.name,
|
||||
'confidence': confidence,
|
||||
data["correspondent"] = {
|
||||
"id": correspondent.id,
|
||||
"name": correspondent.name,
|
||||
"confidence": confidence,
|
||||
}
|
||||
except Correspondent.DoesNotExist:
|
||||
# Correspondent no longer exists in database; omit from suggestions
|
||||
pass
|
||||
|
||||
|
||||
# Document Type
|
||||
if scan_result.document_type:
|
||||
type_id, confidence = scan_result.document_type
|
||||
try:
|
||||
doc_type = DocumentType.objects.get(pk=type_id)
|
||||
data['document_type'] = {
|
||||
'id': doc_type.id,
|
||||
'name': doc_type.name,
|
||||
'confidence': confidence,
|
||||
data["document_type"] = {
|
||||
"id": doc_type.id,
|
||||
"name": doc_type.name,
|
||||
"confidence": confidence,
|
||||
}
|
||||
except DocumentType.DoesNotExist:
|
||||
# Document type no longer exists in database; omit from suggestions
|
||||
pass
|
||||
|
||||
|
||||
# Storage Path
|
||||
if scan_result.storage_path:
|
||||
path_id, confidence = scan_result.storage_path
|
||||
try:
|
||||
storage_path = StoragePath.objects.get(pk=path_id)
|
||||
data['storage_path'] = {
|
||||
'id': storage_path.id,
|
||||
'name': storage_path.name,
|
||||
'path': storage_path.path,
|
||||
'confidence': confidence,
|
||||
data["storage_path"] = {
|
||||
"id": storage_path.id,
|
||||
"name": storage_path.name,
|
||||
"path": storage_path.path,
|
||||
"confidence": confidence,
|
||||
}
|
||||
except StoragePath.DoesNotExist:
|
||||
# Storage path no longer exists in database; omit from suggestions
|
||||
pass
|
||||
|
||||
|
||||
# Custom Fields
|
||||
if scan_result.custom_fields:
|
||||
field_suggestions = []
|
||||
|
|
@ -193,16 +190,16 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
try:
|
||||
field = CustomField.objects.get(pk=field_id)
|
||||
field_suggestions.append({
|
||||
'field_id': field.id,
|
||||
'field_name': field.name,
|
||||
'value': str(value),
|
||||
'confidence': confidence,
|
||||
"field_id": field.id,
|
||||
"field_name": field.name,
|
||||
"value": str(value),
|
||||
"confidence": confidence,
|
||||
})
|
||||
except CustomField.DoesNotExist:
|
||||
# Custom field no longer exists in database; skip this suggestion
|
||||
pass
|
||||
data['custom_fields'] = field_suggestions
|
||||
|
||||
data["custom_fields"] = field_suggestions
|
||||
|
||||
# Workflows
|
||||
if scan_result.workflows:
|
||||
workflow_suggestions = []
|
||||
|
|
@ -210,21 +207,21 @@ class AISuggestionsSerializer(serializers.Serializer):
|
|||
try:
|
||||
workflow = Workflow.objects.get(pk=workflow_id)
|
||||
workflow_suggestions.append({
|
||||
'id': workflow.id,
|
||||
'name': workflow.name,
|
||||
'confidence': confidence,
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"confidence": confidence,
|
||||
})
|
||||
except Workflow.DoesNotExist:
|
||||
# Workflow no longer exists in database; skip this suggestion
|
||||
pass
|
||||
data['workflows'] = workflow_suggestions
|
||||
|
||||
data["workflows"] = workflow_suggestions
|
||||
|
||||
# Title suggestion
|
||||
if scan_result.title_suggestion:
|
||||
data['title_suggestion'] = {
|
||||
'title': scan_result.title_suggestion,
|
||||
data["title_suggestion"] = {
|
||||
"title": scan_result.title_suggestion,
|
||||
}
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
|
@ -234,28 +231,28 @@ class SuggestionSerializerMixin:
|
|||
"""
|
||||
def validate(self, attrs):
|
||||
"""Validate that the correct value field is provided for the suggestion type."""
|
||||
suggestion_type = attrs.get('suggestion_type')
|
||||
value_id = attrs.get('value_id')
|
||||
value_text = attrs.get('value_text')
|
||||
|
||||
suggestion_type = attrs.get("suggestion_type")
|
||||
value_id = attrs.get("value_id")
|
||||
value_text = attrs.get("value_text")
|
||||
|
||||
# Types that require value_id
|
||||
if suggestion_type in ID_REQUIRED_TYPES and not value_id:
|
||||
raise serializers.ValidationError(
|
||||
f"value_id is required for suggestion_type '{suggestion_type}'"
|
||||
f"value_id is required for suggestion_type '{suggestion_type}'",
|
||||
)
|
||||
|
||||
|
||||
# Types that require value_text
|
||||
if suggestion_type in TEXT_REQUIRED_TYPES and not value_text:
|
||||
raise serializers.ValidationError(
|
||||
f"value_text is required for suggestion_type '{suggestion_type}'"
|
||||
f"value_text is required for suggestion_type '{suggestion_type}'",
|
||||
)
|
||||
|
||||
|
||||
# For custom_field, either is acceptable
|
||||
if suggestion_type == 'custom_field' and not value_id and not value_text:
|
||||
if suggestion_type == "custom_field" and not value_id and not value_text:
|
||||
raise serializers.ValidationError(
|
||||
"Either value_id or value_text must be provided for custom_field"
|
||||
"Either value_id or value_text must be provided for custom_field",
|
||||
)
|
||||
|
||||
|
||||
return attrs
|
||||
|
||||
|
||||
|
|
@ -263,12 +260,12 @@ class ApplySuggestionSerializer(SuggestionSerializerMixin, serializers.Serialize
|
|||
"""
|
||||
Serializer for applying AI suggestions.
|
||||
"""
|
||||
|
||||
|
||||
suggestion_type = serializers.ChoiceField(
|
||||
choices=SUGGESTION_TYPE_CHOICES,
|
||||
required=True,
|
||||
)
|
||||
|
||||
|
||||
value_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
value_text = serializers.CharField(required=False, allow_blank=True)
|
||||
confidence = serializers.FloatField(required=True)
|
||||
|
|
@ -278,12 +275,12 @@ class RejectSuggestionSerializer(SuggestionSerializerMixin, serializers.Serializ
|
|||
"""
|
||||
Serializer for rejecting AI suggestions.
|
||||
"""
|
||||
|
||||
|
||||
suggestion_type = serializers.ChoiceField(
|
||||
choices=SUGGESTION_TYPE_CHOICES,
|
||||
required=True,
|
||||
)
|
||||
|
||||
|
||||
value_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
value_text = serializers.CharField(required=False, allow_blank=True)
|
||||
confidence = serializers.FloatField(required=True)
|
||||
|
|
@ -291,41 +288,41 @@ class RejectSuggestionSerializer(SuggestionSerializerMixin, serializers.Serializ
|
|||
|
||||
class AISuggestionFeedbackSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for AI suggestion feedback model."""
|
||||
|
||||
|
||||
class Meta:
|
||||
model = AISuggestionFeedback
|
||||
fields = [
|
||||
'id',
|
||||
'document',
|
||||
'suggestion_type',
|
||||
'suggested_value_id',
|
||||
'suggested_value_text',
|
||||
'confidence',
|
||||
'status',
|
||||
'user',
|
||||
'created_at',
|
||||
'applied_at',
|
||||
'metadata',
|
||||
"id",
|
||||
"document",
|
||||
"suggestion_type",
|
||||
"suggested_value_id",
|
||||
"suggested_value_text",
|
||||
"confidence",
|
||||
"status",
|
||||
"user",
|
||||
"created_at",
|
||||
"applied_at",
|
||||
"metadata",
|
||||
]
|
||||
read_only_fields = ['id', 'created_at', 'applied_at']
|
||||
read_only_fields = ["id", "created_at", "applied_at"]
|
||||
|
||||
|
||||
class AISuggestionStatsSerializer(serializers.Serializer):
|
||||
"""
|
||||
Serializer for AI suggestion accuracy statistics.
|
||||
"""
|
||||
|
||||
|
||||
total_suggestions = serializers.IntegerField()
|
||||
total_applied = serializers.IntegerField()
|
||||
total_rejected = serializers.IntegerField()
|
||||
accuracy_rate = serializers.FloatField()
|
||||
|
||||
|
||||
by_type = serializers.DictField(
|
||||
child=serializers.DictField(),
|
||||
help_text="Statistics broken down by suggestion type",
|
||||
)
|
||||
|
||||
|
||||
average_confidence_applied = serializers.FloatField()
|
||||
average_confidence_rejected = serializers.FloatField()
|
||||
|
||||
|
||||
recent_suggestions = AISuggestionFeedbackSerializer(many=True, required=False)
|
||||
|
|
|
|||
|
|
@ -18,13 +18,11 @@ from django.test import TestCase
|
|||
from django.utils import timezone
|
||||
|
||||
from documents.ai_deletion_manager import AIDeletionManager
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
|
||||
|
||||
class TestAIDeletionManagerCreateRequest(TestCase):
|
||||
|
|
@ -33,13 +31,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
|
|||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
self.user = User.objects.create_user(username="testuser", password="testpass")
|
||||
|
||||
|
||||
# Create test documents with various metadata
|
||||
self.correspondent = Correspondent.objects.create(name="Test Corp")
|
||||
self.doc_type = DocumentType.objects.create(name="Invoice")
|
||||
self.tag1 = Tag.objects.create(name="Important")
|
||||
self.tag2 = Tag.objects.create(name="2024")
|
||||
|
||||
|
||||
self.doc1 = Document.objects.create(
|
||||
title="Test Document 1",
|
||||
content="Test content 1",
|
||||
|
|
@ -49,7 +47,7 @@ class TestAIDeletionManagerCreateRequest(TestCase):
|
|||
document_type=self.doc_type,
|
||||
)
|
||||
self.doc1.tags.add(self.tag1, self.tag2)
|
||||
|
||||
|
||||
self.doc2 = Document.objects.create(
|
||||
title="Test Document 2",
|
||||
content="Test content 2",
|
||||
|
|
@ -63,13 +61,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
|
|||
"""Test creating a basic deletion request."""
|
||||
documents = [self.doc1, self.doc2]
|
||||
reason = "Duplicate documents detected"
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=documents,
|
||||
reason=reason,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertIsNotNone(request)
|
||||
self.assertIsInstance(request, DeletionRequest)
|
||||
self.assertEqual(request.ai_reason, reason)
|
||||
|
|
@ -82,13 +80,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
|
|||
"""Test that deletion request includes impact analysis."""
|
||||
documents = [self.doc1, self.doc2]
|
||||
reason = "Test deletion"
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=documents,
|
||||
reason=reason,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
impact = request.impact_summary
|
||||
self.assertIsNotNone(impact)
|
||||
self.assertEqual(impact["document_count"], 2)
|
||||
|
|
@ -106,14 +104,14 @@ class TestAIDeletionManagerCreateRequest(TestCase):
|
|||
"document_count": 1,
|
||||
"custom_field": "custom_value",
|
||||
}
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=documents,
|
||||
reason=reason,
|
||||
user=self.user,
|
||||
impact_analysis=custom_impact,
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.impact_summary, custom_impact)
|
||||
self.assertEqual(request.impact_summary["custom_field"], "custom_value")
|
||||
|
||||
|
|
@ -121,13 +119,13 @@ class TestAIDeletionManagerCreateRequest(TestCase):
|
|||
"""Test creating request with empty document list."""
|
||||
documents = []
|
||||
reason = "Test deletion"
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=documents,
|
||||
reason=reason,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertIsNotNone(request)
|
||||
self.assertEqual(request.documents.count(), 0)
|
||||
self.assertEqual(request.impact_summary["document_count"], 0)
|
||||
|
|
@ -157,9 +155,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
|
|||
document_type=self.doc_type1,
|
||||
)
|
||||
doc.tags.add(self.tag1, self.tag2)
|
||||
|
||||
|
||||
impact = AIDeletionManager._analyze_impact([doc])
|
||||
|
||||
|
||||
self.assertEqual(impact["document_count"], 1)
|
||||
self.assertEqual(len(impact["documents"]), 1)
|
||||
self.assertEqual(impact["documents"][0]["id"], doc.id)
|
||||
|
|
@ -180,7 +178,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
|
|||
document_type=self.doc_type1,
|
||||
)
|
||||
doc1.tags.add(self.tag1)
|
||||
|
||||
|
||||
doc2 = Document.objects.create(
|
||||
title="Document 2",
|
||||
content="Content 2",
|
||||
|
|
@ -190,9 +188,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
|
|||
document_type=self.doc_type2,
|
||||
)
|
||||
doc2.tags.add(self.tag2, self.tag3)
|
||||
|
||||
|
||||
impact = AIDeletionManager._analyze_impact([doc1, doc2])
|
||||
|
||||
|
||||
self.assertEqual(impact["document_count"], 2)
|
||||
self.assertEqual(len(impact["documents"]), 2)
|
||||
self.assertIn("Corp A", impact["affected_correspondents"])
|
||||
|
|
@ -209,9 +207,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
|
|||
checksum="checksum1",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
impact = AIDeletionManager._analyze_impact([doc])
|
||||
|
||||
|
||||
self.assertEqual(impact["document_count"], 1)
|
||||
self.assertEqual(impact["documents"][0]["correspondent"], None)
|
||||
self.assertEqual(impact["documents"][0]["document_type"], None)
|
||||
|
|
@ -232,7 +230,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
|
|||
# Force set the created date to an earlier time
|
||||
doc1.created = timezone.make_aware(datetime(2023, 1, 1))
|
||||
doc1.save()
|
||||
|
||||
|
||||
doc2 = Document.objects.create(
|
||||
title="New Document",
|
||||
content="Content",
|
||||
|
|
@ -241,9 +239,9 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
|
|||
)
|
||||
doc2.created = timezone.make_aware(datetime(2024, 12, 31))
|
||||
doc2.save()
|
||||
|
||||
|
||||
impact = AIDeletionManager._analyze_impact([doc1, doc2])
|
||||
|
||||
|
||||
self.assertIsNotNone(impact["date_range"]["earliest"])
|
||||
self.assertIsNotNone(impact["date_range"]["latest"])
|
||||
# Check that dates are ISO formatted strings
|
||||
|
|
@ -253,7 +251,7 @@ class TestAIDeletionManagerAnalyzeImpact(TestCase):
|
|||
def test_analyze_impact_empty_list(self):
|
||||
"""Test impact analysis with empty document list."""
|
||||
impact = AIDeletionManager._analyze_impact([])
|
||||
|
||||
|
||||
self.assertEqual(impact["document_count"], 0)
|
||||
self.assertEqual(len(impact["documents"]), 0)
|
||||
self.assertEqual(len(impact["affected_correspondents"]), 0)
|
||||
|
|
@ -267,11 +265,11 @@ class TestAIDeletionManagerFormatRequest(TestCase):
|
|||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
self.user = User.objects.create_user(username="testuser", password="testpass")
|
||||
|
||||
|
||||
self.correspondent = Correspondent.objects.create(name="Test Corp")
|
||||
self.doc_type = DocumentType.objects.create(name="Invoice")
|
||||
self.tag = Tag.objects.create(name="Important")
|
||||
|
||||
|
||||
self.doc = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content",
|
||||
|
|
@ -289,9 +287,9 @@ class TestAIDeletionManagerFormatRequest(TestCase):
|
|||
reason="Test reason for deletion",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
message = AIDeletionManager.format_deletion_request_for_user(request)
|
||||
|
||||
|
||||
self.assertIsInstance(message, str)
|
||||
self.assertIn("AI DELETION REQUEST", message)
|
||||
self.assertIn("Test reason for deletion", message)
|
||||
|
|
@ -306,15 +304,15 @@ class TestAIDeletionManagerFormatRequest(TestCase):
|
|||
checksum="checksum2",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=[self.doc, doc2],
|
||||
reason="Multiple documents",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
message = AIDeletionManager.format_deletion_request_for_user(request)
|
||||
|
||||
|
||||
self.assertIn("Number of documents: 2", message)
|
||||
self.assertIn("Test Corp", message)
|
||||
self.assertIn("Invoice", message)
|
||||
|
|
@ -328,15 +326,15 @@ class TestAIDeletionManagerFormatRequest(TestCase):
|
|||
checksum="checksum1",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=[doc],
|
||||
reason="Test deletion",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
message = AIDeletionManager.format_deletion_request_for_user(request)
|
||||
|
||||
|
||||
self.assertIn("Basic Document", message)
|
||||
self.assertIn("None", message) # Should show None for missing metadata
|
||||
|
||||
|
|
@ -347,9 +345,9 @@ class TestAIDeletionManagerFormatRequest(TestCase):
|
|||
reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
message = AIDeletionManager.format_deletion_request_for_user(request)
|
||||
|
||||
|
||||
self.assertIn("explicit approval", message.lower())
|
||||
self.assertIn("no files will be deleted until you confirm", message.lower())
|
||||
|
||||
|
|
@ -361,7 +359,7 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
|
|||
"""Set up test data."""
|
||||
self.user1 = User.objects.create_user(username="user1", password="pass1")
|
||||
self.user2 = User.objects.create_user(username="user2", password="pass2")
|
||||
|
||||
|
||||
self.doc = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Content",
|
||||
|
|
@ -382,16 +380,16 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
|
|||
reason="Reason 2",
|
||||
user=self.user1,
|
||||
)
|
||||
|
||||
|
||||
# Create request for user2
|
||||
AIDeletionManager.create_deletion_request(
|
||||
documents=[self.doc],
|
||||
reason="Reason 3",
|
||||
user=self.user2,
|
||||
)
|
||||
|
||||
|
||||
pending = AIDeletionManager.get_pending_requests(self.user1)
|
||||
|
||||
|
||||
self.assertEqual(len(pending), 2)
|
||||
self.assertIn(req1, pending)
|
||||
self.assertIn(req2, pending)
|
||||
|
|
@ -408,12 +406,12 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
|
|||
reason="Reason 2",
|
||||
user=self.user1,
|
||||
)
|
||||
|
||||
|
||||
# Approve one request
|
||||
req1.approve(self.user1, "Approved")
|
||||
|
||||
|
||||
pending = AIDeletionManager.get_pending_requests(self.user1)
|
||||
|
||||
|
||||
self.assertEqual(len(pending), 1)
|
||||
self.assertNotIn(req1, pending)
|
||||
self.assertIn(req2, pending)
|
||||
|
|
@ -430,12 +428,12 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
|
|||
reason="Reason 2",
|
||||
user=self.user1,
|
||||
)
|
||||
|
||||
|
||||
# Reject one request
|
||||
req1.reject(self.user1, "Rejected")
|
||||
|
||||
|
||||
pending = AIDeletionManager.get_pending_requests(self.user1)
|
||||
|
||||
|
||||
self.assertEqual(len(pending), 1)
|
||||
self.assertNotIn(req1, pending)
|
||||
self.assertIn(req2, pending)
|
||||
|
|
@ -443,7 +441,7 @@ class TestAIDeletionManagerGetPendingRequests(TestCase):
|
|||
def test_get_pending_requests_empty(self):
|
||||
"""Test getting pending requests when none exist."""
|
||||
pending = AIDeletionManager.get_pending_requests(self.user1)
|
||||
|
||||
|
||||
self.assertEqual(len(pending), 0)
|
||||
|
||||
|
||||
|
|
@ -454,7 +452,7 @@ class TestAIDeletionManagerSecurityConstraints(TestCase):
|
|||
"""Test that AI can never delete automatically."""
|
||||
# This is a critical security test
|
||||
result = AIDeletionManager.can_ai_delete_automatically()
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_deletion_request_requires_pending_status(self):
|
||||
|
|
@ -466,13 +464,13 @@ class TestAIDeletionManagerSecurityConstraints(TestCase):
|
|||
checksum="checksum1",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=[doc],
|
||||
reason="Test",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
|
||||
|
||||
def test_deletion_request_marked_as_ai_initiated(self):
|
||||
|
|
@ -484,13 +482,13 @@ class TestAIDeletionManagerSecurityConstraints(TestCase):
|
|||
checksum="checksum1",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
request = AIDeletionManager.create_deletion_request(
|
||||
documents=[doc],
|
||||
reason="Test",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
self.assertTrue(request.requested_by_ai)
|
||||
|
||||
|
||||
|
|
@ -501,7 +499,7 @@ class TestAIDeletionManagerWorkflow(TestCase):
|
|||
"""Set up test data."""
|
||||
self.user = User.objects.create_user(username="testuser", password="pass")
|
||||
self.approver = User.objects.create_user(username="approver", password="pass")
|
||||
|
||||
|
||||
self.doc1 = Document.objects.create(
|
||||
title="Document 1",
|
||||
content="Content 1",
|
||||
|
|
@ -523,14 +521,14 @@ class TestAIDeletionManagerWorkflow(TestCase):
|
|||
reason="Duplicates detected",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
|
||||
self.assertIsNone(request.reviewed_at)
|
||||
self.assertIsNone(request.reviewed_by)
|
||||
|
||||
|
||||
# Step 2: Approve request
|
||||
success = request.approve(self.approver, "Looks good")
|
||||
|
||||
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
|
||||
self.assertIsNotNone(request.reviewed_at)
|
||||
|
|
@ -545,12 +543,12 @@ class TestAIDeletionManagerWorkflow(TestCase):
|
|||
reason="Should be deleted",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
|
||||
|
||||
|
||||
# Step 2: Reject request
|
||||
success = request.reject(self.approver, "Not a duplicate")
|
||||
|
||||
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
|
||||
self.assertIsNotNone(request.reviewed_at)
|
||||
|
|
@ -564,14 +562,14 @@ class TestAIDeletionManagerWorkflow(TestCase):
|
|||
reason="Test deletion",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
# Record initial state
|
||||
created_at = request.created_at
|
||||
self.assertIsNotNone(created_at)
|
||||
|
||||
|
||||
# Approve
|
||||
request.approve(self.approver, "Approved")
|
||||
|
||||
|
||||
# Verify audit trail
|
||||
self.assertIsNotNone(request.created_at)
|
||||
self.assertIsNotNone(request.updated_at)
|
||||
|
|
|
|||
|
|
@ -10,24 +10,23 @@ Tests cover:
|
|||
- Permission assignment and verification
|
||||
"""
|
||||
|
||||
from django.contrib.auth.models import Group, Permission, User
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.test import TestCase
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import (
|
||||
CanApplyAISuggestionsPermission,
|
||||
CanApproveDeletionsPermission,
|
||||
CanConfigureAIPermission,
|
||||
CanViewAISuggestionsPermission,
|
||||
)
|
||||
from documents.permissions import CanApplyAISuggestionsPermission
|
||||
from documents.permissions import CanApproveDeletionsPermission
|
||||
from documents.permissions import CanConfigureAIPermission
|
||||
from documents.permissions import CanViewAISuggestionsPermission
|
||||
|
||||
|
||||
class MockView:
|
||||
"""Mock view for testing permissions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestCanViewAISuggestionsPermission(TestCase):
|
||||
|
|
@ -41,13 +40,13 @@ class TestCanViewAISuggestionsPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -107,13 +106,13 @@ class TestCanApplyAISuggestionsPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -173,13 +172,13 @@ class TestCanApproveDeletionsPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -239,13 +238,13 @@ class TestCanConfigureAIPermission(TestCase):
|
|||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.regular_user = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
self.permitted_user = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
# Assign permission to permitted_user
|
||||
|
|
@ -345,7 +344,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_viewer_role_permissions(self):
|
||||
"""Test that viewer role has appropriate permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="viewer", email="viewer@test.com", password="viewer123"
|
||||
username="viewer", email="viewer@test.com", password="viewer123",
|
||||
)
|
||||
user.groups.add(self.viewer_group)
|
||||
|
||||
|
|
@ -360,7 +359,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_editor_role_permissions(self):
|
||||
"""Test that editor role has appropriate permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="editor", email="editor@test.com", password="editor123"
|
||||
username="editor", email="editor@test.com", password="editor123",
|
||||
)
|
||||
user.groups.add(self.editor_group)
|
||||
|
||||
|
|
@ -375,7 +374,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_admin_role_permissions(self):
|
||||
"""Test that admin role has all permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="ai_admin", email="ai_admin@test.com", password="admin123"
|
||||
username="ai_admin", email="ai_admin@test.com", password="admin123",
|
||||
)
|
||||
user.groups.add(self.admin_group)
|
||||
|
||||
|
|
@ -390,7 +389,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_user_with_multiple_groups(self):
|
||||
"""Test that user permissions accumulate from multiple groups."""
|
||||
user = User.objects.create_user(
|
||||
username="multi_role", email="multi@test.com", password="multi123"
|
||||
username="multi_role", email="multi@test.com", password="multi123",
|
||||
)
|
||||
user.groups.add(self.viewer_group, self.editor_group)
|
||||
|
||||
|
|
@ -405,7 +404,7 @@ class TestRoleBasedAccessControl(TestCase):
|
|||
def test_direct_permission_assignment_overrides_group(self):
|
||||
"""Test that direct permission assignment works alongside group permissions."""
|
||||
user = User.objects.create_user(
|
||||
username="special", email="special@test.com", password="special123"
|
||||
username="special", email="special@test.com", password="special123",
|
||||
)
|
||||
user.groups.add(self.viewer_group)
|
||||
|
||||
|
|
@ -428,7 +427,7 @@ class TestPermissionAssignment(TestCase):
|
|||
def setUp(self):
|
||||
"""Set up test user."""
|
||||
self.user = User.objects.create_user(
|
||||
username="testuser", email="test@test.com", password="test123"
|
||||
username="testuser", email="test@test.com", password="test123",
|
||||
)
|
||||
content_type = ContentType.objects.get_for_model(Document)
|
||||
self.view_permission, _ = Permission.objects.get_or_create(
|
||||
|
|
@ -500,7 +499,7 @@ class TestPermissionEdgeCases(TestCase):
|
|||
def test_inactive_user_with_permission(self):
|
||||
"""Test that inactive users are denied even with permission."""
|
||||
user = User.objects.create_user(
|
||||
username="inactive", email="inactive@test.com", password="inactive123"
|
||||
username="inactive", email="inactive@test.com", password="inactive123",
|
||||
)
|
||||
user.is_active = False
|
||||
user.save()
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -8,24 +8,21 @@ document consumption to metadata application.
|
|||
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from django.test import TestCase
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from documents.ai_scanner import (
|
||||
AIDocumentScanner,
|
||||
AIScanResult,
|
||||
get_ai_scanner,
|
||||
)
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
CustomField,
|
||||
Document,
|
||||
DocumentType,
|
||||
StoragePath,
|
||||
Tag,
|
||||
Workflow,
|
||||
WorkflowTrigger,
|
||||
WorkflowAction,
|
||||
)
|
||||
from documents.ai_scanner import AIDocumentScanner
|
||||
from documents.ai_scanner import AIScanResult
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
from documents.models import WorkflowAction
|
||||
from documents.models import WorkflowTrigger
|
||||
|
||||
|
||||
class TestAIScannerIntegrationBasic(TestCase):
|
||||
|
|
@ -35,49 +32,49 @@ class TestAIScannerIntegrationBasic(TestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Invoice from ACME Corporation",
|
||||
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000"
|
||||
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000",
|
||||
)
|
||||
|
||||
|
||||
self.tag_invoice = Tag.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=Tag.MATCH_AUTO,
|
||||
match="invoice"
|
||||
match="invoice",
|
||||
)
|
||||
self.tag_important = Tag.objects.create(
|
||||
name="Important",
|
||||
matching_algorithm=Tag.MATCH_AUTO,
|
||||
match="total"
|
||||
match="total",
|
||||
)
|
||||
|
||||
|
||||
self.correspondent = Correspondent.objects.create(
|
||||
name="ACME Corporation",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
match="acme"
|
||||
match="acme",
|
||||
)
|
||||
|
||||
|
||||
self.doc_type = DocumentType.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||
match="invoice"
|
||||
match="invoice",
|
||||
)
|
||||
|
||||
|
||||
self.storage_path = StoragePath.objects.create(
|
||||
name="Invoices",
|
||||
path="/invoices",
|
||||
matching_algorithm=StoragePath.MATCH_AUTO,
|
||||
match="invoice"
|
||||
match="invoice",
|
||||
)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
||||
@mock.patch('documents.ai_scanner.match_document_types')
|
||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||
@mock.patch("documents.ai_scanner.match_document_types")
|
||||
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||
def test_full_scan_and_apply_workflow(
|
||||
self,
|
||||
mock_storage,
|
||||
mock_types,
|
||||
mock_correspondents,
|
||||
mock_tags
|
||||
mock_tags,
|
||||
):
|
||||
"""Test complete workflow from scan to application."""
|
||||
# Mock the matching functions to return our test data
|
||||
|
|
@ -85,51 +82,51 @@ class TestAIScannerIntegrationBasic(TestCase):
|
|||
mock_correspondents.return_value = [self.correspondent]
|
||||
mock_types.return_value = [self.doc_type]
|
||||
mock_storage.return_value = [self.storage_path]
|
||||
|
||||
|
||||
scanner = AIDocumentScanner(auto_apply_threshold=0.80)
|
||||
|
||||
|
||||
# Scan the document
|
||||
scan_result = scanner.scan_document(
|
||||
self.document,
|
||||
self.document.content
|
||||
self.document.content,
|
||||
)
|
||||
|
||||
|
||||
# Verify scan results
|
||||
self.assertIsNotNone(scan_result)
|
||||
self.assertGreater(len(scan_result.tags), 0)
|
||||
self.assertIsNotNone(scan_result.correspondent)
|
||||
self.assertIsNotNone(scan_result.document_type)
|
||||
self.assertIsNotNone(scan_result.storage_path)
|
||||
|
||||
|
||||
# Apply the results
|
||||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify application
|
||||
self.assertGreater(len(result["applied"]["tags"]), 0)
|
||||
self.assertIsNotNone(result["applied"]["correspondent"])
|
||||
|
||||
|
||||
# Verify database changes
|
||||
self.document.refresh_from_db()
|
||||
self.assertEqual(self.document.correspondent, self.correspondent)
|
||||
self.assertEqual(self.document.document_type, self.doc_type)
|
||||
self.assertEqual(self.document.storage_path, self.storage_path)
|
||||
|
||||
@mock.patch('documents.ai_scanner.match_tags')
|
||||
@mock.patch("documents.ai_scanner.match_tags")
|
||||
def test_scan_with_no_matches(self, mock_tags):
|
||||
"""Test scanning when no matches are found."""
|
||||
mock_tags.return_value = []
|
||||
|
||||
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
scan_result = scanner.scan_document(
|
||||
self.document,
|
||||
"Some random text with no matches"
|
||||
"Some random text with no matches",
|
||||
)
|
||||
|
||||
|
||||
# Should return empty results
|
||||
self.assertEqual(len(scan_result.tags), 0)
|
||||
self.assertIsNone(scan_result.correspondent)
|
||||
|
|
@ -143,46 +140,46 @@ class TestAIScannerIntegrationCustomFields(TestCase):
|
|||
"""Set up test data with custom fields."""
|
||||
self.document = Document.objects.create(
|
||||
title="Invoice",
|
||||
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com"
|
||||
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com",
|
||||
)
|
||||
|
||||
|
||||
self.field_date = CustomField.objects.create(
|
||||
name="Invoice Date",
|
||||
data_type=CustomField.FieldDataType.DATE
|
||||
data_type=CustomField.FieldDataType.DATE,
|
||||
)
|
||||
self.field_number = CustomField.objects.create(
|
||||
name="Invoice Number",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_amount = CustomField.objects.create(
|
||||
name="Total Amount",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
self.field_email = CustomField.objects.create(
|
||||
name="Contact Email",
|
||||
data_type=CustomField.FieldDataType.STRING
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
|
||||
def test_custom_field_extraction_integration(self):
|
||||
"""Test custom field extraction with mocked NER."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
# Mock NER to return entities
|
||||
mock_ner = mock.MagicMock()
|
||||
mock_ner.extract_all.return_value = {
|
||||
"dates": [{"text": "2024-01-01"}],
|
||||
"amounts": [{"text": "$1,500"}],
|
||||
"invoice_numbers": ["INV-123"],
|
||||
"emails": ["john@example.com"]
|
||||
"emails": ["john@example.com"],
|
||||
}
|
||||
scanner._ner_extractor = mock_ner
|
||||
|
||||
|
||||
# Scan document
|
||||
scan_result = scanner.scan_document(self.document, self.document.content)
|
||||
|
||||
|
||||
# Verify custom fields were extracted
|
||||
self.assertGreater(len(scan_result.custom_fields), 0)
|
||||
|
||||
|
||||
# Check specific fields
|
||||
extracted_field_ids = list(scan_result.custom_fields.keys())
|
||||
self.assertIn(self.field_date.id, extracted_field_ids)
|
||||
|
|
@ -196,47 +193,47 @@ class TestAIScannerIntegrationWorkflows(TestCase):
|
|||
"""Set up test workflows."""
|
||||
self.document = Document.objects.create(
|
||||
title="Invoice",
|
||||
content="Invoice document"
|
||||
content="Invoice document",
|
||||
)
|
||||
|
||||
|
||||
self.workflow1 = Workflow.objects.create(
|
||||
name="Invoice Processing",
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
self.trigger1 = WorkflowTrigger.objects.create(
|
||||
workflow=self.workflow1,
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||
)
|
||||
self.action1 = WorkflowAction.objects.create(
|
||||
workflow=self.workflow1,
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||
)
|
||||
|
||||
|
||||
self.workflow2 = Workflow.objects.create(
|
||||
name="Archive Documents",
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
self.trigger2 = WorkflowTrigger.objects.create(
|
||||
workflow=self.workflow2,
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||
)
|
||||
|
||||
def test_workflow_suggestion_integration(self):
|
||||
"""Test workflow suggestion with real workflows."""
|
||||
scanner = AIDocumentScanner(suggest_threshold=0.5)
|
||||
|
||||
|
||||
# Create scan result with some attributes
|
||||
scan_result = AIScanResult()
|
||||
scan_result.document_type = (1, 0.85)
|
||||
scan_result.tags = [(1, 0.80)]
|
||||
|
||||
|
||||
# Get workflow suggestions
|
||||
workflows = scanner._suggest_workflows(
|
||||
self.document,
|
||||
self.document.content,
|
||||
scan_result
|
||||
scan_result,
|
||||
)
|
||||
|
||||
|
||||
# Should suggest workflows
|
||||
self.assertGreater(len(workflows), 0)
|
||||
workflow_ids = [wf_id for wf_id, _ in workflows]
|
||||
|
|
@ -250,7 +247,7 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
self.tag = Tag.objects.create(name="TestTag")
|
||||
self.correspondent = Correspondent.objects.create(name="TestCorp")
|
||||
|
|
@ -258,29 +255,29 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
|
|||
def test_transaction_rollback_on_error(self):
|
||||
"""Test that transaction rolls back on error."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
scan_result = AIScanResult()
|
||||
scan_result.tags = [(self.tag.id, 0.90)]
|
||||
scan_result.correspondent = (self.correspondent.id, 0.90)
|
||||
|
||||
|
||||
# Force an error during save
|
||||
original_save = Document.save
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def failing_save(self, *args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] >= 1:
|
||||
raise Exception("Forced save failure")
|
||||
return original_save(self, *args, **kwargs)
|
||||
|
||||
with mock.patch.object(Document, 'save', failing_save):
|
||||
|
||||
with mock.patch.object(Document, "save", failing_save):
|
||||
with self.assertRaises(Exception):
|
||||
scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify changes were rolled back
|
||||
self.document.refresh_from_db()
|
||||
# Document should not have been modified
|
||||
|
|
@ -292,30 +289,30 @@ class TestAIScannerIntegrationPerformance(TestCase):
|
|||
def test_scan_multiple_documents(self):
|
||||
"""Test scanning multiple documents efficiently."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
documents = []
|
||||
for i in range(5):
|
||||
doc = Document.objects.create(
|
||||
title=f"Document {i}",
|
||||
content=f"Content for document {i}"
|
||||
content=f"Content for document {i}",
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
|
||||
# Mock to avoid actual ML loading
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
results = []
|
||||
for doc in documents:
|
||||
result = scanner.scan_document(doc, doc.content)
|
||||
results.append(result)
|
||||
|
||||
|
||||
# Verify all scans completed
|
||||
self.assertEqual(len(results), 5)
|
||||
for result in results:
|
||||
|
|
@ -329,37 +326,37 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Business Invoice",
|
||||
content="Invoice from ACME Corporation"
|
||||
content="Invoice from ACME Corporation",
|
||||
)
|
||||
|
||||
|
||||
self.correspondent_acme = Correspondent.objects.create(
|
||||
name="ACME Corporation",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
)
|
||||
self.correspondent_other = Correspondent.objects.create(
|
||||
name="Other Company",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
)
|
||||
|
||||
def test_correspondent_matching_with_ner_entities(self):
|
||||
"""Test that NER entities help match correspondents."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
# Mock NER to extract organization
|
||||
mock_ner = mock.MagicMock()
|
||||
mock_ner.extract_all.return_value = {
|
||||
"organizations": [{"text": "ACME Corporation"}]
|
||||
"organizations": [{"text": "ACME Corporation"}],
|
||||
}
|
||||
scanner._ner_extractor = mock_ner
|
||||
|
||||
|
||||
# Mock matching to return empty (so NER-based matching is used)
|
||||
with mock.patch('documents.ai_scanner.match_correspondents', return_value=[]):
|
||||
with mock.patch("documents.ai_scanner.match_correspondents", return_value=[]):
|
||||
result = scanner._detect_correspondent(
|
||||
self.document,
|
||||
self.document.content,
|
||||
{"organizations": [{"text": "ACME Corporation"}]}
|
||||
{"organizations": [{"text": "ACME Corporation"}]},
|
||||
)
|
||||
|
||||
|
||||
# Should detect ACME correspondent
|
||||
self.assertIsNotNone(result)
|
||||
corr_id, confidence = result
|
||||
|
|
@ -372,20 +369,20 @@ class TestAIScannerIntegrationTitleGeneration(TestCase):
|
|||
def test_title_generation_with_entities(self):
|
||||
"""Test title generation uses extracted entities."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
document = Document.objects.create(
|
||||
title="document.pdf",
|
||||
content="Invoice from ACME Corp dated 2024-01-15"
|
||||
content="Invoice from ACME Corp dated 2024-01-15",
|
||||
)
|
||||
|
||||
|
||||
entities = {
|
||||
"document_type": "Invoice",
|
||||
"organizations": [{"text": "ACME Corp"}],
|
||||
"dates": [{"text": "2024-01-15"}]
|
||||
"dates": [{"text": "2024-01-15"}],
|
||||
}
|
||||
|
||||
|
||||
title = scanner._suggest_title(document, document.content, entities)
|
||||
|
||||
|
||||
self.assertIsNotNone(title)
|
||||
self.assertIn("Invoice", title)
|
||||
self.assertIn("ACME Corp", title)
|
||||
|
|
@ -399,7 +396,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
|||
"""Set up test data."""
|
||||
self.document = Document.objects.create(
|
||||
title="Test",
|
||||
content="Test"
|
||||
content="Test",
|
||||
)
|
||||
self.tag_high = Tag.objects.create(name="HighConfidence")
|
||||
self.tag_medium = Tag.objects.create(name="MediumConfidence")
|
||||
|
|
@ -409,26 +406,26 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
|||
"""Test that only high confidence suggestions are auto-applied."""
|
||||
scanner = AIDocumentScanner(
|
||||
auto_apply_threshold=0.80,
|
||||
suggest_threshold=0.60
|
||||
suggest_threshold=0.60,
|
||||
)
|
||||
|
||||
|
||||
scan_result = AIScanResult()
|
||||
scan_result.tags = [
|
||||
(self.tag_high.id, 0.90), # Should be applied
|
||||
(self.tag_medium.id, 0.70), # Should be suggested
|
||||
(self.tag_low.id, 0.50), # Should be ignored
|
||||
]
|
||||
|
||||
|
||||
result = scanner.apply_scan_results(
|
||||
self.document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify high confidence was applied
|
||||
self.assertEqual(len(result["applied"]["tags"]), 1)
|
||||
self.assertEqual(result["applied"]["tags"][0]["id"], self.tag_high.id)
|
||||
|
||||
|
||||
# Verify medium confidence was suggested
|
||||
self.assertEqual(len(result["suggestions"]["tags"]), 1)
|
||||
self.assertEqual(result["suggestions"]["tags"][0]["id"], self.tag_medium.id)
|
||||
|
|
@ -441,28 +438,28 @@ class TestAIScannerIntegrationGlobalInstance(TestCase):
|
|||
"""Test that global scanner can be reused across multiple scans."""
|
||||
scanner1 = get_ai_scanner()
|
||||
scanner2 = get_ai_scanner()
|
||||
|
||||
|
||||
# Should be the same instance
|
||||
self.assertIs(scanner1, scanner2)
|
||||
|
||||
|
||||
# Should be functional
|
||||
document = Document.objects.create(
|
||||
title="Test",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
with mock.patch.object(scanner1, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner1, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner1, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner1, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner1, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner1, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner1, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner1, '_suggest_title', return_value=None):
|
||||
|
||||
|
||||
with mock.patch.object(scanner1, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner1, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner1, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner1, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner1, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner1, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner1, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner1, "_suggest_title", return_value=None):
|
||||
|
||||
result1 = scanner1.scan_document(document, document.content)
|
||||
result2 = scanner2.scan_document(document, document.content)
|
||||
|
||||
|
||||
self.assertIsInstance(result1, AIScanResult)
|
||||
self.assertIsInstance(result2, AIScanResult)
|
||||
|
||||
|
|
@ -473,66 +470,66 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
|||
def test_scan_with_minimal_document(self):
|
||||
"""Test scanning a document with minimal information."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
document = Document.objects.create(
|
||||
title="",
|
||||
content=""
|
||||
content="",
|
||||
)
|
||||
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
|
||||
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(document, document.content)
|
||||
|
||||
|
||||
self.assertIsInstance(result, AIScanResult)
|
||||
|
||||
def test_apply_with_deleted_references(self):
|
||||
"""Test applying results when referenced objects have been deleted."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
document = Document.objects.create(
|
||||
title="Test",
|
||||
content="Test"
|
||||
content="Test",
|
||||
)
|
||||
|
||||
|
||||
scan_result = AIScanResult()
|
||||
scan_result.tags = [(9999, 0.90)] # Non-existent tag ID
|
||||
scan_result.correspondent = (9999, 0.90) # Non-existent correspondent ID
|
||||
|
||||
|
||||
# Should handle gracefully
|
||||
result = scanner.apply_scan_results(
|
||||
document,
|
||||
scan_result,
|
||||
auto_apply=True
|
||||
auto_apply=True,
|
||||
)
|
||||
|
||||
|
||||
# Should not crash, just log errors
|
||||
self.assertEqual(len(result["applied"]["tags"]), 0)
|
||||
|
||||
def test_scan_with_unicode_and_special_characters(self):
|
||||
"""Test scanning documents with Unicode and special characters."""
|
||||
scanner = AIDocumentScanner()
|
||||
|
||||
|
||||
document = Document.objects.create(
|
||||
title="Factura - España 🇪🇸",
|
||||
content="Société française • 日本語 • Ελληνικά • مرحبا"
|
||||
content="Société française • 日本語 • Ελληνικά • مرحبا",
|
||||
)
|
||||
|
||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
||||
|
||||
|
||||
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||
|
||||
result = scanner.scan_document(document, document.content)
|
||||
|
||||
|
||||
self.assertIsInstance(result, AIScanResult)
|
||||
|
|
|
|||
|
|
@ -12,18 +12,17 @@ Tests cover:
|
|||
|
||||
from unittest import mock
|
||||
|
||||
from django.contrib.auth.models import Permission, User
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
|
|
@ -33,18 +32,18 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
super().setUp()
|
||||
|
||||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_with_permission = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
self.user_without_permission = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
|
||||
|
||||
# Assign view permission
|
||||
content_type = ContentType.objects.get_for_model(Document)
|
||||
view_permission, _ = Permission.objects.get_or_create(
|
||||
|
|
@ -53,13 +52,13 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
content_type=content_type,
|
||||
)
|
||||
self.user_with_permission.user_permissions.add(view_permission)
|
||||
|
||||
|
||||
# Create test document
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is a test invoice from ACME Corporation"
|
||||
content="This is a test invoice from ACME Corporation",
|
||||
)
|
||||
|
||||
|
||||
# Create test metadata objects
|
||||
self.tag = Tag.objects.create(name="Invoice")
|
||||
self.correspondent = Correspondent.objects.create(name="ACME Corp")
|
||||
|
|
@ -70,28 +69,28 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
def test_user_without_permission_denied(self):
|
||||
"""Test that users without permission are denied."""
|
||||
self.client.force_authenticate(user=self.user_without_permission)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
def test_superuser_allowed(self):
|
||||
"""Test that superusers can access the endpoint."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
||||
|
|
@ -100,17 +99,17 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
mock_scan_result.storage_path = None
|
||||
mock_scan_result.title_suggestion = "Invoice - ACME Corp"
|
||||
mock_scan_result.custom_fields = {}
|
||||
|
||||
|
||||
mock_scanner_instance = mock.MagicMock()
|
||||
mock_scanner_instance.scan_document.return_value = mock_scan_result
|
||||
mock_scanner.return_value = mock_scanner_instance
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("document_id", response.data)
|
||||
self.assertEqual(response.data["document_id"], self.document.id)
|
||||
|
|
@ -118,8 +117,8 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
def test_user_with_permission_allowed(self):
|
||||
"""Test that users with permission can access the endpoint."""
|
||||
self.client.force_authenticate(user=self.user_with_permission)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = []
|
||||
|
|
@ -128,41 +127,41 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
mock_scan_result.storage_path = None
|
||||
mock_scan_result.title_suggestion = None
|
||||
mock_scan_result.custom_fields = {}
|
||||
|
||||
|
||||
mock_scanner_instance = mock.MagicMock()
|
||||
mock_scanner_instance.scan_document.return_value = mock_scan_result
|
||||
mock_scanner.return_value = mock_scanner_instance
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
def test_invalid_document_id(self):
|
||||
"""Test handling of invalid document ID."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": 99999},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_missing_document_id(self):
|
||||
"""Test handling of missing document ID."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
|
|
@ -172,15 +171,15 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
super().setUp()
|
||||
|
||||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_with_permission = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
|
||||
|
||||
# Assign apply permission
|
||||
content_type = ContentType.objects.get_for_model(Document)
|
||||
apply_permission, _ = Permission.objects.get_or_create(
|
||||
|
|
@ -189,13 +188,13 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
content_type=content_type,
|
||||
)
|
||||
self.user_with_permission.user_permissions.add(apply_permission)
|
||||
|
||||
|
||||
# Create test document
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
content="Test content"
|
||||
content="Test content",
|
||||
)
|
||||
|
||||
|
||||
# Create test metadata
|
||||
self.tag = Tag.objects.create(name="Test Tag")
|
||||
self.correspondent = Correspondent.objects.create(name="Test Corp")
|
||||
|
|
@ -205,16 +204,16 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
response = self.client.post(
|
||||
"/api/ai/suggestions/apply/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
def test_apply_tags_success(self):
|
||||
"""Test successfully applying tag suggestions."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
||||
|
|
@ -223,29 +222,29 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
mock_scan_result.storage_path = None
|
||||
mock_scan_result.title_suggestion = None
|
||||
mock_scan_result.custom_fields = {}
|
||||
|
||||
|
||||
mock_scanner_instance = mock.MagicMock()
|
||||
mock_scanner_instance.scan_document.return_value = mock_scan_result
|
||||
mock_scanner_instance.auto_apply_threshold = 0.80
|
||||
mock_scanner.return_value = mock_scanner_instance
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/suggestions/apply/",
|
||||
{
|
||||
"document_id": self.document.id,
|
||||
"apply_tags": True
|
||||
"apply_tags": True,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["status"], "success")
|
||||
|
||||
def test_apply_correspondent_success(self):
|
||||
"""Test successfully applying correspondent suggestion."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
# Mock the scanner response
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = []
|
||||
|
|
@ -254,23 +253,23 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
|||
mock_scan_result.storage_path = None
|
||||
mock_scan_result.title_suggestion = None
|
||||
mock_scan_result.custom_fields = {}
|
||||
|
||||
|
||||
mock_scanner_instance = mock.MagicMock()
|
||||
mock_scanner_instance.scan_document.return_value = mock_scan_result
|
||||
mock_scanner_instance.auto_apply_threshold = 0.80
|
||||
mock_scanner.return_value = mock_scanner_instance
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/suggestions/apply/",
|
||||
{
|
||||
"document_id": self.document.id,
|
||||
"apply_correspondent": True
|
||||
"apply_correspondent": True,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Verify correspondent was applied
|
||||
self.document.refresh_from_db()
|
||||
self.assertEqual(self.document.correspondent, self.correspondent)
|
||||
|
|
@ -282,43 +281,43 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
|||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
super().setUp()
|
||||
|
||||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_without_permission = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
|
||||
def test_unauthorized_access_denied(self):
|
||||
"""Test that unauthenticated users are denied."""
|
||||
response = self.client.get("/api/ai/config/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
def test_user_without_permission_denied(self):
|
||||
"""Test that users without permission are denied."""
|
||||
self.client.force_authenticate(user=self.user_without_permission)
|
||||
|
||||
|
||||
response = self.client.get("/api/ai/config/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
def test_get_config_success(self):
|
||||
"""Test getting AI configuration."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
mock_scanner_instance = mock.MagicMock()
|
||||
mock_scanner_instance.auto_apply_threshold = 0.80
|
||||
mock_scanner_instance.suggest_threshold = 0.60
|
||||
mock_scanner_instance.ml_enabled = True
|
||||
mock_scanner_instance.advanced_ocr_enabled = True
|
||||
mock_scanner.return_value = mock_scanner_instance
|
||||
|
||||
|
||||
response = self.client.get("/api/ai/config/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("auto_apply_threshold", response.data)
|
||||
self.assertEqual(response.data["auto_apply_threshold"], 0.80)
|
||||
|
|
@ -326,31 +325,31 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
|||
def test_update_config_success(self):
|
||||
"""Test updating AI configuration."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/config/",
|
||||
{
|
||||
"auto_apply_threshold": 0.90,
|
||||
"suggest_threshold": 0.70
|
||||
"suggest_threshold": 0.70,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["status"], "success")
|
||||
|
||||
def test_update_config_invalid_threshold(self):
|
||||
"""Test updating with invalid threshold value."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/config/",
|
||||
{
|
||||
"auto_apply_threshold": 1.5 # Invalid: > 1.0
|
||||
"auto_apply_threshold": 1.5, # Invalid: > 1.0
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
|
|
@ -360,18 +359,18 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
super().setUp()
|
||||
|
||||
|
||||
# Create users
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username="admin", email="admin@test.com", password="admin123"
|
||||
username="admin", email="admin@test.com", password="admin123",
|
||||
)
|
||||
self.user_with_permission = User.objects.create_user(
|
||||
username="permitted", email="permitted@test.com", password="permitted123"
|
||||
username="permitted", email="permitted@test.com", password="permitted123",
|
||||
)
|
||||
self.user_without_permission = User.objects.create_user(
|
||||
username="regular", email="regular@test.com", password="regular123"
|
||||
username="regular", email="regular@test.com", password="regular123",
|
||||
)
|
||||
|
||||
|
||||
# Assign approval permission
|
||||
content_type = ContentType.objects.get_for_model(Document)
|
||||
approval_permission, _ = Permission.objects.get_or_create(
|
||||
|
|
@ -380,12 +379,12 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
content_type=content_type,
|
||||
)
|
||||
self.user_with_permission.user_permissions.add(approval_permission)
|
||||
|
||||
|
||||
# Create test deletion request
|
||||
self.deletion_request = DeletionRequest.objects.create(
|
||||
user=self.user_with_permission,
|
||||
requested_by_ai=True,
|
||||
ai_reason="Document appears to be a duplicate"
|
||||
ai_reason="Document appears to be a duplicate",
|
||||
)
|
||||
|
||||
def test_unauthorized_access_denied(self):
|
||||
|
|
@ -394,102 +393,102 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
|||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
def test_user_without_permission_denied(self):
|
||||
"""Test that users without permission are denied."""
|
||||
self.client.force_authenticate(user=self.user_without_permission)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
def test_approve_deletion_success(self):
|
||||
"""Test successfully approving a deletion request."""
|
||||
self.client.force_authenticate(user=self.user_with_permission)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["status"], "success")
|
||||
|
||||
|
||||
# Verify status was updated
|
||||
self.deletion_request.refresh_from_db()
|
||||
self.assertEqual(
|
||||
self.deletion_request.status,
|
||||
DeletionRequest.STATUS_APPROVED
|
||||
DeletionRequest.STATUS_APPROVED,
|
||||
)
|
||||
|
||||
def test_reject_deletion_success(self):
|
||||
"""Test successfully rejecting a deletion request."""
|
||||
self.client.force_authenticate(user=self.user_with_permission)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "reject",
|
||||
"reason": "Document is still needed"
|
||||
"reason": "Document is still needed",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Verify status was updated
|
||||
self.deletion_request.refresh_from_db()
|
||||
self.assertEqual(
|
||||
self.deletion_request.status,
|
||||
DeletionRequest.STATUS_REJECTED
|
||||
DeletionRequest.STATUS_REJECTED,
|
||||
)
|
||||
|
||||
def test_invalid_request_id(self):
|
||||
"""Test handling of invalid deletion request ID."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": 99999,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_superuser_can_approve_any_request(self):
|
||||
"""Test that superusers can approve any deletion request."""
|
||||
self.client.force_authenticate(user=self.superuser)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
"/api/ai/deletions/approve/",
|
||||
{
|
||||
"request_id": self.deletion_request.id,
|
||||
"action": "approve"
|
||||
"action": "approve",
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
|
|
@ -499,14 +498,14 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
super().setUp()
|
||||
|
||||
|
||||
# Create user with all AI permissions
|
||||
self.power_user = User.objects.create_user(
|
||||
username="power_user", email="power@test.com", password="power123"
|
||||
username="power_user", email="power@test.com", password="power123",
|
||||
)
|
||||
|
||||
|
||||
content_type = ContentType.objects.get_for_model(Document)
|
||||
|
||||
|
||||
# Assign all AI permissions
|
||||
permissions = [
|
||||
"can_view_ai_suggestions",
|
||||
|
|
@ -514,7 +513,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
"can_approve_deletions",
|
||||
"can_configure_ai",
|
||||
]
|
||||
|
||||
|
||||
for codename in permissions:
|
||||
perm, _ = Permission.objects.get_or_create(
|
||||
codename=codename,
|
||||
|
|
@ -522,18 +521,18 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
content_type=content_type,
|
||||
)
|
||||
self.power_user.user_permissions.add(perm)
|
||||
|
||||
|
||||
self.document = Document.objects.create(
|
||||
title="Test Doc",
|
||||
content="Test"
|
||||
content="Test",
|
||||
)
|
||||
|
||||
def test_power_user_can_access_all_endpoints(self):
|
||||
"""Test that user with all permissions can access all endpoints."""
|
||||
self.client.force_authenticate(user=self.power_user)
|
||||
|
||||
|
||||
# Test suggestions endpoint
|
||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
||||
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||
mock_scan_result = mock.MagicMock()
|
||||
mock_scan_result.tags = []
|
||||
mock_scan_result.correspondent = None
|
||||
|
|
@ -541,7 +540,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
mock_scan_result.storage_path = None
|
||||
mock_scan_result.title_suggestion = None
|
||||
mock_scan_result.custom_fields = {}
|
||||
|
||||
|
||||
mock_scanner_instance = mock.MagicMock()
|
||||
mock_scanner_instance.scan_document.return_value = mock_scan_result
|
||||
mock_scanner_instance.auto_apply_threshold = 0.80
|
||||
|
|
@ -549,25 +548,25 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
|||
mock_scanner_instance.ml_enabled = True
|
||||
mock_scanner_instance.advanced_ocr_enabled = True
|
||||
mock_scanner.return_value = mock_scanner_instance
|
||||
|
||||
|
||||
response1 = self.client.post(
|
||||
"/api/ai/suggestions/",
|
||||
{"document_id": self.document.id},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Test apply endpoint
|
||||
response2 = self.client.post(
|
||||
"/api/ai/suggestions/apply/",
|
||||
{
|
||||
"document_id": self.document.id,
|
||||
"apply_tags": False
|
||||
"apply_tags": False,
|
||||
},
|
||||
format="json"
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Test config endpoint
|
||||
response3 = self.client.get("/api/ai/config/")
|
||||
self.assertEqual(response3.status_code, status.HTTP_200_OK)
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ from rest_framework import status
|
|||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.ai_scanner import AIScanResult
|
||||
from documents.models import (
|
||||
AISuggestionFeedback,
|
||||
Correspondent,
|
||||
Document,
|
||||
DocumentType,
|
||||
StoragePath,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
|
|
@ -25,11 +23,11 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
|
||||
# Create test user
|
||||
self.user = User.objects.create_superuser(username="test_admin")
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
|
||||
# Create test data
|
||||
self.correspondent = Correspondent.objects.create(
|
||||
name="Test Corp",
|
||||
|
|
@ -52,7 +50,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
path="/archive/",
|
||||
pk=1,
|
||||
)
|
||||
|
||||
|
||||
# Create test document
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
|
|
@ -64,12 +62,12 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_ai_suggestions_endpoint_exists(self):
|
||||
"""Test that the ai-suggestions endpoint is accessible."""
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||
)
|
||||
# Should not be 404
|
||||
self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
@mock.patch('documents.ai_scanner.get_ai_scanner')
|
||||
@mock.patch("documents.ai_scanner.get_ai_scanner")
|
||||
def test_get_ai_suggestions_success(self, mock_get_scanner):
|
||||
"""Test successfully getting AI suggestions for a document."""
|
||||
# Create mock scan result
|
||||
|
|
@ -79,39 +77,39 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
scan_result.document_type = (self.doc_type.id, 0.88)
|
||||
scan_result.storage_path = (self.storage_path.id, 0.80)
|
||||
scan_result.title_suggestion = "Suggested Title"
|
||||
|
||||
|
||||
# Mock scanner
|
||||
mock_scanner = mock.Mock()
|
||||
mock_scanner.scan_document.return_value = scan_result
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
# Make request
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||
)
|
||||
|
||||
|
||||
# Verify response
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
|
||||
|
||||
# Check tags
|
||||
self.assertIn('tags', data)
|
||||
self.assertEqual(len(data['tags']), 2)
|
||||
self.assertEqual(data['tags'][0]['id'], self.tag1.id)
|
||||
self.assertEqual(data['tags'][0]['confidence'], 0.85)
|
||||
|
||||
self.assertIn("tags", data)
|
||||
self.assertEqual(len(data["tags"]), 2)
|
||||
self.assertEqual(data["tags"][0]["id"], self.tag1.id)
|
||||
self.assertEqual(data["tags"][0]["confidence"], 0.85)
|
||||
|
||||
# Check correspondent
|
||||
self.assertIn('correspondent', data)
|
||||
self.assertEqual(data['correspondent']['id'], self.correspondent.id)
|
||||
self.assertEqual(data['correspondent']['confidence'], 0.90)
|
||||
|
||||
self.assertIn("correspondent", data)
|
||||
self.assertEqual(data["correspondent"]["id"], self.correspondent.id)
|
||||
self.assertEqual(data["correspondent"]["confidence"], 0.90)
|
||||
|
||||
# Check document type
|
||||
self.assertIn('document_type', data)
|
||||
self.assertEqual(data['document_type']['id'], self.doc_type.id)
|
||||
|
||||
self.assertIn("document_type", data)
|
||||
self.assertEqual(data["document_type"]["id"], self.doc_type.id)
|
||||
|
||||
# Check title suggestion
|
||||
self.assertIn('title_suggestion', data)
|
||||
self.assertEqual(data['title_suggestion']['title'], "Suggested Title")
|
||||
self.assertIn("title_suggestion", data)
|
||||
self.assertEqual(data["title_suggestion"]["title"], "Suggested Title")
|
||||
|
||||
def test_get_ai_suggestions_no_content(self):
|
||||
"""Test getting AI suggestions for document without content."""
|
||||
|
|
@ -122,43 +120,43 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
checksum="empty123",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("no content", response.json()['detail'].lower())
|
||||
self.assertIn("no content", response.json()["detail"].lower())
|
||||
|
||||
def test_get_ai_suggestions_document_not_found(self):
|
||||
"""Test getting AI suggestions for non-existent document."""
|
||||
response = self.client.get("/api/documents/99999/ai-suggestions/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_apply_suggestion_tag(self):
|
||||
"""Test applying a tag suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'value_id': self.tag1.id,
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "tag",
|
||||
"value_id": self.tag1.id,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json()['status'], 'success')
|
||||
|
||||
self.assertEqual(response.json()["status"], "success")
|
||||
|
||||
# Verify tag was applied
|
||||
self.document.refresh_from_db()
|
||||
self.assertIn(self.tag1, self.document.tags.all())
|
||||
|
||||
|
||||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
||||
|
|
@ -169,27 +167,27 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_correspondent(self):
|
||||
"""Test applying a correspondent suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'correspondent',
|
||||
'value_id': self.correspondent.id,
|
||||
'confidence': 0.90,
|
||||
"suggestion_type": "correspondent",
|
||||
"value_id": self.correspondent.id,
|
||||
"confidence": 0.90,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Verify correspondent was applied
|
||||
self.document.refresh_from_db()
|
||||
self.assertEqual(self.document.correspondent, self.correspondent)
|
||||
|
||||
|
||||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='correspondent',
|
||||
suggestion_type="correspondent",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
||||
|
|
@ -197,19 +195,19 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_document_type(self):
|
||||
"""Test applying a document type suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'document_type',
|
||||
'value_id': self.doc_type.id,
|
||||
'confidence': 0.88,
|
||||
"suggestion_type": "document_type",
|
||||
"value_id": self.doc_type.id,
|
||||
"confidence": 0.88,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Verify document type was applied
|
||||
self.document.refresh_from_db()
|
||||
self.assertEqual(self.document.document_type, self.doc_type)
|
||||
|
|
@ -217,91 +215,91 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_apply_suggestion_title(self):
|
||||
"""Test applying a title suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'title',
|
||||
'value_text': 'New Suggested Title',
|
||||
'confidence': 0.80,
|
||||
"suggestion_type": "title",
|
||||
"value_text": "New Suggested Title",
|
||||
"confidence": 0.80,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Verify title was applied
|
||||
self.document.refresh_from_db()
|
||||
self.assertEqual(self.document.title, 'New Suggested Title')
|
||||
self.assertEqual(self.document.title, "New Suggested Title")
|
||||
|
||||
def test_apply_suggestion_invalid_type(self):
|
||||
"""Test applying suggestion with invalid type."""
|
||||
request_data = {
|
||||
'suggestion_type': 'invalid_type',
|
||||
'value_id': 1,
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "invalid_type",
|
||||
"value_id": 1,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def test_apply_suggestion_missing_value(self):
|
||||
"""Test applying suggestion without value_id or value_text."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "tag",
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def test_apply_suggestion_nonexistent_object(self):
|
||||
"""Test applying suggestion with non-existent object ID."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'value_id': 99999,
|
||||
'confidence': 0.85,
|
||||
"suggestion_type": "tag",
|
||||
"value_id": 99999,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_reject_suggestion(self):
|
||||
"""Test rejecting an AI suggestion."""
|
||||
request_data = {
|
||||
'suggestion_type': 'tag',
|
||||
'value_id': self.tag1.id,
|
||||
'confidence': 0.65,
|
||||
"suggestion_type": "tag",
|
||||
"value_id": self.tag1.id,
|
||||
"confidence": 0.65,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json()['status'], 'success')
|
||||
|
||||
self.assertEqual(response.json()["status"], "success")
|
||||
|
||||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
||||
|
|
@ -312,46 +310,46 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
def test_reject_suggestion_with_text(self):
|
||||
"""Test rejecting a suggestion with text value."""
|
||||
request_data = {
|
||||
'suggestion_type': 'title',
|
||||
'value_text': 'Bad Title Suggestion',
|
||||
'confidence': 0.50,
|
||||
"suggestion_type": "title",
|
||||
"value_text": "Bad Title Suggestion",
|
||||
"confidence": 0.50,
|
||||
}
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
||||
data=request_data,
|
||||
format='json',
|
||||
format="json",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Verify feedback was recorded
|
||||
feedback = AISuggestionFeedback.objects.filter(
|
||||
document=self.document,
|
||||
suggestion_type='title',
|
||||
suggestion_type="title",
|
||||
).first()
|
||||
self.assertIsNotNone(feedback)
|
||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
||||
self.assertEqual(feedback.suggested_value_text, 'Bad Title Suggestion')
|
||||
self.assertEqual(feedback.suggested_value_text, "Bad Title Suggestion")
|
||||
|
||||
def test_ai_suggestion_stats_empty(self):
|
||||
"""Test getting statistics when no feedback exists."""
|
||||
response = self.client.get("/api/documents/ai-suggestion-stats/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
|
||||
self.assertEqual(data['total_suggestions'], 0)
|
||||
self.assertEqual(data['total_applied'], 0)
|
||||
self.assertEqual(data['total_rejected'], 0)
|
||||
self.assertEqual(data['accuracy_rate'], 0)
|
||||
|
||||
self.assertEqual(data["total_suggestions"], 0)
|
||||
self.assertEqual(data["total_applied"], 0)
|
||||
self.assertEqual(data["total_rejected"], 0)
|
||||
self.assertEqual(data["accuracy_rate"], 0)
|
||||
|
||||
def test_ai_suggestion_stats_with_data(self):
|
||||
"""Test getting statistics with feedback data."""
|
||||
# Create some feedback entries
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag1.id,
|
||||
confidence=0.85,
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
|
|
@ -359,7 +357,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
)
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag2.id,
|
||||
confidence=0.70,
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
|
|
@ -367,38 +365,38 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
)
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='correspondent',
|
||||
suggestion_type="correspondent",
|
||||
suggested_value_id=self.correspondent.id,
|
||||
confidence=0.60,
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
response = self.client.get("/api/documents/ai-suggestion-stats/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
|
||||
|
||||
# Check overall stats
|
||||
self.assertEqual(data['total_suggestions'], 3)
|
||||
self.assertEqual(data['total_applied'], 2)
|
||||
self.assertEqual(data['total_rejected'], 1)
|
||||
self.assertAlmostEqual(data['accuracy_rate'], 66.67, places=1)
|
||||
|
||||
self.assertEqual(data["total_suggestions"], 3)
|
||||
self.assertEqual(data["total_applied"], 2)
|
||||
self.assertEqual(data["total_rejected"], 1)
|
||||
self.assertAlmostEqual(data["accuracy_rate"], 66.67, places=1)
|
||||
|
||||
# Check by_type stats
|
||||
self.assertIn('by_type', data)
|
||||
self.assertIn('tag', data['by_type'])
|
||||
self.assertEqual(data['by_type']['tag']['total'], 2)
|
||||
self.assertEqual(data['by_type']['tag']['applied'], 2)
|
||||
self.assertEqual(data['by_type']['tag']['rejected'], 0)
|
||||
|
||||
self.assertIn("by_type", data)
|
||||
self.assertIn("tag", data["by_type"])
|
||||
self.assertEqual(data["by_type"]["tag"]["total"], 2)
|
||||
self.assertEqual(data["by_type"]["tag"]["applied"], 2)
|
||||
self.assertEqual(data["by_type"]["tag"]["rejected"], 0)
|
||||
|
||||
# Check confidence averages
|
||||
self.assertGreater(data['average_confidence_applied'], 0)
|
||||
self.assertGreater(data['average_confidence_rejected'], 0)
|
||||
|
||||
self.assertGreater(data["average_confidence_applied"], 0)
|
||||
self.assertGreater(data["average_confidence_rejected"], 0)
|
||||
|
||||
# Check recent suggestions
|
||||
self.assertIn('recent_suggestions', data)
|
||||
self.assertEqual(len(data['recent_suggestions']), 3)
|
||||
self.assertIn("recent_suggestions", data)
|
||||
self.assertEqual(len(data["recent_suggestions"]), 3)
|
||||
|
||||
def test_ai_suggestion_stats_accuracy_calculation(self):
|
||||
"""Test that accuracy rate is calculated correctly."""
|
||||
|
|
@ -406,57 +404,57 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
|||
for i in range(7):
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag1.id,
|
||||
confidence=0.80,
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
for i in range(3):
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=self.document,
|
||||
suggestion_type='tag',
|
||||
suggestion_type="tag",
|
||||
suggested_value_id=self.tag2.id,
|
||||
confidence=0.60,
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
response = self.client.get("/api/documents/ai-suggestion-stats/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
|
||||
self.assertEqual(data['total_suggestions'], 10)
|
||||
self.assertEqual(data['total_applied'], 7)
|
||||
self.assertEqual(data['total_rejected'], 3)
|
||||
self.assertEqual(data['accuracy_rate'], 70.0)
|
||||
|
||||
self.assertEqual(data["total_suggestions"], 10)
|
||||
self.assertEqual(data["total_applied"], 7)
|
||||
self.assertEqual(data["total_rejected"], 3)
|
||||
self.assertEqual(data["accuracy_rate"], 70.0)
|
||||
|
||||
def test_authentication_required(self):
|
||||
"""Test that authentication is required for all endpoints."""
|
||||
self.client.force_authenticate(user=None)
|
||||
|
||||
|
||||
# Test ai-suggestions endpoint
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
||||
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
# Test apply-suggestion endpoint
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||
data={},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
# Test reject-suggestion endpoint
|
||||
response = self.client.post(
|
||||
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
||||
data={},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
# Test stats endpoint
|
||||
response = self.client.get("/api/documents/ai-suggestion-stats/")
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
|
|
|||
|
|
@ -11,17 +11,14 @@ Tests cover:
|
|||
"""
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
|
||||
|
||||
class TestDeletionRequestAPI(APITestCase):
|
||||
|
|
@ -33,7 +30,7 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
self.user1 = User.objects.create_user(username="user1", password="pass123")
|
||||
self.user2 = User.objects.create_user(username="user2", password="pass123")
|
||||
self.admin = User.objects.create_superuser(username="admin", password="admin123")
|
||||
|
||||
|
||||
# Create test documents
|
||||
self.doc1 = Document.objects.create(
|
||||
title="Test Document 1",
|
||||
|
|
@ -53,7 +50,7 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
checksum="checksum3",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
# Create deletion requests
|
||||
self.request1 = DeletionRequest.objects.create(
|
||||
requested_by_ai=True,
|
||||
|
|
@ -63,7 +60,7 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
impact_summary={"document_count": 1},
|
||||
)
|
||||
self.request1.documents.add(self.doc1)
|
||||
|
||||
|
||||
self.request2 = DeletionRequest.objects.create(
|
||||
requested_by_ai=True,
|
||||
ai_reason="Low quality document",
|
||||
|
|
@ -77,7 +74,7 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
"""Test that users can list their own deletion requests."""
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
response = self.client.get("/api/deletion-requests/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(len(response.data["results"]), 1)
|
||||
self.assertEqual(response.data["results"][0]["id"], self.request1.id)
|
||||
|
|
@ -86,7 +83,7 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
"""Test that admin can list all deletion requests."""
|
||||
self.client.force_authenticate(user=self.admin)
|
||||
response = self.client.get("/api/deletion-requests/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(len(response.data["results"]), 2)
|
||||
|
||||
|
|
@ -94,7 +91,7 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
"""Test retrieving a single deletion request."""
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
response = self.client.get(f"/api/deletion-requests/{self.request1.id}/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["id"], self.request1.id)
|
||||
self.assertEqual(response.data["ai_reason"], "Duplicate document detected")
|
||||
|
|
@ -104,23 +101,23 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
def test_approve_deletion_request_as_owner(self):
|
||||
"""Test approving a deletion request as the owner."""
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
|
||||
# Verify document exists
|
||||
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/approve/",
|
||||
{"comment": "Approved by owner"},
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("message", response.data)
|
||||
self.assertIn("execution_result", response.data)
|
||||
self.assertEqual(response.data["execution_result"]["deleted_count"], 1)
|
||||
|
||||
|
||||
# Verify document was deleted
|
||||
self.assertFalse(Document.objects.filter(id=self.doc1.id).exists())
|
||||
|
||||
|
||||
# Verify deletion request was updated
|
||||
self.request1.refresh_from_db()
|
||||
self.assertEqual(self.request1.status, DeletionRequest.STATUS_COMPLETED)
|
||||
|
|
@ -131,18 +128,18 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
def test_approve_deletion_request_as_admin(self):
|
||||
"""Test approving a deletion request as admin."""
|
||||
self.client.force_authenticate(user=self.admin)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request2.id}/approve/",
|
||||
{"comment": "Approved by admin"},
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("execution_result", response.data)
|
||||
|
||||
|
||||
# Verify document was deleted
|
||||
self.assertFalse(Document.objects.filter(id=self.doc2.id).exists())
|
||||
|
||||
|
||||
# Verify deletion request was updated
|
||||
self.request2.refresh_from_db()
|
||||
self.assertEqual(self.request2.status, DeletionRequest.STATUS_COMPLETED)
|
||||
|
|
@ -151,16 +148,16 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
def test_approve_deletion_request_without_permission(self):
|
||||
"""Test that non-owners cannot approve deletion requests."""
|
||||
self.client.force_authenticate(user=self.user2)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/approve/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
# Verify document was NOT deleted
|
||||
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
|
||||
|
||||
|
||||
# Verify deletion request was NOT updated
|
||||
self.request1.refresh_from_db()
|
||||
self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING)
|
||||
|
|
@ -169,13 +166,13 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
"""Test that already approved requests cannot be approved again."""
|
||||
self.request1.status = DeletionRequest.STATUS_APPROVED
|
||||
self.request1.save()
|
||||
|
||||
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/approve/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("error", response.data)
|
||||
self.assertIn("pending", response.data["error"].lower())
|
||||
|
|
@ -183,18 +180,18 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
def test_reject_deletion_request_as_owner(self):
|
||||
"""Test rejecting a deletion request as the owner."""
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/reject/",
|
||||
{"comment": "Not needed"},
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("message", response.data)
|
||||
|
||||
|
||||
# Verify document was NOT deleted
|
||||
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
|
||||
|
||||
|
||||
# Verify deletion request was updated
|
||||
self.request1.refresh_from_db()
|
||||
self.assertEqual(self.request1.status, DeletionRequest.STATUS_REJECTED)
|
||||
|
|
@ -205,16 +202,16 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
def test_reject_deletion_request_as_admin(self):
|
||||
"""Test rejecting a deletion request as admin."""
|
||||
self.client.force_authenticate(user=self.admin)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request2.id}/reject/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Verify document was NOT deleted
|
||||
self.assertTrue(Document.objects.filter(id=self.doc2.id).exists())
|
||||
|
||||
|
||||
# Verify deletion request was updated
|
||||
self.request2.refresh_from_db()
|
||||
self.assertEqual(self.request2.status, DeletionRequest.STATUS_REJECTED)
|
||||
|
|
@ -223,13 +220,13 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
def test_reject_deletion_request_without_permission(self):
|
||||
"""Test that non-owners cannot reject deletion requests."""
|
||||
self.client.force_authenticate(user=self.user2)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/reject/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
# Verify deletion request was NOT updated
|
||||
self.request1.refresh_from_db()
|
||||
self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING)
|
||||
|
|
@ -238,31 +235,31 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
"""Test that already rejected requests cannot be rejected again."""
|
||||
self.request1.status = DeletionRequest.STATUS_REJECTED
|
||||
self.request1.save()
|
||||
|
||||
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/reject/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("error", response.data)
|
||||
|
||||
def test_cancel_deletion_request_as_owner(self):
|
||||
"""Test canceling a deletion request as the owner."""
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/cancel/",
|
||||
{"comment": "Changed my mind"},
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("message", response.data)
|
||||
|
||||
|
||||
# Verify document was NOT deleted
|
||||
self.assertTrue(Document.objects.filter(id=self.doc1.id).exists())
|
||||
|
||||
|
||||
# Verify deletion request was updated
|
||||
self.request1.refresh_from_db()
|
||||
self.assertEqual(self.request1.status, DeletionRequest.STATUS_CANCELLED)
|
||||
|
|
@ -273,13 +270,13 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
def test_cancel_deletion_request_without_permission(self):
|
||||
"""Test that non-owners cannot cancel deletion requests."""
|
||||
self.client.force_authenticate(user=self.user2)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/cancel/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
# Verify deletion request was NOT updated
|
||||
self.request1.refresh_from_db()
|
||||
self.assertEqual(self.request1.status, DeletionRequest.STATUS_PENDING)
|
||||
|
|
@ -288,13 +285,13 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
"""Test that approved requests cannot be cancelled."""
|
||||
self.request1.status = DeletionRequest.STATUS_APPROVED
|
||||
self.request1.save()
|
||||
|
||||
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/cancel/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("error", response.data)
|
||||
|
||||
|
|
@ -309,17 +306,17 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
impact_summary={"document_count": 2},
|
||||
)
|
||||
multi_request.documents.add(self.doc1, self.doc3)
|
||||
|
||||
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{multi_request.id}/approve/",
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["execution_result"]["deleted_count"], 2)
|
||||
self.assertEqual(response.data["execution_result"]["total_documents"], 2)
|
||||
|
||||
|
||||
# Verify both documents were deleted
|
||||
self.assertFalse(Document.objects.filter(id=self.doc1.id).exists())
|
||||
self.assertFalse(Document.objects.filter(id=self.doc3.id).exists())
|
||||
|
|
@ -330,15 +327,15 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
tag = Tag.objects.create(name="test-tag")
|
||||
correspondent = Correspondent.objects.create(name="Test Corp")
|
||||
doc_type = DocumentType.objects.create(name="Invoice")
|
||||
|
||||
|
||||
self.doc1.tags.add(tag)
|
||||
self.doc1.correspondent = correspondent
|
||||
self.doc1.document_type = doc_type
|
||||
self.doc1.save()
|
||||
|
||||
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
response = self.client.get(f"/api/deletion-requests/{self.request1.id}/")
|
||||
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
doc_details = response.data["document_details"]
|
||||
self.assertEqual(len(doc_details), 1)
|
||||
|
|
@ -352,7 +349,7 @@ class TestDeletionRequestAPI(APITestCase):
|
|||
"""Test that unauthenticated users cannot access the API."""
|
||||
response = self.client.get("/api/deletion-requests/")
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/deletion-requests/{self.request1.id}/approve/",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1350,18 +1350,18 @@ class TestConsumerAIScannerIntegration(
|
|||
correspondent = Correspondent.objects.create(name="Test Corp")
|
||||
doc_type = DocumentType.objects.create(name="Invoice")
|
||||
storage_path = StoragePath.objects.create(name="Invoices", path="/invoices")
|
||||
|
||||
|
||||
# Create mock AI scanner
|
||||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
# Mock scan results
|
||||
scan_result = AIScanResult()
|
||||
scan_result.tags = [(tag1.id, 0.85), (tag2.id, 0.75)]
|
||||
scan_result.correspondent = (correspondent.id, 0.90)
|
||||
scan_result.document_type = (doc_type.id, 0.85)
|
||||
scan_result.storage_path = (storage_path.id, 0.80)
|
||||
|
||||
|
||||
mock_scanner.scan_document.return_value = scan_result
|
||||
mock_scanner.apply_scan_results.return_value = {
|
||||
"applied": {
|
||||
|
|
@ -1381,20 +1381,20 @@ class TestConsumerAIScannerIntegration(
|
|||
"workflows": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Run consumer
|
||||
filename = self.get_test_file()
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
# Verify document was created
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
||||
|
||||
# Verify AI scanner was called
|
||||
mock_scanner.scan_document.assert_called_once()
|
||||
mock_scanner.apply_scan_results.assert_called_once()
|
||||
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_scanner.scan_document.call_args
|
||||
self.assertEqual(call_args[1]["document"], document)
|
||||
|
|
@ -1412,11 +1412,11 @@ class TestConsumerAIScannerIntegration(
|
|||
demonstrating graceful degradation.
|
||||
"""
|
||||
filename = self.get_test_file()
|
||||
|
||||
|
||||
# Consumer should complete successfully even with ML disabled
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
# Verify document was created
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
|
@ -1435,13 +1435,13 @@ class TestConsumerAIScannerIntegration(
|
|||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
mock_scanner.scan_document.side_effect = Exception("AI Scanner failed")
|
||||
|
||||
|
||||
filename = self.get_test_file()
|
||||
|
||||
|
||||
# Consumer should complete despite AI scanner failure
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
# Verify document was created despite AI failure
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
|
@ -1457,17 +1457,17 @@ class TestConsumerAIScannerIntegration(
|
|||
"""
|
||||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
self.create_empty_scan_result_mock(mock_scanner)
|
||||
|
||||
|
||||
filename = self.get_test_file()
|
||||
|
||||
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
||||
|
||||
# Verify AI scanner was called with PDF
|
||||
mock_scanner.scan_document.assert_called_once()
|
||||
call_args = mock_scanner.scan_document.call_args
|
||||
|
|
@ -1488,7 +1488,7 @@ class TestConsumerAIScannerIntegration(
|
|||
self.dirs.scratch_dir,
|
||||
self.get_test_archive_file(),
|
||||
)
|
||||
|
||||
|
||||
with mock.patch("documents.parsers.document_consumer_declaration.send") as m:
|
||||
m.return_value = [
|
||||
(
|
||||
|
|
@ -1500,21 +1500,21 @@ class TestConsumerAIScannerIntegration(
|
|||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
self.create_empty_scan_result_mock(mock_scanner)
|
||||
|
||||
|
||||
# Create a PNG file
|
||||
dst = self.get_test_file_with_name("sample.png")
|
||||
|
||||
|
||||
with self.get_consumer(dst) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
||||
|
||||
# Verify AI scanner was called
|
||||
mock_scanner.scan_document.assert_called_once()
|
||||
|
||||
|
|
@ -1527,26 +1527,26 @@ class TestConsumerAIScannerIntegration(
|
|||
Verifies that AI scanning adds minimal overhead to document consumption.
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
self.create_empty_scan_result_mock(mock_scanner)
|
||||
|
||||
|
||||
filename = self.get_test_file()
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
# Verify document was created
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
||||
|
||||
# Verify AI scanner was called
|
||||
mock_scanner.scan_document.assert_called_once()
|
||||
|
||||
|
||||
# With mocks, this should be very fast (<1s).
|
||||
# TODO: Implement proper performance testing with real ML models in integration/performance test suite.
|
||||
elapsed_time = end_time - start_time
|
||||
|
|
@ -1561,33 +1561,32 @@ class TestConsumerAIScannerIntegration(
|
|||
Verifies that AI scanner respects database transactions and handles
|
||||
rollbacks correctly.
|
||||
"""
|
||||
from django.db import transaction as db_transaction
|
||||
|
||||
|
||||
tag = Tag.objects.create(name="Invoice")
|
||||
|
||||
|
||||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
scan_result = AIScanResult()
|
||||
scan_result.tags = [(tag.id, 0.85)]
|
||||
mock_scanner.scan_document.return_value = scan_result
|
||||
|
||||
|
||||
# Mock apply_scan_results to raise an exception after some work
|
||||
def apply_with_error(document, scan_result, auto_apply=True):
|
||||
# Simulate partial work
|
||||
document.tags.add(tag)
|
||||
# Then fail
|
||||
raise Exception("Simulated transaction failure")
|
||||
|
||||
|
||||
mock_scanner.apply_scan_results.side_effect = apply_with_error
|
||||
|
||||
|
||||
filename = self.get_test_file()
|
||||
|
||||
|
||||
# Even with AI scanner failure, the document should still be created
|
||||
# because we handle AI scanner errors gracefully
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
# The tag addition from AI scanner should be rolled back due to exception
|
||||
|
|
@ -1604,17 +1603,17 @@ class TestConsumerAIScannerIntegration(
|
|||
"""
|
||||
tag1 = Tag.objects.create(name="Invoice")
|
||||
tag2 = Tag.objects.create(name="Receipt")
|
||||
|
||||
|
||||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
# Configure scanner to return different results for each call
|
||||
scan_results = []
|
||||
for tag in [tag1, tag2]:
|
||||
scan_result = AIScanResult()
|
||||
scan_result.tags = [(tag.id, 0.85)]
|
||||
scan_results.append(scan_result)
|
||||
|
||||
|
||||
mock_scanner.scan_document.side_effect = scan_results
|
||||
mock_scanner.apply_scan_results.return_value = {
|
||||
"applied": {
|
||||
|
|
@ -1634,20 +1633,20 @@ class TestConsumerAIScannerIntegration(
|
|||
"workflows": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Process multiple documents
|
||||
filenames = [self.get_test_file()]
|
||||
# Create second file
|
||||
filenames.append(self.get_test_file_with_name("sample2.pdf"))
|
||||
|
||||
|
||||
for filename in filenames:
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
# Verify both documents were created
|
||||
documents = Document.objects.all()
|
||||
self.assertEqual(documents.count(), 2)
|
||||
|
||||
|
||||
# Verify AI scanner was called for each document
|
||||
self.assertEqual(mock_scanner.scan_document.call_count, 2)
|
||||
|
||||
|
|
@ -1661,17 +1660,17 @@ class TestConsumerAIScannerIntegration(
|
|||
"""
|
||||
mock_scanner = MagicMock()
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
|
||||
self.create_empty_scan_result_mock(mock_scanner)
|
||||
|
||||
|
||||
filename = self.get_test_file()
|
||||
|
||||
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
||||
|
||||
# Verify AI scanner received text content
|
||||
mock_scanner.scan_document.assert_called_once()
|
||||
call_args = mock_scanner.scan_document.call_args
|
||||
|
|
@ -1686,10 +1685,10 @@ class TestConsumerAIScannerIntegration(
|
|||
the AI scanner is not invoked at all.
|
||||
"""
|
||||
filename = self.get_test_file()
|
||||
|
||||
|
||||
with self.get_consumer(filename) as consumer:
|
||||
consumer.run()
|
||||
|
||||
|
||||
# Document should be created normally without AI scanning
|
||||
document = Document.objects.first()
|
||||
self.assertIsNotNone(document)
|
||||
|
|
|
|||
|
|
@ -15,13 +15,11 @@ from django.contrib.auth.models import User
|
|||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from documents.models import (
|
||||
Correspondent,
|
||||
DeletionRequest,
|
||||
Document,
|
||||
DocumentType,
|
||||
Tag,
|
||||
)
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DeletionRequest
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Tag
|
||||
|
||||
|
||||
class TestDeletionRequestModelCreation(TestCase):
|
||||
|
|
@ -45,7 +43,7 @@ class TestDeletionRequestModelCreation(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
self.assertIsNotNone(request)
|
||||
self.assertTrue(request.requested_by_ai)
|
||||
self.assertEqual(request.ai_reason, "Test reason")
|
||||
|
|
@ -59,7 +57,7 @@ class TestDeletionRequestModelCreation(TestCase):
|
|||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertIsNotNone(request.created_at)
|
||||
self.assertIsNotNone(request.updated_at)
|
||||
|
||||
|
|
@ -70,7 +68,7 @@ class TestDeletionRequestModelCreation(TestCase):
|
|||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
|
||||
|
||||
def test_deletion_request_with_documents(self):
|
||||
|
|
@ -80,16 +78,16 @@ class TestDeletionRequestModelCreation(TestCase):
|
|||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
doc2 = Document.objects.create(
|
||||
title="Document 2",
|
||||
content="Content 2",
|
||||
checksum="checksum2",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
request.documents.add(self.doc, doc2)
|
||||
|
||||
|
||||
self.assertEqual(request.documents.count(), 2)
|
||||
self.assertIn(self.doc, request.documents.all())
|
||||
self.assertIn(doc2, request.documents.all())
|
||||
|
|
@ -101,7 +99,7 @@ class TestDeletionRequestModelCreation(TestCase):
|
|||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertIsInstance(request.impact_summary, dict)
|
||||
self.assertEqual(request.impact_summary, {})
|
||||
|
||||
|
|
@ -112,14 +110,14 @@ class TestDeletionRequestModelCreation(TestCase):
|
|||
"affected_tags": ["tag1", "tag2"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
|
||||
|
||||
request = DeletionRequest.objects.create(
|
||||
requested_by_ai=True,
|
||||
ai_reason="Test",
|
||||
user=self.user,
|
||||
impact_summary=impact,
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.impact_summary["document_count"], 5)
|
||||
self.assertEqual(request.impact_summary["affected_tags"], ["tag1", "tag2"])
|
||||
|
||||
|
|
@ -131,9 +129,9 @@ class TestDeletionRequestModelCreation(TestCase):
|
|||
user=self.user,
|
||||
)
|
||||
request.documents.add(self.doc)
|
||||
|
||||
|
||||
str_repr = str(request)
|
||||
|
||||
|
||||
self.assertIn("Deletion Request", str_repr)
|
||||
self.assertIn(str(request.id), str_repr)
|
||||
self.assertIn("1 documents", str_repr)
|
||||
|
|
@ -162,9 +160,9 @@ class TestDeletionRequestApprove(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
result = request.approve(self.approver, "Approved")
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
|
||||
self.assertEqual(request.reviewed_by, self.approver)
|
||||
|
|
@ -179,9 +177,9 @@ class TestDeletionRequestApprove(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
result = request.approve(self.approver)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(request.review_comment, "")
|
||||
|
||||
|
|
@ -195,9 +193,9 @@ class TestDeletionRequestApprove(TestCase):
|
|||
reviewed_by=self.user,
|
||||
reviewed_at=timezone.now(),
|
||||
)
|
||||
|
||||
|
||||
result = request.approve(self.approver, "Trying to approve again")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.reviewed_by, self.user) # Should not change
|
||||
|
||||
|
|
@ -211,9 +209,9 @@ class TestDeletionRequestApprove(TestCase):
|
|||
reviewed_by=self.user,
|
||||
reviewed_at=timezone.now(),
|
||||
)
|
||||
|
||||
|
||||
result = request.approve(self.approver, "Trying to approve rejected")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
|
||||
|
||||
|
|
@ -225,9 +223,9 @@ class TestDeletionRequestApprove(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_CANCELLED,
|
||||
)
|
||||
|
||||
|
||||
result = request.approve(self.approver, "Trying to approve cancelled")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_CANCELLED)
|
||||
|
||||
|
|
@ -239,9 +237,9 @@ class TestDeletionRequestApprove(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_COMPLETED,
|
||||
)
|
||||
|
||||
|
||||
result = request.approve(self.approver, "Trying to approve completed")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_COMPLETED)
|
||||
|
||||
|
|
@ -253,11 +251,11 @@ class TestDeletionRequestApprove(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
before_approval = timezone.now()
|
||||
result = request.approve(self.approver, "Approved")
|
||||
after_approval = timezone.now()
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNotNone(request.reviewed_at)
|
||||
self.assertGreaterEqual(request.reviewed_at, before_approval)
|
||||
|
|
@ -280,9 +278,9 @@ class TestDeletionRequestReject(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
result = request.reject(self.reviewer, "Not necessary")
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
|
||||
self.assertEqual(request.reviewed_by, self.reviewer)
|
||||
|
|
@ -297,9 +295,9 @@ class TestDeletionRequestReject(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
result = request.reject(self.reviewer)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(request.review_comment, "")
|
||||
|
||||
|
|
@ -313,9 +311,9 @@ class TestDeletionRequestReject(TestCase):
|
|||
reviewed_by=self.user,
|
||||
reviewed_at=timezone.now(),
|
||||
)
|
||||
|
||||
|
||||
result = request.reject(self.reviewer, "Trying to reject again")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.reviewed_by, self.user) # Should not change
|
||||
|
||||
|
|
@ -329,9 +327,9 @@ class TestDeletionRequestReject(TestCase):
|
|||
reviewed_by=self.user,
|
||||
reviewed_at=timezone.now(),
|
||||
)
|
||||
|
||||
|
||||
result = request.reject(self.reviewer, "Trying to reject approved")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
|
||||
|
||||
|
|
@ -343,9 +341,9 @@ class TestDeletionRequestReject(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_CANCELLED,
|
||||
)
|
||||
|
||||
|
||||
result = request.reject(self.reviewer, "Trying to reject cancelled")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_CANCELLED)
|
||||
|
||||
|
|
@ -357,9 +355,9 @@ class TestDeletionRequestReject(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_COMPLETED,
|
||||
)
|
||||
|
||||
|
||||
result = request.reject(self.reviewer, "Trying to reject completed")
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_COMPLETED)
|
||||
|
||||
|
|
@ -371,11 +369,11 @@ class TestDeletionRequestReject(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
before_rejection = timezone.now()
|
||||
result = request.reject(self.reviewer, "Rejected")
|
||||
after_rejection = timezone.now()
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNotNone(request.reviewed_at)
|
||||
self.assertGreaterEqual(request.reviewed_at, before_rejection)
|
||||
|
|
@ -389,11 +387,11 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
|
|||
"""Set up test data."""
|
||||
self.user = User.objects.create_user(username="user1", password="pass")
|
||||
self.approver = User.objects.create_user(username="approver", password="pass")
|
||||
|
||||
|
||||
self.correspondent = Correspondent.objects.create(name="Test Corp")
|
||||
self.doc_type = DocumentType.objects.create(name="Invoice")
|
||||
self.tag = Tag.objects.create(name="Important")
|
||||
|
||||
|
||||
self.doc1 = Document.objects.create(
|
||||
title="Document 1",
|
||||
content="Content 1",
|
||||
|
|
@ -403,7 +401,7 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
|
|||
document_type=self.doc_type,
|
||||
)
|
||||
self.doc1.tags.add(self.tag)
|
||||
|
||||
|
||||
self.doc2 = Document.objects.create(
|
||||
title="Document 2",
|
||||
content="Content 2",
|
||||
|
|
@ -421,15 +419,15 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
|
|||
impact_summary={"document_count": 2},
|
||||
)
|
||||
request.documents.add(self.doc1, self.doc2)
|
||||
|
||||
|
||||
# Verify initial state
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
|
||||
self.assertIsNone(request.reviewed_by)
|
||||
self.assertIsNone(request.reviewed_at)
|
||||
|
||||
|
||||
# Approve
|
||||
success = request.approve(self.approver, "Confirmed duplicates")
|
||||
|
||||
|
||||
# Verify final state
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
|
||||
|
|
@ -446,13 +444,13 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
|
|||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
request.documents.add(self.doc1)
|
||||
|
||||
|
||||
# Verify initial state
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_PENDING)
|
||||
|
||||
|
||||
# Reject
|
||||
success = request.reject(self.approver, "Not duplicates")
|
||||
|
||||
|
||||
# Verify final state
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
|
||||
|
|
@ -468,14 +466,14 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
# Reject first
|
||||
request.reject(self.user, "Rejected")
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
|
||||
|
||||
|
||||
# Try to approve
|
||||
success = request.approve(self.approver, "Changed my mind")
|
||||
|
||||
|
||||
# Should fail
|
||||
self.assertFalse(success)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_REJECTED)
|
||||
|
|
@ -488,14 +486,14 @@ class TestDeletionRequestWorkflowScenarios(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
# Approve first
|
||||
request.approve(self.approver, "Approved")
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
|
||||
|
||||
|
||||
# Try to reject
|
||||
success = request.reject(self.user, "Changed my mind")
|
||||
|
||||
|
||||
# Should fail
|
||||
self.assertFalse(success)
|
||||
self.assertEqual(request.status, DeletionRequest.STATUS_APPROVED)
|
||||
|
|
@ -516,7 +514,7 @@ class TestDeletionRequestAuditTrail(TestCase):
|
|||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.user, self.user)
|
||||
|
||||
def test_audit_trail_records_reviewer(self):
|
||||
|
|
@ -527,9 +525,9 @@ class TestDeletionRequestAuditTrail(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
request.approve(self.approver, "Approved")
|
||||
|
||||
|
||||
self.assertEqual(request.reviewed_by, self.approver)
|
||||
self.assertNotEqual(request.reviewed_by, request.user)
|
||||
|
||||
|
|
@ -540,12 +538,12 @@ class TestDeletionRequestAuditTrail(TestCase):
|
|||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
created_at = request.created_at
|
||||
|
||||
|
||||
# Approve the request
|
||||
request.approve(self.approver, "Approved")
|
||||
|
||||
|
||||
# Verify timestamps
|
||||
self.assertIsNotNone(request.created_at)
|
||||
self.assertIsNotNone(request.updated_at)
|
||||
|
|
@ -555,16 +553,16 @@ class TestDeletionRequestAuditTrail(TestCase):
|
|||
def test_audit_trail_preserves_ai_reason(self):
|
||||
"""Test that AI's original reason is preserved."""
|
||||
original_reason = "AI detected duplicates based on content similarity"
|
||||
|
||||
|
||||
request = DeletionRequest.objects.create(
|
||||
requested_by_ai=True,
|
||||
ai_reason=original_reason,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
# Approve with different comment
|
||||
request.approve(self.approver, "User confirmed")
|
||||
|
||||
|
||||
# Original AI reason should be preserved
|
||||
self.assertEqual(request.ai_reason, original_reason)
|
||||
self.assertEqual(request.review_comment, "User confirmed")
|
||||
|
|
@ -582,7 +580,7 @@ class TestDeletionRequestAuditTrail(TestCase):
|
|||
"completed_by": "system",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(request.completion_details["deleted_count"], 5)
|
||||
self.assertEqual(request.completion_details["failed_count"], 0)
|
||||
|
||||
|
|
@ -593,17 +591,17 @@ class TestDeletionRequestAuditTrail(TestCase):
|
|||
ai_reason="Reason 1",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
request2 = DeletionRequest.objects.create(
|
||||
requested_by_ai=True,
|
||||
ai_reason="Reason 2",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
# Approve one, reject another
|
||||
request1.approve(self.approver, "Approved")
|
||||
request2.reject(self.approver, "Rejected")
|
||||
|
||||
|
||||
# Verify each has its own audit trail
|
||||
self.assertEqual(request1.status, DeletionRequest.STATUS_APPROVED)
|
||||
self.assertEqual(request2.status, DeletionRequest.STATUS_REJECTED)
|
||||
|
|
@ -625,13 +623,13 @@ class TestDeletionRequestModelRelationships(TestCase):
|
|||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
request_id = request.id
|
||||
self.assertEqual(DeletionRequest.objects.filter(id=request_id).count(), 1)
|
||||
|
||||
|
||||
# Delete user
|
||||
self.user.delete()
|
||||
|
||||
|
||||
# Request should be deleted
|
||||
self.assertEqual(DeletionRequest.objects.filter(id=request_id).count(), 0)
|
||||
|
||||
|
|
@ -649,15 +647,15 @@ class TestDeletionRequestModelRelationships(TestCase):
|
|||
checksum="checksum2",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
|
||||
request = DeletionRequest.objects.create(
|
||||
requested_by_ai=True,
|
||||
ai_reason="Test",
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
|
||||
request.documents.add(doc1, doc2)
|
||||
|
||||
|
||||
self.assertEqual(request.documents.count(), 2)
|
||||
self.assertEqual(doc1.deletion_requests.count(), 1)
|
||||
self.assertEqual(doc2.deletion_requests.count(), 1)
|
||||
|
|
@ -670,29 +668,29 @@ class TestDeletionRequestModelRelationships(TestCase):
|
|||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
self.assertIsNone(request.reviewed_by)
|
||||
|
||||
def test_reviewed_by_set_null_on_delete(self):
|
||||
"""Test that reviewed_by is set to null when reviewer is deleted."""
|
||||
approver = User.objects.create_user(username="approver", password="pass")
|
||||
|
||||
|
||||
request = DeletionRequest.objects.create(
|
||||
requested_by_ai=True,
|
||||
ai_reason="Test",
|
||||
user=self.user,
|
||||
status=DeletionRequest.STATUS_PENDING,
|
||||
)
|
||||
|
||||
|
||||
request.approve(approver, "Approved")
|
||||
self.assertEqual(request.reviewed_by, approver)
|
||||
|
||||
|
||||
# Delete approver
|
||||
approver.delete()
|
||||
|
||||
|
||||
# Refresh request
|
||||
request.refresh_from_db()
|
||||
|
||||
|
||||
# reviewed_by should be null
|
||||
self.assertIsNone(request.reviewed_by)
|
||||
# But the request should still exist
|
||||
|
|
|
|||
|
|
@ -8,11 +8,9 @@ from unittest import mock
|
|||
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.ml.model_cache import (
|
||||
CacheMetrics,
|
||||
LRUCache,
|
||||
ModelCacheManager,
|
||||
)
|
||||
from documents.ml.model_cache import CacheMetrics
|
||||
from documents.ml.model_cache import LRUCache
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
|
||||
class TestCacheMetrics(TestCase):
|
||||
|
|
@ -22,10 +20,10 @@ class TestCacheMetrics(TestCase):
|
|||
"""Test recording cache hits."""
|
||||
metrics = CacheMetrics()
|
||||
self.assertEqual(metrics.hits, 0)
|
||||
|
||||
|
||||
metrics.record_hit()
|
||||
self.assertEqual(metrics.hits, 1)
|
||||
|
||||
|
||||
metrics.record_hit()
|
||||
self.assertEqual(metrics.hits, 2)
|
||||
|
||||
|
|
@ -33,26 +31,26 @@ class TestCacheMetrics(TestCase):
|
|||
"""Test recording cache misses."""
|
||||
metrics = CacheMetrics()
|
||||
self.assertEqual(metrics.misses, 0)
|
||||
|
||||
|
||||
metrics.record_miss()
|
||||
self.assertEqual(metrics.misses, 1)
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting cache statistics."""
|
||||
metrics = CacheMetrics()
|
||||
|
||||
|
||||
# Initial stats
|
||||
stats = metrics.get_stats()
|
||||
self.assertEqual(stats["hits"], 0)
|
||||
self.assertEqual(stats["misses"], 0)
|
||||
self.assertEqual(stats["hit_rate"], "0.00%")
|
||||
|
||||
|
||||
# After some hits and misses
|
||||
metrics.record_hit()
|
||||
metrics.record_hit()
|
||||
metrics.record_hit()
|
||||
metrics.record_miss()
|
||||
|
||||
|
||||
stats = metrics.get_stats()
|
||||
self.assertEqual(stats["hits"], 3)
|
||||
self.assertEqual(stats["misses"], 1)
|
||||
|
|
@ -64,9 +62,9 @@ class TestCacheMetrics(TestCase):
|
|||
metrics = CacheMetrics()
|
||||
metrics.record_hit()
|
||||
metrics.record_miss()
|
||||
|
||||
|
||||
metrics.reset()
|
||||
|
||||
|
||||
stats = metrics.get_stats()
|
||||
self.assertEqual(stats["hits"], 0)
|
||||
self.assertEqual(stats["misses"], 0)
|
||||
|
|
@ -78,28 +76,28 @@ class TestLRUCache(TestCase):
|
|||
def test_put_and_get(self):
|
||||
"""Test basic cache operations."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
|
||||
|
||||
self.assertEqual(cache.get("key1"), "value1")
|
||||
self.assertEqual(cache.get("key2"), "value2")
|
||||
|
||||
def test_cache_miss(self):
|
||||
"""Test cache miss returns None."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
|
||||
result = cache.get("nonexistent")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Test LRU eviction policy."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
cache.put("key3", "value3") # Should evict key1
|
||||
|
||||
|
||||
self.assertIsNone(cache.get("key1")) # Evicted
|
||||
self.assertEqual(cache.get("key2"), "value2")
|
||||
self.assertEqual(cache.get("key3"), "value3")
|
||||
|
|
@ -107,12 +105,12 @@ class TestLRUCache(TestCase):
|
|||
def test_lru_update_access_order(self):
|
||||
"""Test that accessing an item updates its position."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
cache.get("key1") # Access key1, making it most recent
|
||||
cache.put("key3", "value3") # Should evict key2, not key1
|
||||
|
||||
|
||||
self.assertEqual(cache.get("key1"), "value1")
|
||||
self.assertIsNone(cache.get("key2")) # Evicted
|
||||
self.assertEqual(cache.get("key3"), "value3")
|
||||
|
|
@ -120,24 +118,24 @@ class TestLRUCache(TestCase):
|
|||
def test_cache_size(self):
|
||||
"""Test cache size tracking."""
|
||||
cache = LRUCache(max_size=3)
|
||||
|
||||
|
||||
self.assertEqual(cache.size(), 0)
|
||||
|
||||
|
||||
cache.put("key1", "value1")
|
||||
self.assertEqual(cache.size(), 1)
|
||||
|
||||
|
||||
cache.put("key2", "value2")
|
||||
self.assertEqual(cache.size(), 2)
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing cache."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
|
||||
|
||||
cache.clear()
|
||||
|
||||
|
||||
self.assertEqual(cache.size(), 0)
|
||||
self.assertIsNone(cache.get("key1"))
|
||||
self.assertIsNone(cache.get("key2"))
|
||||
|
|
@ -155,20 +153,20 @@ class TestModelCacheManager(TestCase):
|
|||
"""Test that ModelCacheManager is a singleton."""
|
||||
instance1 = ModelCacheManager.get_instance()
|
||||
instance2 = ModelCacheManager.get_instance()
|
||||
|
||||
|
||||
self.assertIs(instance1, instance2)
|
||||
|
||||
def test_get_or_load_model_first_time(self):
|
||||
"""Test loading a model for the first time (cache miss)."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
|
||||
# Mock loader function
|
||||
mock_model = mock.Mock()
|
||||
loader = mock.Mock(return_value=mock_model)
|
||||
|
||||
|
||||
# Load model
|
||||
result = cache_manager.get_or_load_model("test_model", loader)
|
||||
|
||||
|
||||
# Verify loader was called
|
||||
loader.assert_called_once()
|
||||
self.assertIs(result, mock_model)
|
||||
|
|
@ -176,17 +174,17 @@ class TestModelCacheManager(TestCase):
|
|||
def test_get_or_load_model_cached(self):
|
||||
"""Test loading a model from cache (cache hit)."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
|
||||
# Mock loader function
|
||||
mock_model = mock.Mock()
|
||||
loader = mock.Mock(return_value=mock_model)
|
||||
|
||||
|
||||
# Load model first time
|
||||
cache_manager.get_or_load_model("test_model", loader)
|
||||
|
||||
|
||||
# Load model second time (should be cached)
|
||||
result = cache_manager.get_or_load_model("test_model", loader)
|
||||
|
||||
|
||||
# Verify loader was only called once
|
||||
loader.assert_called_once()
|
||||
self.assertIs(result, mock_model)
|
||||
|
|
@ -197,48 +195,48 @@ class TestModelCacheManager(TestCase):
|
|||
cache_manager = ModelCacheManager.get_instance(
|
||||
disk_cache_dir=tmpdir,
|
||||
)
|
||||
|
||||
|
||||
# Create test embeddings
|
||||
embeddings = {
|
||||
1: "embedding1",
|
||||
2: "embedding2",
|
||||
3: "embedding3",
|
||||
}
|
||||
|
||||
|
||||
# Save to disk
|
||||
cache_manager.save_embeddings_to_disk("test_embeddings", embeddings)
|
||||
|
||||
|
||||
# Verify file was created
|
||||
cache_file = Path(tmpdir) / "test_embeddings.pkl"
|
||||
self.assertTrue(cache_file.exists())
|
||||
|
||||
|
||||
# Load from disk
|
||||
loaded = cache_manager.load_embeddings_from_disk("test_embeddings")
|
||||
|
||||
|
||||
# Verify embeddings match
|
||||
self.assertEqual(loaded, embeddings)
|
||||
|
||||
def test_get_metrics(self):
|
||||
"""Test getting cache metrics."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
|
||||
# Mock loader
|
||||
loader = mock.Mock(return_value=mock.Mock())
|
||||
|
||||
|
||||
# Generate some cache activity
|
||||
cache_manager.get_or_load_model("model1", loader)
|
||||
cache_manager.get_or_load_model("model1", loader) # Cache hit
|
||||
cache_manager.get_or_load_model("model2", loader)
|
||||
|
||||
|
||||
# Get metrics
|
||||
metrics = cache_manager.get_metrics()
|
||||
|
||||
|
||||
# Verify metrics structure
|
||||
self.assertIn("hits", metrics)
|
||||
self.assertIn("misses", metrics)
|
||||
self.assertIn("cache_size", metrics)
|
||||
self.assertIn("max_size", metrics)
|
||||
|
||||
|
||||
# Verify hit/miss counts
|
||||
self.assertEqual(metrics["hits"], 1) # One cache hit
|
||||
self.assertEqual(metrics["misses"], 2) # Two cache misses
|
||||
|
|
@ -249,21 +247,21 @@ class TestModelCacheManager(TestCase):
|
|||
cache_manager = ModelCacheManager.get_instance(
|
||||
disk_cache_dir=tmpdir,
|
||||
)
|
||||
|
||||
|
||||
# Add some models to cache
|
||||
loader = mock.Mock(return_value=mock.Mock())
|
||||
cache_manager.get_or_load_model("model1", loader)
|
||||
|
||||
|
||||
# Add embeddings to disk
|
||||
embeddings = {1: "embedding1"}
|
||||
cache_manager.save_embeddings_to_disk("test", embeddings)
|
||||
|
||||
|
||||
# Clear all
|
||||
cache_manager.clear_all()
|
||||
|
||||
|
||||
# Verify memory cache is cleared
|
||||
self.assertEqual(cache_manager.model_cache.size(), 0)
|
||||
|
||||
|
||||
# Verify disk cache is cleared
|
||||
cache_file = Path(tmpdir) / "test.pkl"
|
||||
self.assertFalse(cache_file.exists())
|
||||
|
|
@ -271,22 +269,22 @@ class TestModelCacheManager(TestCase):
|
|||
def test_warm_up(self):
|
||||
"""Test model warm-up functionality."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
|
||||
# Create mock loaders
|
||||
model1 = mock.Mock()
|
||||
model2 = mock.Mock()
|
||||
|
||||
|
||||
loaders = {
|
||||
"model1": mock.Mock(return_value=model1),
|
||||
"model2": mock.Mock(return_value=model2),
|
||||
}
|
||||
|
||||
|
||||
# Warm up
|
||||
cache_manager.warm_up(loaders)
|
||||
|
||||
|
||||
# Verify all loaders were called
|
||||
for loader in loaders.values():
|
||||
loader.assert_called_once()
|
||||
|
||||
|
||||
# Verify models are cached
|
||||
self.assertEqual(cache_manager.model_cache.size(), 2)
|
||||
|
|
|
|||
|
|
@ -150,7 +150,6 @@ class TestMLCacheDirectory:
|
|||
|
||||
def test_model_cache_writable(self, tmp_path):
|
||||
"""Test that we can write to model cache directory."""
|
||||
import pathlib
|
||||
|
||||
# Use tmp_path fixture for testing
|
||||
cache_dir = tmp_path / ".cache" / "huggingface"
|
||||
|
|
@ -169,7 +168,6 @@ class TestMLCacheDirectory:
|
|||
|
||||
def test_torch_cache_directory(self, tmp_path, monkeypatch):
|
||||
"""Test that PyTorch can use a custom cache directory."""
|
||||
import torch
|
||||
|
||||
# Set custom cache directory
|
||||
cache_dir = tmp_path / ".cache" / "torch"
|
||||
|
|
@ -204,9 +202,10 @@ class TestMLPerformanceBasic:
|
|||
|
||||
def test_numpy_performance_basic(self):
|
||||
"""Test basic NumPy performance with larger arrays."""
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Create large array (10 million elements)
|
||||
arr = np.random.rand(10_000_000)
|
||||
|
||||
|
|
|
|||
|
|
@ -89,6 +89,8 @@ from rest_framework.viewsets import ViewSet
|
|||
|
||||
from documents import bulk_edit
|
||||
from documents import index
|
||||
from documents.ai_scanner import AIDocumentScanner
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.bulk_download import ArchiveOnlyStrategy
|
||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||
from documents.bulk_download import OriginalsOnlyStrategy
|
||||
|
|
@ -141,13 +143,10 @@ from documents.models import UiSettings
|
|||
from documents.models import Workflow
|
||||
from documents.models import WorkflowAction
|
||||
from documents.models import WorkflowTrigger
|
||||
from documents.ai_scanner import AIDocumentScanner
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.parsers import get_parser_class_for_mime_type
|
||||
from documents.parsers import parse_date_generator
|
||||
from documents.permissions import AcknowledgeTasksPermissions
|
||||
from documents.permissions import CanApplyAISuggestionsPermission
|
||||
from documents.permissions import CanApproveDeletionsPermission
|
||||
from documents.permissions import CanConfigureAIPermission
|
||||
from documents.permissions import CanViewAISuggestionsPermission
|
||||
from documents.permissions import PaperlessAdminPermissions
|
||||
|
|
@ -1370,36 +1369,36 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
"""
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
from documents.serializers.ai_suggestions import AISuggestionsSerializer
|
||||
|
||||
|
||||
try:
|
||||
document = self.get_object()
|
||||
|
||||
|
||||
# Check if document has content to scan
|
||||
if not document.content:
|
||||
return Response(
|
||||
{"detail": "Document has no content to analyze"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
|
||||
# Get AI scanner instance
|
||||
scanner = get_ai_scanner()
|
||||
|
||||
|
||||
# Perform AI scan
|
||||
scan_result = scanner.scan_document(
|
||||
document=document,
|
||||
document_text=document.content,
|
||||
original_file_path=document.source_path if hasattr(document, 'source_path') else None,
|
||||
original_file_path=document.source_path if hasattr(document, "source_path") else None,
|
||||
)
|
||||
|
||||
|
||||
# Convert scan result to serializable format
|
||||
data = AISuggestionsSerializer.from_scan_result(scan_result, document.id)
|
||||
|
||||
|
||||
# Serialize and return
|
||||
serializer = AISuggestionsSerializer(data=data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
|
||||
return Response(serializer.validated_data)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting AI suggestions for document {pk}: {e}", exc_info=True)
|
||||
return Response(
|
||||
|
|
@ -1416,56 +1415,56 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
"""
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.serializers.ai_suggestions import ApplySuggestionSerializer
|
||||
|
||||
|
||||
try:
|
||||
document = self.get_object()
|
||||
|
||||
|
||||
# Validate input
|
||||
serializer = ApplySuggestionSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
suggestion_type = serializer.validated_data['suggestion_type']
|
||||
value_id = serializer.validated_data.get('value_id')
|
||||
value_text = serializer.validated_data.get('value_text')
|
||||
confidence = serializer.validated_data['confidence']
|
||||
|
||||
|
||||
suggestion_type = serializer.validated_data["suggestion_type"]
|
||||
value_id = serializer.validated_data.get("value_id")
|
||||
value_text = serializer.validated_data.get("value_text")
|
||||
confidence = serializer.validated_data["confidence"]
|
||||
|
||||
# Apply the suggestion based on type
|
||||
applied = False
|
||||
result_message = ""
|
||||
|
||||
if suggestion_type == 'tag' and value_id:
|
||||
|
||||
if suggestion_type == "tag" and value_id:
|
||||
tag = Tag.objects.get(pk=value_id)
|
||||
document.tags.add(tag)
|
||||
applied = True
|
||||
result_message = f"Tag '{tag.name}' applied"
|
||||
|
||||
elif suggestion_type == 'correspondent' and value_id:
|
||||
|
||||
elif suggestion_type == "correspondent" and value_id:
|
||||
correspondent = Correspondent.objects.get(pk=value_id)
|
||||
document.correspondent = correspondent
|
||||
document.save()
|
||||
applied = True
|
||||
result_message = f"Correspondent '{correspondent.name}' applied"
|
||||
|
||||
elif suggestion_type == 'document_type' and value_id:
|
||||
|
||||
elif suggestion_type == "document_type" and value_id:
|
||||
doc_type = DocumentType.objects.get(pk=value_id)
|
||||
document.document_type = doc_type
|
||||
document.save()
|
||||
applied = True
|
||||
result_message = f"Document type '{doc_type.name}' applied"
|
||||
|
||||
elif suggestion_type == 'storage_path' and value_id:
|
||||
|
||||
elif suggestion_type == "storage_path" and value_id:
|
||||
storage_path = StoragePath.objects.get(pk=value_id)
|
||||
document.storage_path = storage_path
|
||||
document.save()
|
||||
applied = True
|
||||
result_message = f"Storage path '{storage_path.name}' applied"
|
||||
|
||||
elif suggestion_type == 'title' and value_text:
|
||||
|
||||
elif suggestion_type == "title" and value_text:
|
||||
document.title = value_text
|
||||
document.save()
|
||||
applied = True
|
||||
result_message = f"Title updated to '{value_text}'"
|
||||
|
||||
|
||||
if applied:
|
||||
# Record feedback
|
||||
AISuggestionFeedback.objects.create(
|
||||
|
|
@ -1477,7 +1476,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
user=request.user,
|
||||
)
|
||||
|
||||
|
||||
return Response({
|
||||
"status": "success",
|
||||
"message": result_message,
|
||||
|
|
@ -1487,8 +1486,8 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
{"detail": "Invalid suggestion type or missing value"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
except (Tag.DoesNotExist, Correspondent.DoesNotExist,
|
||||
|
||||
except (Tag.DoesNotExist, Correspondent.DoesNotExist,
|
||||
DocumentType.DoesNotExist, StoragePath.DoesNotExist):
|
||||
return Response(
|
||||
{"detail": "Referenced object not found"},
|
||||
|
|
@ -1510,19 +1509,19 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
"""
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.serializers.ai_suggestions import RejectSuggestionSerializer
|
||||
|
||||
|
||||
try:
|
||||
document = self.get_object()
|
||||
|
||||
|
||||
# Validate input
|
||||
serializer = RejectSuggestionSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
suggestion_type = serializer.validated_data['suggestion_type']
|
||||
value_id = serializer.validated_data.get('value_id')
|
||||
value_text = serializer.validated_data.get('value_text')
|
||||
confidence = serializer.validated_data['confidence']
|
||||
|
||||
|
||||
suggestion_type = serializer.validated_data["suggestion_type"]
|
||||
value_id = serializer.validated_data.get("value_id")
|
||||
value_text = serializer.validated_data.get("value_text")
|
||||
confidence = serializer.validated_data["confidence"]
|
||||
|
||||
# Record feedback
|
||||
AISuggestionFeedback.objects.create(
|
||||
document=document,
|
||||
|
|
@ -1533,12 +1532,12 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
user=request.user,
|
||||
)
|
||||
|
||||
|
||||
return Response({
|
||||
"status": "success",
|
||||
"message": "Suggestion rejected and feedback recorded",
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rejecting suggestion for document {pk}: {e}", exc_info=True)
|
||||
return Response(
|
||||
|
|
@ -1554,78 +1553,83 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||
Returns aggregated data about applied vs rejected suggestions,
|
||||
accuracy rates, and confidence scores.
|
||||
"""
|
||||
from django.db.models import Avg, Count, Q
|
||||
from django.db.models import Avg
|
||||
from django.db.models import Count
|
||||
from django.db.models import Q
|
||||
|
||||
from documents.models import AISuggestionFeedback
|
||||
from documents.serializers.ai_suggestions import AISuggestionStatsSerializer
|
||||
|
||||
|
||||
try:
|
||||
# Get overall counts
|
||||
total_feedbacks = AISuggestionFeedback.objects.count()
|
||||
total_applied = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_APPLIED
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
).count()
|
||||
total_rejected = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_REJECTED
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
).count()
|
||||
|
||||
|
||||
# Calculate accuracy rate
|
||||
accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0
|
||||
|
||||
|
||||
# Get statistics by suggestion type using a single aggregated query
|
||||
stats_by_type = AISuggestionFeedback.objects.values('suggestion_type').annotate(
|
||||
total=Count('id'),
|
||||
applied=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
|
||||
rejected=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_REJECTED))
|
||||
stats_by_type = AISuggestionFeedback.objects.values("suggestion_type").annotate(
|
||||
total=Count("id"),
|
||||
applied=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
|
||||
rejected=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)),
|
||||
)
|
||||
|
||||
|
||||
# Build the by_type dictionary using the aggregated results
|
||||
by_type = {}
|
||||
for stat in stats_by_type:
|
||||
suggestion_type = stat['suggestion_type']
|
||||
type_total = stat['total']
|
||||
type_applied = stat['applied']
|
||||
type_rejected = stat['rejected']
|
||||
|
||||
suggestion_type = stat["suggestion_type"]
|
||||
type_total = stat["total"]
|
||||
type_applied = stat["applied"]
|
||||
type_rejected = stat["rejected"]
|
||||
|
||||
by_type[suggestion_type] = {
|
||||
'total': type_total,
|
||||
'applied': type_applied,
|
||||
'rejected': type_rejected,
|
||||
'accuracy_rate': (type_applied / type_total * 100) if type_total > 0 else 0,
|
||||
"total": type_total,
|
||||
"applied": type_applied,
|
||||
"rejected": type_rejected,
|
||||
"accuracy_rate": (type_applied / type_total * 100) if type_total > 0 else 0,
|
||||
}
|
||||
|
||||
|
||||
# Get average confidence scores
|
||||
avg_confidence_applied = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_APPLIED
|
||||
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
|
||||
|
||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
|
||||
|
||||
avg_confidence_rejected = AISuggestionFeedback.objects.filter(
|
||||
status=AISuggestionFeedback.STATUS_REJECTED
|
||||
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
|
||||
|
||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
|
||||
|
||||
# Get recent suggestions (last 10)
|
||||
recent_suggestions = AISuggestionFeedback.objects.order_by('-created_at')[:10]
|
||||
|
||||
recent_suggestions = AISuggestionFeedback.objects.order_by("-created_at")[:10]
|
||||
|
||||
# Build response data
|
||||
from documents.serializers.ai_suggestions import AISuggestionFeedbackSerializer
|
||||
from documents.serializers.ai_suggestions import (
|
||||
AISuggestionFeedbackSerializer,
|
||||
)
|
||||
data = {
|
||||
'total_suggestions': total_feedbacks,
|
||||
'total_applied': total_applied,
|
||||
'total_rejected': total_rejected,
|
||||
'accuracy_rate': accuracy_rate,
|
||||
'by_type': by_type,
|
||||
'average_confidence_applied': avg_confidence_applied,
|
||||
'average_confidence_rejected': avg_confidence_rejected,
|
||||
'recent_suggestions': AISuggestionFeedbackSerializer(
|
||||
recent_suggestions, many=True
|
||||
"total_suggestions": total_feedbacks,
|
||||
"total_applied": total_applied,
|
||||
"total_rejected": total_rejected,
|
||||
"accuracy_rate": accuracy_rate,
|
||||
"by_type": by_type,
|
||||
"average_confidence_applied": avg_confidence_applied,
|
||||
"average_confidence_rejected": avg_confidence_rejected,
|
||||
"recent_suggestions": AISuggestionFeedbackSerializer(
|
||||
recent_suggestions, many=True,
|
||||
).data,
|
||||
}
|
||||
|
||||
|
||||
# Serialize and return
|
||||
serializer = AISuggestionStatsSerializer(data=data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
|
||||
return Response(serializer.validated_data)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting AI suggestion statistics: {e}", exc_info=True)
|
||||
return Response(
|
||||
|
|
@ -3561,37 +3565,37 @@ class AISuggestionsView(GenericAPIView):
|
|||
|
||||
Requires: can_view_ai_suggestions permission
|
||||
"""
|
||||
|
||||
|
||||
permission_classes = [IsAuthenticated, CanViewAISuggestionsPermission]
|
||||
serializer_class = AISuggestionsResponseSerializer
|
||||
|
||||
|
||||
def post(self, request):
|
||||
"""Get AI suggestions for a document."""
|
||||
# Validate request
|
||||
request_serializer = AISuggestionsRequestSerializer(data=request.data)
|
||||
request_serializer.is_valid(raise_exception=True)
|
||||
|
||||
document_id = request_serializer.validated_data['document_id']
|
||||
|
||||
|
||||
document_id = request_serializer.validated_data["document_id"]
|
||||
|
||||
try:
|
||||
document = Document.objects.get(pk=document_id)
|
||||
except Document.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "Document not found or you don't have permission to view it"},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
# Check if user has permission to view this document
|
||||
if not has_perms_owner_aware(request.user, 'documents.view_document', document):
|
||||
if not has_perms_owner_aware(request.user, "documents.view_document", document):
|
||||
return Response(
|
||||
{"error": "Permission denied"},
|
||||
status=status.HTTP_403_FORBIDDEN
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
# Get AI scanner and scan document
|
||||
scanner = get_ai_scanner()
|
||||
scan_result = scanner.scan_document(document, document.content or "")
|
||||
|
||||
|
||||
# Build response
|
||||
response_data = {
|
||||
"document_id": document.id,
|
||||
|
|
@ -3600,9 +3604,9 @@ class AISuggestionsView(GenericAPIView):
|
|||
"document_type": None,
|
||||
"storage_path": None,
|
||||
"title_suggestion": scan_result.title_suggestion,
|
||||
"custom_fields": {}
|
||||
"custom_fields": {},
|
||||
}
|
||||
|
||||
|
||||
# Format tag suggestions
|
||||
for tag_id, confidence in scan_result.tags:
|
||||
try:
|
||||
|
|
@ -3610,12 +3614,12 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["tags"].append({
|
||||
"id": tag.id,
|
||||
"name": tag.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
})
|
||||
except Tag.DoesNotExist:
|
||||
# Tag was suggested by AI but no longer exists; skip it
|
||||
pass
|
||||
|
||||
|
||||
# Format correspondent suggestion
|
||||
if scan_result.correspondent:
|
||||
corr_id, confidence = scan_result.correspondent
|
||||
|
|
@ -3624,12 +3628,12 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["correspondent"] = {
|
||||
"id": correspondent.id,
|
||||
"name": correspondent.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
except Correspondent.DoesNotExist:
|
||||
# Correspondent was suggested but no longer exists; skip it
|
||||
pass
|
||||
|
||||
|
||||
# Format document type suggestion
|
||||
if scan_result.document_type:
|
||||
type_id, confidence = scan_result.document_type
|
||||
|
|
@ -3638,12 +3642,12 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["document_type"] = {
|
||||
"id": doc_type.id,
|
||||
"name": doc_type.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
except DocumentType.DoesNotExist:
|
||||
# Document type was suggested but no longer exists; skip it
|
||||
pass
|
||||
|
||||
|
||||
# Format storage path suggestion
|
||||
if scan_result.storage_path:
|
||||
path_id, confidence = scan_result.storage_path
|
||||
|
|
@ -3652,19 +3656,19 @@ class AISuggestionsView(GenericAPIView):
|
|||
response_data["storage_path"] = {
|
||||
"id": storage_path.id,
|
||||
"name": storage_path.name,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
except StoragePath.DoesNotExist:
|
||||
# Storage path was suggested but no longer exists; skip it
|
||||
pass
|
||||
|
||||
|
||||
# Format custom fields
|
||||
for field_id, (value, confidence) in scan_result.custom_fields.items():
|
||||
response_data["custom_fields"][str(field_id)] = {
|
||||
"value": value,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
|
||||
|
||||
return Response(response_data)
|
||||
|
||||
|
||||
|
|
@ -3674,48 +3678,48 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
|
||||
Requires: can_apply_ai_suggestions permission
|
||||
"""
|
||||
|
||||
|
||||
permission_classes = [IsAuthenticated, CanApplyAISuggestionsPermission]
|
||||
|
||||
|
||||
def post(self, request):
|
||||
"""Apply AI suggestions to a document."""
|
||||
# Validate request
|
||||
serializer = ApplyAISuggestionsSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
document_id = serializer.validated_data['document_id']
|
||||
|
||||
|
||||
document_id = serializer.validated_data["document_id"]
|
||||
|
||||
try:
|
||||
document = Document.objects.get(pk=document_id)
|
||||
except Document.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "Document not found"},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
# Check if user has permission to change this document
|
||||
if not has_perms_owner_aware(request.user, 'documents.change_document', document):
|
||||
if not has_perms_owner_aware(request.user, "documents.change_document", document):
|
||||
return Response(
|
||||
{"error": "Permission denied"},
|
||||
status=status.HTTP_403_FORBIDDEN
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
# Get AI scanner and scan document
|
||||
scanner = get_ai_scanner()
|
||||
scan_result = scanner.scan_document(document, document.content or "")
|
||||
|
||||
|
||||
# Apply suggestions based on user selections
|
||||
applied = []
|
||||
|
||||
if serializer.validated_data.get('apply_tags'):
|
||||
selected_tags = serializer.validated_data.get('selected_tags', [])
|
||||
|
||||
if serializer.validated_data.get("apply_tags"):
|
||||
selected_tags = serializer.validated_data.get("selected_tags", [])
|
||||
if selected_tags:
|
||||
# Apply only selected tags
|
||||
tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags]
|
||||
else:
|
||||
# Apply all high-confidence tags
|
||||
tags_to_apply = [tag_id for tag_id, conf in scan_result.tags if conf >= scanner.auto_apply_threshold]
|
||||
|
||||
|
||||
for tag_id in tags_to_apply:
|
||||
try:
|
||||
tag = Tag.objects.get(pk=tag_id)
|
||||
|
|
@ -3724,8 +3728,8 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
except Tag.DoesNotExist:
|
||||
# Tag not found; skip applying this tag
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_correspondent') and scan_result.correspondent:
|
||||
|
||||
if serializer.validated_data.get("apply_correspondent") and scan_result.correspondent:
|
||||
corr_id, confidence = scan_result.correspondent
|
||||
try:
|
||||
correspondent = Correspondent.objects.get(pk=corr_id)
|
||||
|
|
@ -3734,8 +3738,8 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
except Correspondent.DoesNotExist:
|
||||
# Correspondent not found; skip applying
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_document_type') and scan_result.document_type:
|
||||
|
||||
if serializer.validated_data.get("apply_document_type") and scan_result.document_type:
|
||||
type_id, confidence = scan_result.document_type
|
||||
try:
|
||||
doc_type = DocumentType.objects.get(pk=type_id)
|
||||
|
|
@ -3744,8 +3748,8 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
except DocumentType.DoesNotExist:
|
||||
# Document type not found; skip applying
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_storage_path') and scan_result.storage_path:
|
||||
|
||||
if serializer.validated_data.get("apply_storage_path") and scan_result.storage_path:
|
||||
path_id, confidence = scan_result.storage_path
|
||||
try:
|
||||
storage_path = StoragePath.objects.get(pk=path_id)
|
||||
|
|
@ -3754,18 +3758,18 @@ class ApplyAISuggestionsView(GenericAPIView):
|
|||
except StoragePath.DoesNotExist:
|
||||
# Storage path not found; skip applying
|
||||
pass
|
||||
|
||||
if serializer.validated_data.get('apply_title') and scan_result.title_suggestion:
|
||||
|
||||
if serializer.validated_data.get("apply_title") and scan_result.title_suggestion:
|
||||
document.title = scan_result.title_suggestion
|
||||
applied.append(f"title: {scan_result.title_suggestion}")
|
||||
|
||||
|
||||
# Save document
|
||||
document.save()
|
||||
|
||||
|
||||
return Response({
|
||||
"status": "success",
|
||||
"document_id": document.id,
|
||||
"applied": applied
|
||||
"applied": applied,
|
||||
})
|
||||
|
||||
|
||||
|
|
@ -3775,23 +3779,23 @@ class AIConfigurationView(GenericAPIView):
|
|||
|
||||
Requires: can_configure_ai permission
|
||||
"""
|
||||
|
||||
|
||||
permission_classes = [IsAuthenticated, CanConfigureAIPermission]
|
||||
|
||||
|
||||
def get(self, request):
|
||||
"""Get current AI configuration."""
|
||||
scanner = get_ai_scanner()
|
||||
|
||||
|
||||
config_data = {
|
||||
"auto_apply_threshold": scanner.auto_apply_threshold,
|
||||
"suggest_threshold": scanner.suggest_threshold,
|
||||
"ml_enabled": scanner.ml_enabled,
|
||||
"advanced_ocr_enabled": scanner.advanced_ocr_enabled,
|
||||
}
|
||||
|
||||
|
||||
serializer = AIConfigurationSerializer(config_data)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
def post(self, request):
|
||||
"""
|
||||
Update AI configuration.
|
||||
|
|
@ -3802,27 +3806,27 @@ class AIConfigurationView(GenericAPIView):
|
|||
"""
|
||||
serializer = AIConfigurationSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
|
||||
# Create new scanner with updated configuration
|
||||
config = {}
|
||||
if 'auto_apply_threshold' in serializer.validated_data:
|
||||
config['auto_apply_threshold'] = serializer.validated_data['auto_apply_threshold']
|
||||
if 'suggest_threshold' in serializer.validated_data:
|
||||
config['suggest_threshold'] = serializer.validated_data['suggest_threshold']
|
||||
if 'ml_enabled' in serializer.validated_data:
|
||||
config['enable_ml_features'] = serializer.validated_data['ml_enabled']
|
||||
if 'advanced_ocr_enabled' in serializer.validated_data:
|
||||
config['enable_advanced_ocr'] = serializer.validated_data['advanced_ocr_enabled']
|
||||
|
||||
if "auto_apply_threshold" in serializer.validated_data:
|
||||
config["auto_apply_threshold"] = serializer.validated_data["auto_apply_threshold"]
|
||||
if "suggest_threshold" in serializer.validated_data:
|
||||
config["suggest_threshold"] = serializer.validated_data["suggest_threshold"]
|
||||
if "ml_enabled" in serializer.validated_data:
|
||||
config["enable_ml_features"] = serializer.validated_data["ml_enabled"]
|
||||
if "advanced_ocr_enabled" in serializer.validated_data:
|
||||
config["enable_advanced_ocr"] = serializer.validated_data["advanced_ocr_enabled"]
|
||||
|
||||
# Update global scanner instance
|
||||
# WARNING: Not thread-safe. Consider storing configuration in database
|
||||
# and reloading on each get_ai_scanner() call for production use
|
||||
from documents import ai_scanner
|
||||
ai_scanner._scanner_instance = AIDocumentScanner(**config)
|
||||
|
||||
|
||||
return Response({
|
||||
"status": "success",
|
||||
"message": "AI configuration updated. Changes may require server restart for consistency."
|
||||
"message": "AI configuration updated. Changes may require server restart for consistency.",
|
||||
})
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
|
||||
Provides CRUD operations plus custom actions for approval workflow.
|
||||
"""
|
||||
|
||||
|
||||
model = DeletionRequest
|
||||
serializer_class = DeletionRequestSerializer
|
||||
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
Return deletion requests for the current user.
|
||||
|
|
@ -45,7 +45,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
if user.is_superuser:
|
||||
return DeletionRequest.objects.all()
|
||||
return DeletionRequest.objects.filter(user=user)
|
||||
|
||||
|
||||
def _can_manage_request(self, deletion_request):
|
||||
"""
|
||||
Check if current user can manage (approve/reject/cancel) the request.
|
||||
|
|
@ -58,7 +58,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
"""
|
||||
user = self.request.user
|
||||
return user.is_superuser or deletion_request.user == user
|
||||
|
||||
|
||||
@action(methods=["post"], detail=True)
|
||||
def approve(self, request, pk=None):
|
||||
"""
|
||||
|
|
@ -72,13 +72,13 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
Response with execution results
|
||||
"""
|
||||
deletion_request = self.get_object()
|
||||
|
||||
|
||||
# Check permissions
|
||||
if not self._can_manage_request(deletion_request):
|
||||
return HttpResponseForbidden(
|
||||
"You don't have permission to approve this deletion request."
|
||||
"You don't have permission to approve this deletion request.",
|
||||
)
|
||||
|
||||
|
||||
# Validate status
|
||||
if deletion_request.status != DeletionRequest.STATUS_PENDING:
|
||||
return Response(
|
||||
|
|
@ -88,9 +88,9 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
comment = request.data.get("comment", "")
|
||||
|
||||
|
||||
# Execute approval and deletion in a transaction
|
||||
try:
|
||||
with transaction.atomic():
|
||||
|
|
@ -100,12 +100,12 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
{"error": "Failed to approve deletion request."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
# Execute the deletion
|
||||
documents = list(deletion_request.documents.all())
|
||||
deleted_count = 0
|
||||
failed_deletions = []
|
||||
|
||||
|
||||
for doc in documents:
|
||||
try:
|
||||
doc_id = doc.id
|
||||
|
|
@ -114,18 +114,18 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
deleted_count += 1
|
||||
logger.info(
|
||||
f"Deleted document {doc_id} ('{doc_title}') "
|
||||
f"as part of deletion request {deletion_request.id}"
|
||||
f"as part of deletion request {deletion_request.id}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete document {doc.id}: {str(e)}"
|
||||
f"Failed to delete document {doc.id}: {e!s}",
|
||||
)
|
||||
failed_deletions.append({
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
|
||||
# Update completion status
|
||||
deletion_request.status = DeletionRequest.STATUS_COMPLETED
|
||||
deletion_request.completed_at = timezone.now()
|
||||
|
|
@ -135,20 +135,20 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
"total_documents": len(documents),
|
||||
}
|
||||
deletion_request.save()
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Deletion request {deletion_request.id} completed. "
|
||||
f"Deleted {deleted_count}/{len(documents)} documents."
|
||||
f"Deleted {deleted_count}/{len(documents)} documents.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error executing deletion request {deletion_request.id}: {str(e)}"
|
||||
f"Error executing deletion request {deletion_request.id}: {e!s}",
|
||||
)
|
||||
return Response(
|
||||
{"error": f"Failed to execute deletion: {str(e)}"},
|
||||
{"error": f"Failed to execute deletion: {e!s}"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
serializer = self.get_serializer(deletion_request)
|
||||
return Response(
|
||||
{
|
||||
|
|
@ -158,7 +158,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
|
||||
@action(methods=["post"], detail=True)
|
||||
def reject(self, request, pk=None):
|
||||
"""
|
||||
|
|
@ -172,13 +172,13 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
Response with updated deletion request
|
||||
"""
|
||||
deletion_request = self.get_object()
|
||||
|
||||
|
||||
# Check permissions
|
||||
if not self._can_manage_request(deletion_request):
|
||||
return HttpResponseForbidden(
|
||||
"You don't have permission to reject this deletion request."
|
||||
"You don't have permission to reject this deletion request.",
|
||||
)
|
||||
|
||||
|
||||
# Validate status
|
||||
if deletion_request.status != DeletionRequest.STATUS_PENDING:
|
||||
return Response(
|
||||
|
|
@ -188,20 +188,20 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
comment = request.data.get("comment", "")
|
||||
|
||||
|
||||
# Reject the request
|
||||
if not deletion_request.reject(request.user, comment):
|
||||
return Response(
|
||||
{"error": "Failed to reject deletion request."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Deletion request {deletion_request.id} rejected by user {request.user.username}"
|
||||
f"Deletion request {deletion_request.id} rejected by user {request.user.username}",
|
||||
)
|
||||
|
||||
|
||||
serializer = self.get_serializer(deletion_request)
|
||||
return Response(
|
||||
{
|
||||
|
|
@ -210,7 +210,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
|
||||
@action(methods=["post"], detail=True)
|
||||
def cancel(self, request, pk=None):
|
||||
"""
|
||||
|
|
@ -224,13 +224,13 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
Response with updated deletion request
|
||||
"""
|
||||
deletion_request = self.get_object()
|
||||
|
||||
|
||||
# Check permissions
|
||||
if not self._can_manage_request(deletion_request):
|
||||
return HttpResponseForbidden(
|
||||
"You don't have permission to cancel this deletion request."
|
||||
"You don't have permission to cancel this deletion request.",
|
||||
)
|
||||
|
||||
|
||||
# Validate status
|
||||
if deletion_request.status != DeletionRequest.STATUS_PENDING:
|
||||
return Response(
|
||||
|
|
@ -240,18 +240,18 @@ class DeletionRequestViewSet(ModelViewSet):
|
|||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
# Cancel the request
|
||||
deletion_request.status = DeletionRequest.STATUS_CANCELLED
|
||||
deletion_request.reviewed_by = request.user
|
||||
deletion_request.reviewed_at = timezone.now()
|
||||
deletion_request.review_comment = request.data.get("comment", "Cancelled by user")
|
||||
deletion_request.save()
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}"
|
||||
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}",
|
||||
)
|
||||
|
||||
|
||||
serializer = self.get_serializer(deletion_request)
|
||||
return Response(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class SecurityHeadersMiddleware:
|
|||
|
||||
# Store nonce in request for use in templates
|
||||
# Templates can access this via {{ request.csp_nonce }}
|
||||
if hasattr(request, '_csp_nonce'):
|
||||
if hasattr(request, "_csp_nonce"):
|
||||
request._csp_nonce = nonce
|
||||
|
||||
# Prevent clickjacking attacks
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
|||
|
||||
import hashlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
|
@ -26,39 +25,39 @@ logger = logging.getLogger("paperless.security")
|
|||
# Lista explícita de tipos MIME permitidos
|
||||
ALLOWED_MIME_TYPES = {
|
||||
# Documentos
|
||||
'application/pdf',
|
||||
'application/vnd.oasis.opendocument.text',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'application/vnd.oasis.opendocument.spreadsheet',
|
||||
'application/vnd.oasis.opendocument.presentation',
|
||||
'application/rtf',
|
||||
'text/rtf',
|
||||
"application/pdf",
|
||||
"application/vnd.oasis.opendocument.text",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/vnd.oasis.opendocument.spreadsheet",
|
||||
"application/vnd.oasis.opendocument.presentation",
|
||||
"application/rtf",
|
||||
"text/rtf",
|
||||
|
||||
# Imágenes
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/tiff',
|
||||
'image/bmp',
|
||||
'image/webp',
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/tiff",
|
||||
"image/bmp",
|
||||
"image/webp",
|
||||
|
||||
# Texto
|
||||
'text/plain',
|
||||
'text/html',
|
||||
'text/csv',
|
||||
'text/markdown',
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/csv",
|
||||
"text/markdown",
|
||||
}
|
||||
|
||||
# Maximum file size (100MB by default)
|
||||
# Can be overridden by settings.MAX_UPLOAD_SIZE
|
||||
try:
|
||||
from django.conf import settings
|
||||
MAX_FILE_SIZE = getattr(settings, 'MAX_UPLOAD_SIZE', 100 * 1024 * 1024) # 100MB por defecto
|
||||
MAX_FILE_SIZE = getattr(settings, "MAX_UPLOAD_SIZE", 100 * 1024 * 1024) # 100MB por defecto
|
||||
except ImportError:
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes
|
||||
|
||||
|
|
@ -114,7 +113,6 @@ ALLOWED_JS_PATTERNS = [
|
|||
class FileValidationError(Exception):
|
||||
"""Raised when file validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def has_whitelisted_javascript(content: bytes) -> bool:
|
||||
|
|
@ -143,7 +141,7 @@ def validate_mime_type(mime_type: str) -> None:
|
|||
if mime_type not in ALLOWED_MIME_TYPES:
|
||||
raise FileValidationError(
|
||||
f"MIME type '{mime_type}' is not allowed. "
|
||||
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}"
|
||||
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ api_router.register(r"config", ApplicationConfigurationViewSet)
|
|||
api_router.register(r"processed_mail", ProcessedMailViewSet)
|
||||
api_router.register(r"deletion_requests", DeletionRequestViewSet)
|
||||
api_router.register(
|
||||
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests"
|
||||
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue