mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-09 00:05:21 +01:00
fix(syntax): corrige errores de sintaxis y formato en Python
- Corrige paréntesis faltante en DeletionRequestActionSerializer (serialisers.py:2855) - Elimina espacios en blanco en líneas vacías (W293) - Elimina espacios finales en líneas (W291) - Elimina imports no utilizados (F401) - Normaliza comillas a comillas dobles (Q000) - Agrega comas finales faltantes (COM812) - Ordena imports según convenciones (I001) - Actualiza anotaciones de tipo a PEP 585 (UP006) Este commit resuelve el error de compilación en el job de CI/CD que estaba causando que fallara el linting check. Archivos afectados: 38 Líneas modificadas: ~2200
This commit is contained in:
parent
9298f64546
commit
69326b883d
38 changed files with 2077 additions and 2112 deletions
|
|
@ -14,14 +14,10 @@ According to agents.md requirements:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai_deletion")
|
logger = logging.getLogger("paperless.ai_deletion")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
@ -142,34 +141,34 @@ class AIScanResult:
|
||||||
"""
|
"""
|
||||||
# Convert internal tuple format to TypedDict format
|
# Convert internal tuple format to TypedDict format
|
||||||
result: AIScanResultDict = {
|
result: AIScanResultDict = {
|
||||||
'tags': [{'tag_id': tag_id, 'confidence': conf} for tag_id, conf in self.tags],
|
"tags": [{"tag_id": tag_id, "confidence": conf} for tag_id, conf in self.tags],
|
||||||
'custom_fields': {
|
"custom_fields": {
|
||||||
field_id: {'value': value, 'confidence': conf}
|
field_id: {"value": value, "confidence": conf}
|
||||||
for field_id, (value, conf) in self.custom_fields.items()
|
for field_id, (value, conf) in self.custom_fields.items()
|
||||||
},
|
},
|
||||||
'workflows': [{'workflow_id': wf_id, 'confidence': conf} for wf_id, conf in self.workflows],
|
"workflows": [{"workflow_id": wf_id, "confidence": conf} for wf_id, conf in self.workflows],
|
||||||
'extracted_entities': self.extracted_entities,
|
"extracted_entities": self.extracted_entities,
|
||||||
'metadata': self.metadata,
|
"metadata": self.metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add optional fields only if present
|
# Add optional fields only if present
|
||||||
if self.correspondent:
|
if self.correspondent:
|
||||||
result['correspondent'] = {
|
result["correspondent"] = {
|
||||||
'correspondent_id': self.correspondent[0],
|
"correspondent_id": self.correspondent[0],
|
||||||
'confidence': self.correspondent[1],
|
"confidence": self.correspondent[1],
|
||||||
}
|
}
|
||||||
if self.document_type:
|
if self.document_type:
|
||||||
result['document_type'] = {
|
result["document_type"] = {
|
||||||
'type_id': self.document_type[0],
|
"type_id": self.document_type[0],
|
||||||
'confidence': self.document_type[1],
|
"confidence": self.document_type[1],
|
||||||
}
|
}
|
||||||
if self.storage_path:
|
if self.storage_path:
|
||||||
result['storage_path'] = {
|
result["storage_path"] = {
|
||||||
'path_id': self.storage_path[0],
|
"path_id": self.storage_path[0],
|
||||||
'confidence': self.storage_path[1],
|
"confidence": self.storage_path[1],
|
||||||
}
|
}
|
||||||
if self.title_suggestion:
|
if self.title_suggestion:
|
||||||
result['title_suggestion'] = self.title_suggestion
|
result["title_suggestion"] = self.title_suggestion
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -1054,7 +1053,7 @@ class AIDocumentScanner:
|
||||||
warm_up_time = time.time() - start_time
|
warm_up_time = time.time() - start_time
|
||||||
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
|
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
|
||||||
|
|
||||||
def get_cache_metrics(self) -> Dict[str, Any]:
|
def get_cache_metrics(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get cache performance metrics.
|
Get cache performance metrics.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -310,7 +310,7 @@ class ConsumerPlugin(
|
||||||
# Parsing phase
|
# Parsing phase
|
||||||
document_parser = self._create_parser_instance(parser_class)
|
document_parser = self._create_parser_instance(parser_class)
|
||||||
text, date, thumbnail, archive_path, page_count = self._parse_document(
|
text, date, thumbnail, archive_path, page_count = self._parse_document(
|
||||||
document_parser, mime_type
|
document_parser, mime_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Storage phase
|
# Storage phase
|
||||||
|
|
@ -394,7 +394,7 @@ class ConsumerPlugin(
|
||||||
def _attempt_pdf_recovery(
|
def _attempt_pdf_recovery(
|
||||||
self,
|
self,
|
||||||
tempdir: tempfile.TemporaryDirectory,
|
tempdir: tempfile.TemporaryDirectory,
|
||||||
original_mime_type: str
|
original_mime_type: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Attempt to recover a PDF file with incorrect MIME type using qpdf.
|
Attempt to recover a PDF file with incorrect MIME type using qpdf.
|
||||||
|
|
@ -438,7 +438,7 @@ class ConsumerPlugin(
|
||||||
def _get_parser_class(
|
def _get_parser_class(
|
||||||
self,
|
self,
|
||||||
mime_type: str,
|
mime_type: str,
|
||||||
tempdir: tempfile.TemporaryDirectory
|
tempdir: tempfile.TemporaryDirectory,
|
||||||
) -> type[DocumentParser]:
|
) -> type[DocumentParser]:
|
||||||
"""
|
"""
|
||||||
Determine which parser to use based on MIME type.
|
Determine which parser to use based on MIME type.
|
||||||
|
|
@ -468,7 +468,7 @@ class ConsumerPlugin(
|
||||||
|
|
||||||
def _create_parser_instance(
|
def _create_parser_instance(
|
||||||
self,
|
self,
|
||||||
parser_class: type[DocumentParser]
|
parser_class: type[DocumentParser],
|
||||||
) -> DocumentParser:
|
) -> DocumentParser:
|
||||||
"""
|
"""
|
||||||
Create a parser instance with progress callback.
|
Create a parser instance with progress callback.
|
||||||
|
|
@ -496,7 +496,7 @@ class ConsumerPlugin(
|
||||||
def _parse_document(
|
def _parse_document(
|
||||||
self,
|
self,
|
||||||
document_parser: DocumentParser,
|
document_parser: DocumentParser,
|
||||||
mime_type: str
|
mime_type: str,
|
||||||
) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]:
|
) -> tuple[str, datetime.datetime | None, Path, Path | None, int | None]:
|
||||||
"""
|
"""
|
||||||
Parse the document and extract metadata.
|
Parse the document and extract metadata.
|
||||||
|
|
@ -670,7 +670,7 @@ class ConsumerPlugin(
|
||||||
self,
|
self,
|
||||||
document: Document,
|
document: Document,
|
||||||
thumbnail: Path,
|
thumbnail: Path,
|
||||||
archive_path: Path | None
|
archive_path: Path | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Store document files (source, thumbnail, archive) to disk.
|
Store document files (source, thumbnail, archive) to disk.
|
||||||
|
|
@ -949,7 +949,7 @@ class ConsumerPlugin(
|
||||||
text: The extracted document text
|
text: The extracted document text
|
||||||
"""
|
"""
|
||||||
# Check if AI scanner is enabled
|
# Check if AI scanner is enabled
|
||||||
if not getattr(settings, 'PAPERLESS_ENABLE_AI_SCANNER', True):
|
if not getattr(settings, "PAPERLESS_ENABLE_AI_SCANNER", True):
|
||||||
self.log.debug("AI scanner is disabled, skipping AI analysis")
|
self.log.debug("AI scanner is disabled, skipping AI analysis")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Generated manually for performance optimization
|
# Generated manually for performance optimization
|
||||||
|
|
||||||
from django.db import migrations, models
|
from django.db import migrations
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
# Generated manually for DeletionRequest model
|
# Generated manually for DeletionRequest model
|
||||||
# Based on model definition in documents/models.py
|
# Based on model definition in documents/models.py
|
||||||
|
|
||||||
from django.conf import settings
|
|
||||||
from django.db import migrations, models
|
|
||||||
import django.db.models.deletion
|
import django.db.models.deletion
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
@ -48,7 +49,7 @@ class Migration(migrations.Migration):
|
||||||
(
|
(
|
||||||
"ai_reason",
|
"ai_reason",
|
||||||
models.TextField(
|
models.TextField(
|
||||||
help_text="Detailed explanation from AI about why deletion is recommended"
|
help_text="Detailed explanation from AI about why deletion is recommended",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Generated manually for DeletionRequest performance optimization
|
# Generated manually for DeletionRequest performance optimization
|
||||||
|
|
||||||
from django.db import migrations, models
|
from django.db import migrations
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
# Generated manually for AI Suggestions API
|
# Generated manually for AI Suggestions API
|
||||||
|
|
||||||
from django.conf import settings
|
|
||||||
from django.db import migrations, models
|
|
||||||
import django.db.models.deletion
|
|
||||||
import django.core.validators
|
import django.core.validators
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,9 @@ Provides AI/ML capabilities including:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TransformerDocumentClassifier",
|
|
||||||
"DocumentNER",
|
"DocumentNER",
|
||||||
"SemanticSearch",
|
"SemanticSearch",
|
||||||
|
"TransformerDocumentClassifier",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Lazy imports to avoid loading heavy ML libraries unless needed
|
# Lazy imports to avoid loading heavy ML libraries unless needed
|
||||||
|
|
|
||||||
|
|
@ -15,23 +15,16 @@ Logging levels used in this module:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import (
|
from transformers import AutoModelForSequenceClassification
|
||||||
AutoModelForSequenceClassification,
|
from transformers import AutoTokenizer
|
||||||
AutoTokenizer,
|
from transformers import Trainer
|
||||||
Trainer,
|
from transformers import TrainingArguments
|
||||||
TrainingArguments,
|
|
||||||
)
|
|
||||||
|
|
||||||
from documents.ml.model_cache import ModelCacheManager
|
from documents.ml.model_cache import ModelCacheManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from documents.models import Document
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ml.classifier")
|
logger = logging.getLogger("paperless.ml.classifier")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,7 +134,7 @@ class TransformerDocumentClassifier:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initialized TransformerDocumentClassifier with {model_name} "
|
f"Initialized TransformerDocumentClassifier with {model_name} "
|
||||||
f"(caching: {use_cache})"
|
f"(caching: {use_cache})",
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,9 @@ import pickle
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ml.model_cache")
|
logger = logging.getLogger("paperless.ml.model_cache")
|
||||||
|
|
||||||
|
|
@ -58,7 +59,7 @@ class CacheMetrics:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.loads += 1
|
self.loads += 1
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> dict[str, Any]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
total = self.hits + self.misses
|
total = self.hits + self.misses
|
||||||
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
|
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
|
||||||
|
|
@ -98,7 +99,7 @@ class LRUCache:
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.metrics = CacheMetrics()
|
self.metrics = CacheMetrics()
|
||||||
|
|
||||||
def get(self, key: str) -> Optional[Any]:
|
def get(self, key: str) -> Any | None:
|
||||||
"""
|
"""
|
||||||
Get item from cache.
|
Get item from cache.
|
||||||
|
|
||||||
|
|
@ -153,7 +154,7 @@ class LRUCache:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return len(self.cache)
|
return len(self.cache)
|
||||||
|
|
||||||
def get_metrics(self) -> Dict[str, Any]:
|
def get_metrics(self) -> dict[str, Any]:
|
||||||
"""Get cache metrics."""
|
"""Get cache metrics."""
|
||||||
return self.metrics.get_stats()
|
return self.metrics.get_stats()
|
||||||
|
|
||||||
|
|
@ -173,7 +174,7 @@ class ModelCacheManager:
|
||||||
model = cache.get_or_load_model("classifier", loader_func)
|
model = cache.get_or_load_model("classifier", loader_func)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: Optional[ModelCacheManager] = None
|
_instance: ModelCacheManager | None = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
|
@ -187,7 +188,7 @@ class ModelCacheManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_models: int = 3,
|
max_models: int = 3,
|
||||||
disk_cache_dir: Optional[str] = None,
|
disk_cache_dir: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize model cache manager.
|
Initialize model cache manager.
|
||||||
|
|
@ -215,7 +216,7 @@ class ModelCacheManager:
|
||||||
def get_instance(
|
def get_instance(
|
||||||
cls,
|
cls,
|
||||||
max_models: int = 3,
|
max_models: int = 3,
|
||||||
disk_cache_dir: Optional[str] = None,
|
disk_cache_dir: str | None = None,
|
||||||
) -> ModelCacheManager:
|
) -> ModelCacheManager:
|
||||||
"""
|
"""
|
||||||
Get singleton instance of ModelCacheManager.
|
Get singleton instance of ModelCacheManager.
|
||||||
|
|
@ -278,7 +279,7 @@ class ModelCacheManager:
|
||||||
load_time = time.time() - start_time
|
load_time = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Model loaded successfully: {model_key} "
|
f"Model loaded successfully: {model_key} "
|
||||||
f"(took {load_time:.2f}s)"
|
f"(took {load_time:.2f}s)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
@ -289,7 +290,7 @@ class ModelCacheManager:
|
||||||
def save_embeddings_to_disk(
|
def save_embeddings_to_disk(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
embeddings: Dict[int, Any],
|
embeddings: dict[int, Any],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Save embeddings to disk cache.
|
Save embeddings to disk cache.
|
||||||
|
|
@ -311,7 +312,7 @@ class ModelCacheManager:
|
||||||
cache_file = self.disk_cache_dir / f"{key}.pkl"
|
cache_file = self.disk_cache_dir / f"{key}.pkl"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(cache_file, 'wb') as f:
|
with open(cache_file, "wb") as f:
|
||||||
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}")
|
logger.info(f"Saved {len(embeddings)} embeddings to {cache_file}")
|
||||||
return True
|
return True
|
||||||
|
|
@ -330,7 +331,7 @@ class ModelCacheManager:
|
||||||
def load_embeddings_from_disk(
|
def load_embeddings_from_disk(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
) -> Optional[Dict[int, Any]]:
|
) -> dict[int, Any] | None:
|
||||||
"""
|
"""
|
||||||
Load embeddings from disk cache.
|
Load embeddings from disk cache.
|
||||||
|
|
||||||
|
|
@ -393,7 +394,7 @@ class ModelCacheManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete {cache_file}: {e}")
|
logger.error(f"Failed to delete {cache_file}: {e}")
|
||||||
|
|
||||||
def get_metrics(self) -> Dict[str, Any]:
|
def get_metrics(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get cache performance metrics.
|
Get cache performance metrics.
|
||||||
|
|
||||||
|
|
@ -416,7 +417,7 @@ class ModelCacheManager:
|
||||||
|
|
||||||
def warm_up(
|
def warm_up(
|
||||||
self,
|
self,
|
||||||
model_loaders: Dict[str, Callable[[], Any]],
|
model_loaders: dict[str, Callable[[], Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Pre-load models on startup (warm-up).
|
Pre-load models on startup (warm-up).
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,11 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
|
||||||
from documents.ml.model_cache import ModelCacheManager
|
from documents.ml.model_cache import ModelCacheManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ml.ner")
|
logger = logging.getLogger("paperless.ml.ner")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,18 +18,14 @@ Examples:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer, util
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sentence_transformers import util
|
||||||
|
|
||||||
from documents.ml.model_cache import ModelCacheManager
|
from documents.ml.model_cache import ModelCacheManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ml.semantic_search")
|
logger = logging.getLogger("paperless.ml.semantic_search")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -67,7 +63,7 @@ class SemanticSearch:
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing SemanticSearch with model: {model_name} "
|
f"Initializing SemanticSearch with model: {model_name} "
|
||||||
f"(caching: {use_cache})"
|
f"(caching: {use_cache})",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
@ -127,11 +123,11 @@ class SemanticSearch:
|
||||||
if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor):
|
if not isinstance(embedding, np.ndarray) and not isinstance(embedding, torch.Tensor):
|
||||||
logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor")
|
logger.warning(f"Embedding for doc {doc_id} is not a numpy array or tensor")
|
||||||
return False
|
return False
|
||||||
if hasattr(embedding, 'size'):
|
if hasattr(embedding, "size"):
|
||||||
if embedding.size == 0:
|
if embedding.size == 0:
|
||||||
logger.warning(f"Embedding for doc {doc_id} is empty")
|
logger.warning(f"Embedding for doc {doc_id} is empty")
|
||||||
return False
|
return False
|
||||||
elif hasattr(embedding, 'numel'):
|
elif hasattr(embedding, "numel"):
|
||||||
if embedding.numel() == 0:
|
if embedding.numel() == 0:
|
||||||
logger.warning(f"Embedding for doc {doc_id} is empty")
|
logger.warning(f"Embedding for doc {doc_id} is empty")
|
||||||
return False
|
return False
|
||||||
|
|
@ -216,11 +212,11 @@ class SemanticSearch:
|
||||||
try:
|
try:
|
||||||
result = self.cache_manager.save_embeddings_to_disk(
|
result = self.cache_manager.save_embeddings_to_disk(
|
||||||
"document_embeddings",
|
"document_embeddings",
|
||||||
self.document_embeddings
|
self.document_embeddings,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully saved {len(self.document_embeddings)} embeddings to disk"
|
f"Successfully saved {len(self.document_embeddings)} embeddings to disk",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error("Failed to save embeddings to disk (returned False)")
|
logger.error("Failed to save embeddings to disk (returned False)")
|
||||||
|
|
|
||||||
|
|
@ -1604,30 +1604,30 @@ class DeletionRequest(models.Model):
|
||||||
# Requester (AI system)
|
# Requester (AI system)
|
||||||
requested_by_ai = models.BooleanField(default=True)
|
requested_by_ai = models.BooleanField(default=True)
|
||||||
ai_reason = models.TextField(
|
ai_reason = models.TextField(
|
||||||
help_text=_("Detailed explanation from AI about why deletion is recommended")
|
help_text=_("Detailed explanation from AI about why deletion is recommended"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# User who must approve
|
# User who must approve
|
||||||
user = models.ForeignKey(
|
user = models.ForeignKey(
|
||||||
User,
|
User,
|
||||||
on_delete=models.CASCADE,
|
on_delete=models.CASCADE,
|
||||||
related_name='deletion_requests',
|
related_name="deletion_requests",
|
||||||
help_text=_("User who must approve this deletion"),
|
help_text=_("User who must approve this deletion"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Status tracking
|
# Status tracking
|
||||||
STATUS_PENDING = 'pending'
|
STATUS_PENDING = "pending"
|
||||||
STATUS_APPROVED = 'approved'
|
STATUS_APPROVED = "approved"
|
||||||
STATUS_REJECTED = 'rejected'
|
STATUS_REJECTED = "rejected"
|
||||||
STATUS_CANCELLED = 'cancelled'
|
STATUS_CANCELLED = "cancelled"
|
||||||
STATUS_COMPLETED = 'completed'
|
STATUS_COMPLETED = "completed"
|
||||||
|
|
||||||
STATUS_CHOICES = [
|
STATUS_CHOICES = [
|
||||||
(STATUS_PENDING, _('Pending')),
|
(STATUS_PENDING, _("Pending")),
|
||||||
(STATUS_APPROVED, _('Approved')),
|
(STATUS_APPROVED, _("Approved")),
|
||||||
(STATUS_REJECTED, _('Rejected')),
|
(STATUS_REJECTED, _("Rejected")),
|
||||||
(STATUS_CANCELLED, _('Cancelled')),
|
(STATUS_CANCELLED, _("Cancelled")),
|
||||||
(STATUS_COMPLETED, _('Completed')),
|
(STATUS_COMPLETED, _("Completed")),
|
||||||
]
|
]
|
||||||
|
|
||||||
status = models.CharField(
|
status = models.CharField(
|
||||||
|
|
@ -1639,7 +1639,7 @@ class DeletionRequest(models.Model):
|
||||||
# Documents to be deleted
|
# Documents to be deleted
|
||||||
documents = models.ManyToManyField(
|
documents = models.ManyToManyField(
|
||||||
Document,
|
Document,
|
||||||
related_name='deletion_requests',
|
related_name="deletion_requests",
|
||||||
help_text=_("Documents that would be deleted if approved"),
|
help_text=_("Documents that would be deleted if approved"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1656,7 +1656,7 @@ class DeletionRequest(models.Model):
|
||||||
on_delete=models.SET_NULL,
|
on_delete=models.SET_NULL,
|
||||||
null=True,
|
null=True,
|
||||||
blank=True,
|
blank=True,
|
||||||
related_name='reviewed_deletion_requests',
|
related_name="reviewed_deletion_requests",
|
||||||
help_text=_("User who reviewed and approved/rejected"),
|
help_text=_("User who reviewed and approved/rejected"),
|
||||||
)
|
)
|
||||||
review_comment = models.TextField(
|
review_comment = models.TextField(
|
||||||
|
|
@ -1672,21 +1672,21 @@ class DeletionRequest(models.Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
ordering = ['-created_at']
|
ordering = ["-created_at"]
|
||||||
verbose_name = _("deletion request")
|
verbose_name = _("deletion request")
|
||||||
verbose_name_plural = _("deletion requests")
|
verbose_name_plural = _("deletion requests")
|
||||||
indexes = [
|
indexes = [
|
||||||
# Composite index for common listing queries (by user, filtered by status, sorted by date)
|
# Composite index for common listing queries (by user, filtered by status, sorted by date)
|
||||||
# PostgreSQL can use this index for queries on: user, user+status, user+status+created_at
|
# PostgreSQL can use this index for queries on: user, user+status, user+status+created_at
|
||||||
models.Index(fields=['user', 'status', 'created_at'], name='delreq_user_status_created_idx'),
|
models.Index(fields=["user", "status", "created_at"], name="delreq_user_status_created_idx"),
|
||||||
# Index for queries filtering by status and date without user filter
|
# Index for queries filtering by status and date without user filter
|
||||||
models.Index(fields=['status', 'created_at'], name='delreq_status_created_idx'),
|
models.Index(fields=["status", "created_at"], name="delreq_status_created_idx"),
|
||||||
# Index for queries filtering by user and date (common for user-specific views)
|
# Index for queries filtering by user and date (common for user-specific views)
|
||||||
models.Index(fields=['user', 'created_at'], name='delreq_user_created_idx'),
|
models.Index(fields=["user", "created_at"], name="delreq_user_created_idx"),
|
||||||
# Index for queries filtering by review date
|
# Index for queries filtering by review date
|
||||||
models.Index(fields=['reviewed_at'], name='delreq_reviewed_at_idx'),
|
models.Index(fields=["reviewed_at"], name="delreq_reviewed_at_idx"),
|
||||||
# Index for queries filtering by completion date
|
# Index for queries filtering by completion date
|
||||||
models.Index(fields=['completed_at'], name='delreq_completed_at_idx'),
|
models.Index(fields=["completed_at"], name="delreq_completed_at_idx"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
@ -1745,67 +1745,67 @@ class AISuggestionFeedback(models.Model):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Suggestion types
|
# Suggestion types
|
||||||
TYPE_TAG = 'tag'
|
TYPE_TAG = "tag"
|
||||||
TYPE_CORRESPONDENT = 'correspondent'
|
TYPE_CORRESPONDENT = "correspondent"
|
||||||
TYPE_DOCUMENT_TYPE = 'document_type'
|
TYPE_DOCUMENT_TYPE = "document_type"
|
||||||
TYPE_STORAGE_PATH = 'storage_path'
|
TYPE_STORAGE_PATH = "storage_path"
|
||||||
TYPE_CUSTOM_FIELD = 'custom_field'
|
TYPE_CUSTOM_FIELD = "custom_field"
|
||||||
TYPE_WORKFLOW = 'workflow'
|
TYPE_WORKFLOW = "workflow"
|
||||||
TYPE_TITLE = 'title'
|
TYPE_TITLE = "title"
|
||||||
|
|
||||||
SUGGESTION_TYPES = (
|
SUGGESTION_TYPES = (
|
||||||
(TYPE_TAG, _('Tag')),
|
(TYPE_TAG, _("Tag")),
|
||||||
(TYPE_CORRESPONDENT, _('Correspondent')),
|
(TYPE_CORRESPONDENT, _("Correspondent")),
|
||||||
(TYPE_DOCUMENT_TYPE, _('Document Type')),
|
(TYPE_DOCUMENT_TYPE, _("Document Type")),
|
||||||
(TYPE_STORAGE_PATH, _('Storage Path')),
|
(TYPE_STORAGE_PATH, _("Storage Path")),
|
||||||
(TYPE_CUSTOM_FIELD, _('Custom Field')),
|
(TYPE_CUSTOM_FIELD, _("Custom Field")),
|
||||||
(TYPE_WORKFLOW, _('Workflow')),
|
(TYPE_WORKFLOW, _("Workflow")),
|
||||||
(TYPE_TITLE, _('Title')),
|
(TYPE_TITLE, _("Title")),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Feedback status
|
# Feedback status
|
||||||
STATUS_APPLIED = 'applied'
|
STATUS_APPLIED = "applied"
|
||||||
STATUS_REJECTED = 'rejected'
|
STATUS_REJECTED = "rejected"
|
||||||
|
|
||||||
FEEDBACK_STATUS = (
|
FEEDBACK_STATUS = (
|
||||||
(STATUS_APPLIED, _('Applied')),
|
(STATUS_APPLIED, _("Applied")),
|
||||||
(STATUS_REJECTED, _('Rejected')),
|
(STATUS_REJECTED, _("Rejected")),
|
||||||
)
|
)
|
||||||
|
|
||||||
document = models.ForeignKey(
|
document = models.ForeignKey(
|
||||||
Document,
|
Document,
|
||||||
on_delete=models.CASCADE,
|
on_delete=models.CASCADE,
|
||||||
related_name='ai_suggestion_feedbacks',
|
related_name="ai_suggestion_feedbacks",
|
||||||
verbose_name=_('document'),
|
verbose_name=_("document"),
|
||||||
)
|
)
|
||||||
|
|
||||||
suggestion_type = models.CharField(
|
suggestion_type = models.CharField(
|
||||||
_('suggestion type'),
|
_("suggestion type"),
|
||||||
max_length=50,
|
max_length=50,
|
||||||
choices=SUGGESTION_TYPES,
|
choices=SUGGESTION_TYPES,
|
||||||
)
|
)
|
||||||
|
|
||||||
suggested_value_id = models.IntegerField(
|
suggested_value_id = models.IntegerField(
|
||||||
_('suggested value ID'),
|
_("suggested value ID"),
|
||||||
null=True,
|
null=True,
|
||||||
blank=True,
|
blank=True,
|
||||||
help_text=_('ID of the suggested object (tag, correspondent, etc.)'),
|
help_text=_("ID of the suggested object (tag, correspondent, etc.)"),
|
||||||
)
|
)
|
||||||
|
|
||||||
suggested_value_text = models.TextField(
|
suggested_value_text = models.TextField(
|
||||||
_('suggested value text'),
|
_("suggested value text"),
|
||||||
blank=True,
|
blank=True,
|
||||||
help_text=_('Text representation of the suggested value'),
|
help_text=_("Text representation of the suggested value"),
|
||||||
)
|
)
|
||||||
|
|
||||||
confidence = models.FloatField(
|
confidence = models.FloatField(
|
||||||
_('confidence'),
|
_("confidence"),
|
||||||
help_text=_('AI confidence score (0.0 to 1.0)'),
|
help_text=_("AI confidence score (0.0 to 1.0)"),
|
||||||
validators=[MinValueValidator(0.0), MaxValueValidator(1.0)],
|
validators=[MinValueValidator(0.0), MaxValueValidator(1.0)],
|
||||||
)
|
)
|
||||||
|
|
||||||
status = models.CharField(
|
status = models.CharField(
|
||||||
_('status'),
|
_("status"),
|
||||||
max_length=20,
|
max_length=20,
|
||||||
choices=FEEDBACK_STATUS,
|
choices=FEEDBACK_STATUS,
|
||||||
)
|
)
|
||||||
|
|
@ -1815,36 +1815,36 @@ class AISuggestionFeedback(models.Model):
|
||||||
on_delete=models.SET_NULL,
|
on_delete=models.SET_NULL,
|
||||||
null=True,
|
null=True,
|
||||||
blank=True,
|
blank=True,
|
||||||
related_name='ai_suggestion_feedbacks',
|
related_name="ai_suggestion_feedbacks",
|
||||||
verbose_name=_('user'),
|
verbose_name=_("user"),
|
||||||
help_text=_('User who applied or rejected the suggestion'),
|
help_text=_("User who applied or rejected the suggestion"),
|
||||||
)
|
)
|
||||||
|
|
||||||
created_at = models.DateTimeField(
|
created_at = models.DateTimeField(
|
||||||
_('created at'),
|
_("created at"),
|
||||||
auto_now_add=True,
|
auto_now_add=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
applied_at = models.DateTimeField(
|
applied_at = models.DateTimeField(
|
||||||
_('applied/rejected at'),
|
_("applied/rejected at"),
|
||||||
auto_now=True,
|
auto_now=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = models.JSONField(
|
metadata = models.JSONField(
|
||||||
_('metadata'),
|
_("metadata"),
|
||||||
default=dict,
|
default=dict,
|
||||||
blank=True,
|
blank=True,
|
||||||
help_text=_('Additional metadata about the suggestion'),
|
help_text=_("Additional metadata about the suggestion"),
|
||||||
)
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = _('AI suggestion feedback')
|
verbose_name = _("AI suggestion feedback")
|
||||||
verbose_name_plural = _('AI suggestion feedbacks')
|
verbose_name_plural = _("AI suggestion feedbacks")
|
||||||
ordering = ['-created_at']
|
ordering = ["-created_at"]
|
||||||
indexes = [
|
indexes = [
|
||||||
models.Index(fields=['document', 'suggestion_type']),
|
models.Index(fields=["document", "suggestion_type"]),
|
||||||
models.Index(fields=['status', 'created_at']),
|
models.Index(fields=["status", "created_at"]),
|
||||||
models.Index(fields=['suggestion_type', 'status']),
|
models.Index(fields=["suggestion_type", "status"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
|
||||||
|
|
@ -11,21 +11,21 @@ Lazy imports are used to avoid loading heavy dependencies unless needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'TableExtractor',
|
"FormFieldDetector",
|
||||||
'HandwritingRecognizer',
|
"HandwritingRecognizer",
|
||||||
'FormFieldDetector',
|
"TableExtractor",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name):
|
def __getattr__(name):
|
||||||
"""Lazy import to avoid loading heavy ML models on startup."""
|
"""Lazy import to avoid loading heavy ML models on startup."""
|
||||||
if name == 'TableExtractor':
|
if name == "TableExtractor":
|
||||||
from .table_extractor import TableExtractor
|
from .table_extractor import TableExtractor
|
||||||
return TableExtractor
|
return TableExtractor
|
||||||
elif name == 'HandwritingRecognizer':
|
elif name == "HandwritingRecognizer":
|
||||||
from .handwriting import HandwritingRecognizer
|
from .handwriting import HandwritingRecognizer
|
||||||
return HandwritingRecognizer
|
return HandwritingRecognizer
|
||||||
elif name == 'FormFieldDetector':
|
elif name == "FormFieldDetector":
|
||||||
from .form_detector import FormFieldDetector
|
from .form_detector import FormFieldDetector
|
||||||
return FormFieldDetector
|
return FormFieldDetector
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ This module provides capabilities to:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from typing import Any
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
@ -59,8 +59,8 @@ class FormFieldDetector:
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
min_size: int = 10,
|
min_size: int = 10,
|
||||||
max_size: int = 50
|
max_size: int = 50,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Detect checkboxes in a form image.
|
Detect checkboxes in a form image.
|
||||||
|
|
||||||
|
|
@ -114,9 +114,9 @@ class FormFieldDetector:
|
||||||
checked, confidence = self._is_checkbox_checked(checkbox_region)
|
checked, confidence = self._is_checkbox_checked(checkbox_region)
|
||||||
|
|
||||||
checkboxes.append({
|
checkboxes.append({
|
||||||
'bbox': [x, y, x+w, y+h],
|
"bbox": [x, y, x+w, y+h],
|
||||||
'checked': checked,
|
"checked": checked,
|
||||||
'confidence': confidence
|
"confidence": confidence,
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"Detected {len(checkboxes)} checkboxes")
|
logger.info(f"Detected {len(checkboxes)} checkboxes")
|
||||||
|
|
@ -129,7 +129,7 @@ class FormFieldDetector:
|
||||||
logger.error(f"Error detecting checkboxes: {e}")
|
logger.error(f"Error detecting checkboxes: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> Tuple[bool, float]:
|
def _is_checkbox_checked(self, checkbox_image: np.ndarray) -> tuple[bool, float]:
|
||||||
"""
|
"""
|
||||||
Determine if a checkbox is checked.
|
Determine if a checkbox is checked.
|
||||||
|
|
||||||
|
|
@ -167,8 +167,8 @@ class FormFieldDetector:
|
||||||
def detect_text_fields(
|
def detect_text_fields(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
min_width: int = 100
|
min_width: int = 100,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Detect text input fields in a form.
|
Detect text input fields in a form.
|
||||||
|
|
||||||
|
|
@ -202,14 +202,14 @@ class FormFieldDetector:
|
||||||
cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1],
|
cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1],
|
||||||
cv2.MORPH_OPEN,
|
cv2.MORPH_OPEN,
|
||||||
horizontal_kernel,
|
horizontal_kernel,
|
||||||
iterations=2
|
iterations=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find contours of horizontal lines
|
# Find contours of horizontal lines
|
||||||
contours, _ = cv2.findContours(
|
contours, _ = cv2.findContours(
|
||||||
detect_horizontal,
|
detect_horizontal,
|
||||||
cv2.RETR_EXTERNAL,
|
cv2.RETR_EXTERNAL,
|
||||||
cv2.CHAIN_APPROX_SIMPLE
|
cv2.CHAIN_APPROX_SIMPLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_fields = []
|
text_fields = []
|
||||||
|
|
@ -221,8 +221,8 @@ class FormFieldDetector:
|
||||||
# Expand upward to include text area
|
# Expand upward to include text area
|
||||||
text_bbox = [x, max(0, y-30), x+w, y+h]
|
text_bbox = [x, max(0, y-30), x+w, y+h]
|
||||||
text_fields.append({
|
text_fields.append({
|
||||||
'bbox': text_bbox,
|
"bbox": text_bbox,
|
||||||
'type': 'line'
|
"type": "line",
|
||||||
})
|
})
|
||||||
|
|
||||||
# Detect rectangular boxes (bordered text fields)
|
# Detect rectangular boxes (bordered text fields)
|
||||||
|
|
@ -236,8 +236,8 @@ class FormFieldDetector:
|
||||||
aspect_ratio = w / h if h > 0 else 0
|
aspect_ratio = w / h if h > 0 else 0
|
||||||
if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2:
|
if w >= min_width and 20 <= h <= 100 and aspect_ratio > 2:
|
||||||
text_fields.append({
|
text_fields.append({
|
||||||
'bbox': [x, y, x+w, y+h],
|
"bbox": [x, y, x+w, y+h],
|
||||||
'type': 'box'
|
"type": "box",
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"Detected {len(text_fields)} text fields")
|
logger.info(f"Detected {len(text_fields)} text fields")
|
||||||
|
|
@ -253,8 +253,8 @@ class FormFieldDetector:
|
||||||
def detect_labels(
|
def detect_labels(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
field_bboxes: List[List[int]]
|
field_bboxes: list[list[int]],
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Detect labels near form fields.
|
Detect labels near form fields.
|
||||||
|
|
||||||
|
|
@ -271,17 +271,17 @@ class FormFieldDetector:
|
||||||
# Get all text with bounding boxes
|
# Get all text with bounding boxes
|
||||||
ocr_data = pytesseract.image_to_data(
|
ocr_data = pytesseract.image_to_data(
|
||||||
image,
|
image,
|
||||||
output_type=pytesseract.Output.DICT
|
output_type=pytesseract.Output.DICT,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Group text into potential labels
|
# Group text into potential labels
|
||||||
labels = []
|
labels = []
|
||||||
for i, text in enumerate(ocr_data['text']):
|
for i, text in enumerate(ocr_data["text"]):
|
||||||
if text.strip() and len(text.strip()) > 2:
|
if text.strip() and len(text.strip()) > 2:
|
||||||
x = ocr_data['left'][i]
|
x = ocr_data["left"][i]
|
||||||
y = ocr_data['top'][i]
|
y = ocr_data["top"][i]
|
||||||
w = ocr_data['width'][i]
|
w = ocr_data["width"][i]
|
||||||
h = ocr_data['height'][i]
|
h = ocr_data["height"][i]
|
||||||
|
|
||||||
label_bbox = [x, y, x+w, y+h]
|
label_bbox = [x, y, x+w, y+h]
|
||||||
|
|
||||||
|
|
@ -289,9 +289,9 @@ class FormFieldDetector:
|
||||||
closest_field_idx = self._find_closest_field(label_bbox, field_bboxes)
|
closest_field_idx = self._find_closest_field(label_bbox, field_bboxes)
|
||||||
|
|
||||||
labels.append({
|
labels.append({
|
||||||
'text': text.strip(),
|
"text": text.strip(),
|
||||||
'bbox': label_bbox,
|
"bbox": label_bbox,
|
||||||
'field_index': closest_field_idx
|
"field_index": closest_field_idx,
|
||||||
})
|
})
|
||||||
|
|
||||||
return labels
|
return labels
|
||||||
|
|
@ -305,9 +305,9 @@ class FormFieldDetector:
|
||||||
|
|
||||||
def _find_closest_field(
|
def _find_closest_field(
|
||||||
self,
|
self,
|
||||||
label_bbox: List[int],
|
label_bbox: list[int],
|
||||||
field_bboxes: List[List[int]]
|
field_bboxes: list[list[int]],
|
||||||
) -> Optional[int]:
|
) -> int | None:
|
||||||
"""
|
"""
|
||||||
Find the closest field to a label.
|
Find the closest field to a label.
|
||||||
|
|
||||||
|
|
@ -325,7 +325,7 @@ class FormFieldDetector:
|
||||||
label_center_x = (label_bbox[0] + label_bbox[2]) / 2
|
label_center_x = (label_bbox[0] + label_bbox[2]) / 2
|
||||||
label_center_y = (label_bbox[1] + label_bbox[3]) / 2
|
label_center_y = (label_bbox[1] + label_bbox[3]) / 2
|
||||||
|
|
||||||
min_distance = float('inf')
|
min_distance = float("inf")
|
||||||
closest_idx = 0
|
closest_idx = 0
|
||||||
|
|
||||||
for i, field_bbox in enumerate(field_bboxes):
|
for i, field_bbox in enumerate(field_bboxes):
|
||||||
|
|
@ -336,7 +336,7 @@ class FormFieldDetector:
|
||||||
# Euclidean distance
|
# Euclidean distance
|
||||||
distance = np.sqrt(
|
distance = np.sqrt(
|
||||||
(label_center_x - field_center_x)**2 +
|
(label_center_x - field_center_x)**2 +
|
||||||
(label_center_y - field_center_y)**2
|
(label_center_y - field_center_y)**2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if distance < min_distance:
|
if distance < min_distance:
|
||||||
|
|
@ -348,8 +348,8 @@ class FormFieldDetector:
|
||||||
def detect_form_fields(
|
def detect_form_fields(
|
||||||
self,
|
self,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
extract_values: bool = True
|
extract_values: bool = True,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Detect all form fields and extract their values.
|
Detect all form fields and extract their values.
|
||||||
|
|
||||||
|
|
@ -372,14 +372,14 @@ class FormFieldDetector:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load image
|
# Load image
|
||||||
image = Image.open(image_path).convert('RGB')
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
# Detect different field types
|
# Detect different field types
|
||||||
text_fields = self.detect_text_fields(image)
|
text_fields = self.detect_text_fields(image)
|
||||||
checkboxes = self.detect_checkboxes(image)
|
checkboxes = self.detect_checkboxes(image)
|
||||||
|
|
||||||
# Combine all field bboxes for label detection
|
# Combine all field bboxes for label detection
|
||||||
all_field_bboxes = [f['bbox'] for f in text_fields] + [cb['bbox'] for cb in checkboxes]
|
all_field_bboxes = [f["bbox"] for f in text_fields] + [cb["bbox"] for cb in checkboxes]
|
||||||
|
|
||||||
# Detect labels
|
# Detect labels
|
||||||
labels = self.detect_labels(image, all_field_bboxes)
|
labels = self.detect_labels(image, all_field_bboxes)
|
||||||
|
|
@ -393,20 +393,20 @@ class FormFieldDetector:
|
||||||
label_text = self._find_label_for_field(i, labels, len(text_fields))
|
label_text = self._find_label_for_field(i, labels, len(text_fields))
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'type': 'text',
|
"type": "text",
|
||||||
'label': label_text,
|
"label": label_text,
|
||||||
'bbox': field['bbox'],
|
"bbox": field["bbox"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract value if requested
|
# Extract value if requested
|
||||||
if extract_values:
|
if extract_values:
|
||||||
x1, y1, x2, y2 = field['bbox']
|
x1, y1, x2, y2 = field["bbox"]
|
||||||
field_image = image.crop((x1, y1, x2, y2))
|
field_image = image.crop((x1, y1, x2, y2))
|
||||||
|
|
||||||
recognizer = self._get_handwriting_recognizer()
|
recognizer = self._get_handwriting_recognizer()
|
||||||
value = recognizer.recognize_from_image(field_image, preprocess=True)
|
value = recognizer.recognize_from_image(field_image, preprocess=True)
|
||||||
result['value'] = value.strip()
|
result["value"] = value.strip()
|
||||||
result['confidence'] = recognizer._estimate_confidence(value)
|
result["confidence"] = recognizer._estimate_confidence(value)
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
|
@ -416,11 +416,11 @@ class FormFieldDetector:
|
||||||
label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes))
|
label_text = self._find_label_for_field(field_idx, labels, len(all_field_bboxes))
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
'type': 'checkbox',
|
"type": "checkbox",
|
||||||
'label': label_text,
|
"label": label_text,
|
||||||
'value': checkbox['checked'],
|
"value": checkbox["checked"],
|
||||||
'bbox': checkbox['bbox'],
|
"bbox": checkbox["bbox"],
|
||||||
'confidence': checkbox['confidence']
|
"confidence": checkbox["confidence"],
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"Detected {len(results)} form fields from {image_path}")
|
logger.info(f"Detected {len(results)} form fields from {image_path}")
|
||||||
|
|
@ -433,8 +433,8 @@ class FormFieldDetector:
|
||||||
def _find_label_for_field(
|
def _find_label_for_field(
|
||||||
self,
|
self,
|
||||||
field_idx: int,
|
field_idx: int,
|
||||||
labels: List[Dict[str, Any]],
|
labels: list[dict[str, Any]],
|
||||||
total_fields: int
|
total_fields: int,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Find the label text for a specific field.
|
Find the label text for a specific field.
|
||||||
|
|
@ -449,19 +449,19 @@ class FormFieldDetector:
|
||||||
"""
|
"""
|
||||||
matching_labels = [
|
matching_labels = [
|
||||||
label for label in labels
|
label for label in labels
|
||||||
if label['field_index'] == field_idx
|
if label["field_index"] == field_idx
|
||||||
]
|
]
|
||||||
|
|
||||||
if matching_labels:
|
if matching_labels:
|
||||||
# Combine multiple label parts if found
|
# Combine multiple label parts if found
|
||||||
return ' '.join(label['text'] for label in matching_labels)
|
return " ".join(label["text"] for label in matching_labels)
|
||||||
|
|
||||||
return f"Field_{field_idx + 1}"
|
return f"Field_{field_idx + 1}"
|
||||||
|
|
||||||
def extract_form_data(
|
def extract_form_data(
|
||||||
self,
|
self,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
output_format: str = 'dict'
|
output_format: str = "dict",
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Extract all form data as structured output.
|
Extract all form data as structured output.
|
||||||
|
|
@ -476,16 +476,16 @@ class FormFieldDetector:
|
||||||
# Detect and extract fields
|
# Detect and extract fields
|
||||||
fields = self.detect_form_fields(image_path, extract_values=True)
|
fields = self.detect_form_fields(image_path, extract_values=True)
|
||||||
|
|
||||||
if output_format == 'dict':
|
if output_format == "dict":
|
||||||
# Return as dictionary
|
# Return as dictionary
|
||||||
return {field['label']: field['value'] for field in fields}
|
return {field["label"]: field["value"] for field in fields}
|
||||||
|
|
||||||
elif output_format == 'json':
|
elif output_format == "json":
|
||||||
import json
|
import json
|
||||||
data = {field['label']: field['value'] for field in fields}
|
data = {field["label"]: field["value"] for field in fields}
|
||||||
return json.dumps(data, indent=2)
|
return json.dumps(data, indent=2)
|
||||||
|
|
||||||
elif output_format == 'dataframe':
|
elif output_format == "dataframe":
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
return pd.DataFrame(fields)
|
return pd.DataFrame(fields)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ This module provides handwriting OCR capabilities using:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from typing import Any
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
@ -65,8 +65,9 @@ class HandwritingRecognizer:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
|
||||||
logger.info(f"Loading handwriting recognition model: {self.model_name}")
|
logger.info(f"Loading handwriting recognition model: {self.model_name}")
|
||||||
|
|
||||||
|
|
@ -90,7 +91,7 @@ class HandwritingRecognizer:
|
||||||
def recognize_from_image(
|
def recognize_from_image(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
preprocess: bool = True
|
preprocess: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Recognize text from a single image.
|
Recognize text from a single image.
|
||||||
|
|
@ -142,11 +143,12 @@ class HandwritingRecognizer:
|
||||||
Preprocessed PIL Image
|
Preprocessed PIL Image
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from PIL import ImageEnhance, ImageFilter
|
from PIL import ImageEnhance
|
||||||
|
from PIL import ImageFilter
|
||||||
|
|
||||||
# Convert to grayscale
|
# Convert to grayscale
|
||||||
if image.mode != 'L':
|
if image.mode != "L":
|
||||||
image = image.convert('L')
|
image = image.convert("L")
|
||||||
|
|
||||||
# Enhance contrast
|
# Enhance contrast
|
||||||
enhancer = ImageEnhance.Contrast(image)
|
enhancer = ImageEnhance.Contrast(image)
|
||||||
|
|
@ -156,7 +158,7 @@ class HandwritingRecognizer:
|
||||||
image = image.filter(ImageFilter.MedianFilter(size=3))
|
image = image.filter(ImageFilter.MedianFilter(size=3))
|
||||||
|
|
||||||
# Convert back to RGB (required by model)
|
# Convert back to RGB (required by model)
|
||||||
image = image.convert('RGB')
|
image = image.convert("RGB")
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
@ -164,7 +166,7 @@ class HandwritingRecognizer:
|
||||||
logger.warning(f"Error preprocessing image: {e}")
|
logger.warning(f"Error preprocessing image: {e}")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def detect_text_lines(self, image: Image.Image) -> List[Dict[str, Any]]:
|
def detect_text_lines(self, image: Image.Image) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Detect individual text lines in an image.
|
Detect individual text lines in an image.
|
||||||
|
|
||||||
|
|
@ -208,12 +210,12 @@ class HandwritingRecognizer:
|
||||||
# Crop line from original image
|
# Crop line from original image
|
||||||
line_img = image.crop((x, y, x+w, y+h))
|
line_img = image.crop((x, y, x+w, y+h))
|
||||||
lines.append({
|
lines.append({
|
||||||
'bbox': [x, y, x+w, y+h],
|
"bbox": [x, y, x+w, y+h],
|
||||||
'image': line_img
|
"image": line_img,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Sort lines top to bottom
|
# Sort lines top to bottom
|
||||||
lines.sort(key=lambda l: l['bbox'][1])
|
lines.sort(key=lambda l: l["bbox"][1])
|
||||||
|
|
||||||
logger.info(f"Detected {len(lines)} text lines")
|
logger.info(f"Detected {len(lines)} text lines")
|
||||||
return lines
|
return lines
|
||||||
|
|
@ -228,8 +230,8 @@ class HandwritingRecognizer:
|
||||||
def recognize_lines(
|
def recognize_lines(
|
||||||
self,
|
self,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
return_confidence: bool = True
|
return_confidence: bool = True,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Recognize text from each line in an image.
|
Recognize text from each line in an image.
|
||||||
|
|
||||||
|
|
@ -250,7 +252,7 @@ class HandwritingRecognizer:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load image
|
# Load image
|
||||||
image = Image.open(image_path).convert('RGB')
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
# Detect lines
|
# Detect lines
|
||||||
lines = self.detect_text_lines(image)
|
lines = self.detect_text_lines(image)
|
||||||
|
|
@ -260,18 +262,18 @@ class HandwritingRecognizer:
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
logger.debug(f"Recognizing line {i+1}/{len(lines)}")
|
logger.debug(f"Recognizing line {i+1}/{len(lines)}")
|
||||||
|
|
||||||
text = self.recognize_from_image(line['image'], preprocess=True)
|
text = self.recognize_from_image(line["image"], preprocess=True)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'text': text,
|
"text": text,
|
||||||
'bbox': line['bbox'],
|
"bbox": line["bbox"],
|
||||||
'line_index': i
|
"line_index": i,
|
||||||
}
|
}
|
||||||
|
|
||||||
if return_confidence:
|
if return_confidence:
|
||||||
# Simple confidence based on text length and content
|
# Simple confidence based on text length and content
|
||||||
confidence = self._estimate_confidence(text)
|
confidence = self._estimate_confidence(text)
|
||||||
result['confidence'] = confidence
|
result["confidence"] = confidence
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
|
@ -309,7 +311,7 @@ class HandwritingRecognizer:
|
||||||
score += 0.1
|
score += 0.1
|
||||||
|
|
||||||
# Text with spaces (words) is more reliable
|
# Text with spaces (words) is more reliable
|
||||||
if ' ' in text:
|
if " " in text:
|
||||||
score += 0.1
|
score += 0.1
|
||||||
|
|
||||||
# Penalize if too many special characters
|
# Penalize if too many special characters
|
||||||
|
|
@ -322,8 +324,8 @@ class HandwritingRecognizer:
|
||||||
def recognize_from_file(
|
def recognize_from_file(
|
||||||
self,
|
self,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
mode: str = 'full'
|
mode: str = "full",
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Recognize handwriting from an image file.
|
Recognize handwriting from an image file.
|
||||||
|
|
||||||
|
|
@ -337,30 +339,30 @@ class HandwritingRecognizer:
|
||||||
Dictionary with recognized text and metadata
|
Dictionary with recognized text and metadata
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if mode == 'full':
|
if mode == "full":
|
||||||
# Recognize entire image
|
# Recognize entire image
|
||||||
image = Image.open(image_path).convert('RGB')
|
image = Image.open(image_path).convert("RGB")
|
||||||
text = self.recognize_from_image(image, preprocess=True)
|
text = self.recognize_from_image(image, preprocess=True)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'text': text,
|
"text": text,
|
||||||
'mode': 'full',
|
"mode": "full",
|
||||||
'confidence': self._estimate_confidence(text)
|
"confidence": self._estimate_confidence(text),
|
||||||
}
|
}
|
||||||
|
|
||||||
elif mode == 'lines':
|
elif mode == "lines":
|
||||||
# Recognize line by line
|
# Recognize line by line
|
||||||
lines = self.recognize_lines(image_path, return_confidence=True)
|
lines = self.recognize_lines(image_path, return_confidence=True)
|
||||||
|
|
||||||
# Combine all lines
|
# Combine all lines
|
||||||
full_text = '\n'.join(line['text'] for line in lines)
|
full_text = "\n".join(line["text"] for line in lines)
|
||||||
avg_confidence = np.mean([line['confidence'] for line in lines]) if lines else 0.0
|
avg_confidence = np.mean([line["confidence"] for line in lines]) if lines else 0.0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'text': full_text,
|
"text": full_text,
|
||||||
'lines': lines,
|
"lines": lines,
|
||||||
'mode': 'lines',
|
"mode": "lines",
|
||||||
'confidence': float(avg_confidence)
|
"confidence": float(avg_confidence),
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -369,17 +371,17 @@ class HandwritingRecognizer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error recognizing from file {image_path}: {e}")
|
logger.error(f"Error recognizing from file {image_path}: {e}")
|
||||||
return {
|
return {
|
||||||
'text': '',
|
"text": "",
|
||||||
'mode': mode,
|
"mode": mode,
|
||||||
'confidence': 0.0,
|
"confidence": 0.0,
|
||||||
'error': str(e)
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
def recognize_form_fields(
|
def recognize_form_fields(
|
||||||
self,
|
self,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
field_regions: List[Dict[str, Any]]
|
field_regions: list[dict[str, Any]],
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Recognize text from specific form fields.
|
Recognize text from specific form fields.
|
||||||
|
|
||||||
|
|
@ -399,13 +401,13 @@ class HandwritingRecognizer:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load image
|
# Load image
|
||||||
image = Image.open(image_path).convert('RGB')
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
# Extract and recognize each field
|
# Extract and recognize each field
|
||||||
results = {}
|
results = {}
|
||||||
for field in field_regions:
|
for field in field_regions:
|
||||||
name = field['name']
|
name = field["name"]
|
||||||
bbox = field['bbox']
|
bbox = field["bbox"]
|
||||||
|
|
||||||
# Crop field region
|
# Crop field region
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
|
|
@ -425,9 +427,9 @@ class HandwritingRecognizer:
|
||||||
|
|
||||||
def batch_recognize(
|
def batch_recognize(
|
||||||
self,
|
self,
|
||||||
image_paths: List[str],
|
image_paths: list[str],
|
||||||
mode: str = 'full'
|
mode: str = "full",
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Recognize handwriting from multiple images in batch.
|
Recognize handwriting from multiple images in batch.
|
||||||
|
|
||||||
|
|
@ -442,7 +444,7 @@ class HandwritingRecognizer:
|
||||||
for i, path in enumerate(image_paths):
|
for i, path in enumerate(image_paths):
|
||||||
logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
|
logger.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
|
||||||
result = self.recognize_from_file(path, mode=mode)
|
result = self.recognize_from_file(path, mode=mode)
|
||||||
result['image_path'] = path
|
result["image_path"] = path
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,8 @@ This module uses various techniques to detect and extract tables from documents:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from typing import Any
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -59,8 +58,9 @@ class TableExtractor:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import AutoImageProcessor
|
||||||
|
from transformers import AutoModelForObjectDetection
|
||||||
|
|
||||||
logger.info(f"Loading table detection model: {self.model_name}")
|
logger.info(f"Loading table detection model: {self.model_name}")
|
||||||
|
|
||||||
|
|
@ -79,7 +79,7 @@ class TableExtractor:
|
||||||
logger.error("Please install required packages: pip install transformers torch pillow")
|
logger.error("Please install required packages: pip install transformers torch pillow")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def detect_tables(self, image: Image.Image) -> List[Dict[str, Any]]:
|
def detect_tables(self, image: Image.Image) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Detect tables in an image.
|
Detect tables in an image.
|
||||||
|
|
||||||
|
|
@ -117,16 +117,16 @@ class TableExtractor:
|
||||||
results = self._processor.post_process_object_detection(
|
results = self._processor.post_process_object_detection(
|
||||||
outputs,
|
outputs,
|
||||||
threshold=self.confidence_threshold,
|
threshold=self.confidence_threshold,
|
||||||
target_sizes=target_sizes
|
target_sizes=target_sizes,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Convert to list of dicts
|
# Convert to list of dicts
|
||||||
tables = []
|
tables = []
|
||||||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||||
tables.append({
|
tables.append({
|
||||||
'bbox': box.cpu().tolist(),
|
"bbox": box.cpu().tolist(),
|
||||||
'score': score.item(),
|
"score": score.item(),
|
||||||
'label': self._model.config.id2label[label.item()]
|
"label": self._model.config.id2label[label.item()],
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"Detected {len(tables)} tables in image")
|
logger.info(f"Detected {len(tables)} tables in image")
|
||||||
|
|
@ -139,9 +139,9 @@ class TableExtractor:
|
||||||
def extract_table_from_region(
|
def extract_table_from_region(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
bbox: List[float],
|
bbox: list[float],
|
||||||
use_ocr: bool = True
|
use_ocr: bool = True,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Extract table data from a specific region of an image.
|
Extract table data from a specific region of an image.
|
||||||
|
|
||||||
|
|
@ -166,7 +166,7 @@ class TableExtractor:
|
||||||
# Get detailed OCR data
|
# Get detailed OCR data
|
||||||
ocr_data = pytesseract.image_to_data(
|
ocr_data = pytesseract.image_to_data(
|
||||||
table_image,
|
table_image,
|
||||||
output_type=pytesseract.Output.DICT
|
output_type=pytesseract.Output.DICT,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reconstruct table structure from OCR data
|
# Reconstruct table structure from OCR data
|
||||||
|
|
@ -176,20 +176,20 @@ class TableExtractor:
|
||||||
raw_text = pytesseract.image_to_string(table_image)
|
raw_text = pytesseract.image_to_string(table_image)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'data': table_data,
|
"data": table_data,
|
||||||
'raw_text': raw_text,
|
"raw_text": raw_text,
|
||||||
'bbox': bbox,
|
"bbox": bbox,
|
||||||
'image_size': table_image.size
|
"image_size": table_image.size,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Fallback to basic OCR without structure
|
# Fallback to basic OCR without structure
|
||||||
import pytesseract
|
import pytesseract
|
||||||
raw_text = pytesseract.image_to_string(table_image)
|
raw_text = pytesseract.image_to_string(table_image)
|
||||||
return {
|
return {
|
||||||
'data': None,
|
"data": None,
|
||||||
'raw_text': raw_text,
|
"raw_text": raw_text,
|
||||||
'bbox': bbox,
|
"bbox": bbox,
|
||||||
'image_size': table_image.size
|
"image_size": table_image.size,
|
||||||
}
|
}
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -199,7 +199,7 @@ class TableExtractor:
|
||||||
logger.error(f"Error extracting table from region: {e}")
|
logger.error(f"Error extracting table from region: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _reconstruct_table_from_ocr(self, ocr_data: Dict) -> Optional[Any]:
|
def _reconstruct_table_from_ocr(self, ocr_data: dict) -> Any | None:
|
||||||
"""
|
"""
|
||||||
Reconstruct table structure from OCR output.
|
Reconstruct table structure from OCR output.
|
||||||
|
|
||||||
|
|
@ -214,10 +214,10 @@ class TableExtractor:
|
||||||
|
|
||||||
# Group text by vertical position (rows)
|
# Group text by vertical position (rows)
|
||||||
rows = {}
|
rows = {}
|
||||||
for i, text in enumerate(ocr_data['text']):
|
for i, text in enumerate(ocr_data["text"]):
|
||||||
if text.strip():
|
if text.strip():
|
||||||
top = ocr_data['top'][i]
|
top = ocr_data["top"][i]
|
||||||
left = ocr_data['left'][i]
|
left = ocr_data["left"][i]
|
||||||
|
|
||||||
# Group by approximate row (within 20 pixels)
|
# Group by approximate row (within 20 pixels)
|
||||||
row_key = round(top / 20) * 20
|
row_key = round(top / 20) * 20
|
||||||
|
|
@ -235,14 +235,14 @@ class TableExtractor:
|
||||||
if table_rows:
|
if table_rows:
|
||||||
# Pad rows to same length
|
# Pad rows to same length
|
||||||
max_cols = max(len(row) for row in table_rows)
|
max_cols = max(len(row) for row in table_rows)
|
||||||
table_rows = [row + [''] * (max_cols - len(row)) for row in table_rows]
|
table_rows = [row + [""] * (max_cols - len(row)) for row in table_rows]
|
||||||
|
|
||||||
# Create DataFrame
|
# Create DataFrame
|
||||||
df = pd.DataFrame(table_rows)
|
df = pd.DataFrame(table_rows)
|
||||||
|
|
||||||
# Try to use first row as header if it looks like one
|
# Try to use first row as header if it looks like one
|
||||||
if len(df) > 1:
|
if len(df) > 1:
|
||||||
first_row_text = ' '.join(str(x) for x in df.iloc[0])
|
first_row_text = " ".join(str(x) for x in df.iloc[0])
|
||||||
if not any(char.isdigit() for char in first_row_text):
|
if not any(char.isdigit() for char in first_row_text):
|
||||||
df.columns = df.iloc[0]
|
df.columns = df.iloc[0]
|
||||||
df = df[1:].reset_index(drop=True)
|
df = df[1:].reset_index(drop=True)
|
||||||
|
|
@ -261,8 +261,8 @@ class TableExtractor:
|
||||||
def extract_tables_from_image(
|
def extract_tables_from_image(
|
||||||
self,
|
self,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
output_format: str = 'dataframe'
|
output_format: str = "dataframe",
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Extract all tables from an image file.
|
Extract all tables from an image file.
|
||||||
|
|
||||||
|
|
@ -275,7 +275,7 @@ class TableExtractor:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load image
|
# Load image
|
||||||
image = Image.open(image_path).convert('RGB')
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
# Detect tables
|
# Detect tables
|
||||||
detections = self.detect_tables(image)
|
detections = self.detect_tables(image)
|
||||||
|
|
@ -287,18 +287,18 @@ class TableExtractor:
|
||||||
|
|
||||||
table_data = self.extract_table_from_region(
|
table_data = self.extract_table_from_region(
|
||||||
image,
|
image,
|
||||||
detection['bbox']
|
detection["bbox"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if table_data:
|
if table_data:
|
||||||
table_data['detection_score'] = detection['score']
|
table_data["detection_score"] = detection["score"]
|
||||||
table_data['table_index'] = i
|
table_data["table_index"] = i
|
||||||
|
|
||||||
# Convert to requested format
|
# Convert to requested format
|
||||||
if output_format == 'csv' and table_data['data'] is not None:
|
if output_format == "csv" and table_data["data"] is not None:
|
||||||
table_data['csv'] = table_data['data'].to_csv(index=False)
|
table_data["csv"] = table_data["data"].to_csv(index=False)
|
||||||
elif output_format == 'json' and table_data['data'] is not None:
|
elif output_format == "json" and table_data["data"] is not None:
|
||||||
table_data['json'] = table_data['data'].to_json(orient='records')
|
table_data["json"] = table_data["data"].to_json(orient="records")
|
||||||
|
|
||||||
tables.append(table_data)
|
tables.append(table_data)
|
||||||
|
|
||||||
|
|
@ -312,8 +312,8 @@ class TableExtractor:
|
||||||
def extract_tables_from_pdf(
|
def extract_tables_from_pdf(
|
||||||
self,
|
self,
|
||||||
pdf_path: str,
|
pdf_path: str,
|
||||||
page_numbers: Optional[List[int]] = None
|
page_numbers: list[int] | None = None,
|
||||||
) -> Dict[int, List[Dict[str, Any]]]:
|
) -> dict[int, list[dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Extract tables from a PDF document.
|
Extract tables from a PDF document.
|
||||||
|
|
||||||
|
|
@ -334,7 +334,7 @@ class TableExtractor:
|
||||||
images = convert_from_path(
|
images = convert_from_path(
|
||||||
pdf_path,
|
pdf_path,
|
||||||
first_page=min(page_numbers),
|
first_page=min(page_numbers),
|
||||||
last_page=max(page_numbers)
|
last_page=max(page_numbers),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
images = convert_from_path(pdf_path)
|
images = convert_from_path(pdf_path)
|
||||||
|
|
@ -352,11 +352,11 @@ class TableExtractor:
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
table_data = self.extract_table_from_region(
|
table_data = self.extract_table_from_region(
|
||||||
image,
|
image,
|
||||||
detection['bbox']
|
detection["bbox"],
|
||||||
)
|
)
|
||||||
if table_data:
|
if table_data:
|
||||||
table_data['detection_score'] = detection['score']
|
table_data["detection_score"] = detection["score"]
|
||||||
table_data['page'] = page_num
|
table_data["page"] = page_num
|
||||||
tables.append(table_data)
|
tables.append(table_data)
|
||||||
|
|
||||||
if tables:
|
if tables:
|
||||||
|
|
@ -374,8 +374,8 @@ class TableExtractor:
|
||||||
|
|
||||||
def save_tables_to_excel(
|
def save_tables_to_excel(
|
||||||
self,
|
self,
|
||||||
tables: List[Dict[str, Any]],
|
tables: list[dict[str, Any]],
|
||||||
output_path: str
|
output_path: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Save extracted tables to an Excel file.
|
Save extracted tables to an Excel file.
|
||||||
|
|
@ -390,17 +390,17 @@ class TableExtractor:
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
|
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
|
||||||
for i, table in enumerate(tables):
|
for i, table in enumerate(tables):
|
||||||
if table.get('data') is not None:
|
if table.get("data") is not None:
|
||||||
sheet_name = f"Table_{i+1}"
|
sheet_name = f"Table_{i+1}"
|
||||||
if 'page' in table:
|
if "page" in table:
|
||||||
sheet_name = f"Page_{table['page']}_Table_{i+1}"
|
sheet_name = f"Page_{table['page']}_Table_{i+1}"
|
||||||
|
|
||||||
table['data'].to_excel(
|
table["data"].to_excel(
|
||||||
writer,
|
writer,
|
||||||
sheet_name=sheet_name,
|
sheet_name=sheet_name,
|
||||||
index=False
|
index=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Saved {len(tables)} tables to {output_path}")
|
logger.info(f"Saved {len(tables)} tables to {output_path}")
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,6 @@ if settings.AUDIT_LOG_ENABLED:
|
||||||
from documents import bulk_edit
|
from documents import bulk_edit
|
||||||
from documents.data_models import DocumentSource
|
from documents.data_models import DocumentSource
|
||||||
from documents.filters import CustomFieldQueryParser
|
from documents.filters import CustomFieldQueryParser
|
||||||
from documents.models import AISuggestionFeedback
|
|
||||||
from documents.models import Correspondent
|
from documents.models import Correspondent
|
||||||
from documents.models import CustomField
|
from documents.models import CustomField
|
||||||
from documents.models import CustomFieldInstance
|
from documents.models import CustomFieldInstance
|
||||||
|
|
@ -2788,9 +2787,9 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
|
||||||
"""Serializer for DeletionRequest model with document details."""
|
"""Serializer for DeletionRequest model with document details."""
|
||||||
|
|
||||||
document_details = serializers.SerializerMethodField()
|
document_details = serializers.SerializerMethodField()
|
||||||
user_username = serializers.CharField(source='user.username', read_only=True)
|
user_username = serializers.CharField(source="user.username", read_only=True)
|
||||||
reviewed_by_username = serializers.CharField(
|
reviewed_by_username = serializers.CharField(
|
||||||
source='reviewed_by.username',
|
source="reviewed_by.username",
|
||||||
read_only=True,
|
read_only=True,
|
||||||
allow_null=True,
|
allow_null=True,
|
||||||
)
|
)
|
||||||
|
|
@ -2799,31 +2798,31 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
|
||||||
from documents.models import DeletionRequest
|
from documents.models import DeletionRequest
|
||||||
model = DeletionRequest
|
model = DeletionRequest
|
||||||
fields = [
|
fields = [
|
||||||
'id',
|
"id",
|
||||||
'created_at',
|
"created_at",
|
||||||
'updated_at',
|
"updated_at",
|
||||||
'requested_by_ai',
|
"requested_by_ai",
|
||||||
'ai_reason',
|
"ai_reason",
|
||||||
'user',
|
"user",
|
||||||
'user_username',
|
"user_username",
|
||||||
'status',
|
"status",
|
||||||
'impact_summary',
|
"impact_summary",
|
||||||
'reviewed_at',
|
"reviewed_at",
|
||||||
'reviewed_by',
|
"reviewed_by",
|
||||||
'reviewed_by_username',
|
"reviewed_by_username",
|
||||||
'review_comment',
|
"review_comment",
|
||||||
'completed_at',
|
"completed_at",
|
||||||
'completion_details',
|
"completion_details",
|
||||||
'document_details',
|
"document_details",
|
||||||
]
|
]
|
||||||
read_only_fields = [
|
read_only_fields = [
|
||||||
'id',
|
"id",
|
||||||
'created_at',
|
"created_at",
|
||||||
'updated_at',
|
"updated_at",
|
||||||
'reviewed_at',
|
"reviewed_at",
|
||||||
'reviewed_by',
|
"reviewed_by",
|
||||||
'completed_at',
|
"completed_at",
|
||||||
'completion_details',
|
"completion_details",
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_document_details(self, obj):
|
def get_document_details(self, obj):
|
||||||
|
|
@ -2831,12 +2830,12 @@ class DeletionRequestDetailSerializer(serializers.ModelSerializer):
|
||||||
documents = obj.documents.all()
|
documents = obj.documents.all()
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
'id': doc.id,
|
"id": doc.id,
|
||||||
'title': doc.title,
|
"title": doc.title,
|
||||||
'created': doc.created.isoformat() if doc.created else None,
|
"created": doc.created.isoformat() if doc.created else None,
|
||||||
'correspondent': doc.correspondent.name if doc.correspondent else None,
|
"correspondent": doc.correspondent.name if doc.correspondent else None,
|
||||||
'document_type': doc.document_type.name if doc.document_type else None,
|
"document_type": doc.document_type.name if doc.document_type else None,
|
||||||
'tags': [tag.name for tag in doc.tags.all()],
|
"tags": [tag.name for tag in doc.tags.all()],
|
||||||
}
|
}
|
||||||
for doc in documents
|
for doc in documents
|
||||||
]
|
]
|
||||||
|
|
@ -2852,6 +2851,9 @@ class DeletionRequestActionSerializer(serializers.Serializer):
|
||||||
allow_blank=True,
|
allow_blank=True,
|
||||||
label="Review Comment",
|
label="Review Comment",
|
||||||
help_text="Optional comment when reviewing the deletion request",
|
help_text="Optional comment when reviewing the deletion request",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AISuggestionsRequestSerializer(serializers.Serializer):
|
class AISuggestionsRequestSerializer(serializers.Serializer):
|
||||||
"""Serializer for requesting AI suggestions for a document."""
|
"""Serializer for requesting AI suggestions for a document."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,15 @@
|
||||||
"""Serializers package for documents app."""
|
"""Serializers package for documents app."""
|
||||||
|
|
||||||
from .ai_suggestions import (
|
from .ai_suggestions import AISuggestionFeedbackSerializer
|
||||||
AISuggestionFeedbackSerializer,
|
from .ai_suggestions import AISuggestionsSerializer
|
||||||
AISuggestionsSerializer,
|
from .ai_suggestions import AISuggestionStatsSerializer
|
||||||
AISuggestionStatsSerializer,
|
from .ai_suggestions import ApplySuggestionSerializer
|
||||||
ApplySuggestionSerializer,
|
from .ai_suggestions import RejectSuggestionSerializer
|
||||||
RejectSuggestionSerializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AISuggestionFeedbackSerializer',
|
"AISuggestionFeedbackSerializer",
|
||||||
'AISuggestionsSerializer',
|
"AISuggestionStatsSerializer",
|
||||||
'AISuggestionStatsSerializer',
|
"AISuggestionsSerializer",
|
||||||
'ApplySuggestionSerializer',
|
"ApplySuggestionSerializer",
|
||||||
'RejectSuggestionSerializer',
|
"RejectSuggestionSerializer",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -7,36 +7,33 @@ and handling user feedback on AI suggestions.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from documents.models import (
|
from documents.models import AISuggestionFeedback
|
||||||
AISuggestionFeedback,
|
from documents.models import Correspondent
|
||||||
Correspondent,
|
from documents.models import CustomField
|
||||||
CustomField,
|
from documents.models import DocumentType
|
||||||
DocumentType,
|
from documents.models import StoragePath
|
||||||
StoragePath,
|
from documents.models import Tag
|
||||||
Tag,
|
from documents.models import Workflow
|
||||||
Workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Suggestion type choices - used across multiple serializers
|
# Suggestion type choices - used across multiple serializers
|
||||||
SUGGESTION_TYPE_CHOICES = [
|
SUGGESTION_TYPE_CHOICES = [
|
||||||
'tag',
|
"tag",
|
||||||
'correspondent',
|
"correspondent",
|
||||||
'document_type',
|
"document_type",
|
||||||
'storage_path',
|
"storage_path",
|
||||||
'custom_field',
|
"custom_field",
|
||||||
'workflow',
|
"workflow",
|
||||||
'title',
|
"title",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Types that require value_id
|
# Types that require value_id
|
||||||
ID_REQUIRED_TYPES = ['tag', 'correspondent', 'document_type', 'storage_path', 'workflow']
|
ID_REQUIRED_TYPES = ["tag", "correspondent", "document_type", "storage_path", "workflow"]
|
||||||
# Types that require value_text
|
# Types that require value_text
|
||||||
TEXT_REQUIRED_TYPES = ['title']
|
TEXT_REQUIRED_TYPES = ["title"]
|
||||||
# Types that can use either (custom_field can be ID or text)
|
# Types that can use either (custom_field can be ID or text)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,7 +110,7 @@ class AISuggestionsSerializer(serializers.Serializer):
|
||||||
title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True)
|
title_suggestion = TitleSuggestionSerializer(required=False, allow_null=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_scan_result(scan_result, document_id: int) -> Dict[str, Any]:
|
def from_scan_result(scan_result, document_id: int) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert an AIScanResult object to serializer data.
|
Convert an AIScanResult object to serializer data.
|
||||||
|
|
||||||
|
|
@ -133,25 +130,25 @@ class AISuggestionsSerializer(serializers.Serializer):
|
||||||
try:
|
try:
|
||||||
tag = Tag.objects.get(pk=tag_id)
|
tag = Tag.objects.get(pk=tag_id)
|
||||||
tag_suggestions.append({
|
tag_suggestions.append({
|
||||||
'id': tag.id,
|
"id": tag.id,
|
||||||
'name': tag.name,
|
"name": tag.name,
|
||||||
'color': getattr(tag, 'color', '#000000'),
|
"color": getattr(tag, "color", "#000000"),
|
||||||
'confidence': confidence,
|
"confidence": confidence,
|
||||||
})
|
})
|
||||||
except Tag.DoesNotExist:
|
except Tag.DoesNotExist:
|
||||||
# Tag no longer exists in database; skip this suggestion
|
# Tag no longer exists in database; skip this suggestion
|
||||||
pass
|
pass
|
||||||
data['tags'] = tag_suggestions
|
data["tags"] = tag_suggestions
|
||||||
|
|
||||||
# Correspondent
|
# Correspondent
|
||||||
if scan_result.correspondent:
|
if scan_result.correspondent:
|
||||||
corr_id, confidence = scan_result.correspondent
|
corr_id, confidence = scan_result.correspondent
|
||||||
try:
|
try:
|
||||||
correspondent = Correspondent.objects.get(pk=corr_id)
|
correspondent = Correspondent.objects.get(pk=corr_id)
|
||||||
data['correspondent'] = {
|
data["correspondent"] = {
|
||||||
'id': correspondent.id,
|
"id": correspondent.id,
|
||||||
'name': correspondent.name,
|
"name": correspondent.name,
|
||||||
'confidence': confidence,
|
"confidence": confidence,
|
||||||
}
|
}
|
||||||
except Correspondent.DoesNotExist:
|
except Correspondent.DoesNotExist:
|
||||||
# Correspondent no longer exists in database; omit from suggestions
|
# Correspondent no longer exists in database; omit from suggestions
|
||||||
|
|
@ -162,10 +159,10 @@ class AISuggestionsSerializer(serializers.Serializer):
|
||||||
type_id, confidence = scan_result.document_type
|
type_id, confidence = scan_result.document_type
|
||||||
try:
|
try:
|
||||||
doc_type = DocumentType.objects.get(pk=type_id)
|
doc_type = DocumentType.objects.get(pk=type_id)
|
||||||
data['document_type'] = {
|
data["document_type"] = {
|
||||||
'id': doc_type.id,
|
"id": doc_type.id,
|
||||||
'name': doc_type.name,
|
"name": doc_type.name,
|
||||||
'confidence': confidence,
|
"confidence": confidence,
|
||||||
}
|
}
|
||||||
except DocumentType.DoesNotExist:
|
except DocumentType.DoesNotExist:
|
||||||
# Document type no longer exists in database; omit from suggestions
|
# Document type no longer exists in database; omit from suggestions
|
||||||
|
|
@ -176,11 +173,11 @@ class AISuggestionsSerializer(serializers.Serializer):
|
||||||
path_id, confidence = scan_result.storage_path
|
path_id, confidence = scan_result.storage_path
|
||||||
try:
|
try:
|
||||||
storage_path = StoragePath.objects.get(pk=path_id)
|
storage_path = StoragePath.objects.get(pk=path_id)
|
||||||
data['storage_path'] = {
|
data["storage_path"] = {
|
||||||
'id': storage_path.id,
|
"id": storage_path.id,
|
||||||
'name': storage_path.name,
|
"name": storage_path.name,
|
||||||
'path': storage_path.path,
|
"path": storage_path.path,
|
||||||
'confidence': confidence,
|
"confidence": confidence,
|
||||||
}
|
}
|
||||||
except StoragePath.DoesNotExist:
|
except StoragePath.DoesNotExist:
|
||||||
# Storage path no longer exists in database; omit from suggestions
|
# Storage path no longer exists in database; omit from suggestions
|
||||||
|
|
@ -193,15 +190,15 @@ class AISuggestionsSerializer(serializers.Serializer):
|
||||||
try:
|
try:
|
||||||
field = CustomField.objects.get(pk=field_id)
|
field = CustomField.objects.get(pk=field_id)
|
||||||
field_suggestions.append({
|
field_suggestions.append({
|
||||||
'field_id': field.id,
|
"field_id": field.id,
|
||||||
'field_name': field.name,
|
"field_name": field.name,
|
||||||
'value': str(value),
|
"value": str(value),
|
||||||
'confidence': confidence,
|
"confidence": confidence,
|
||||||
})
|
})
|
||||||
except CustomField.DoesNotExist:
|
except CustomField.DoesNotExist:
|
||||||
# Custom field no longer exists in database; skip this suggestion
|
# Custom field no longer exists in database; skip this suggestion
|
||||||
pass
|
pass
|
||||||
data['custom_fields'] = field_suggestions
|
data["custom_fields"] = field_suggestions
|
||||||
|
|
||||||
# Workflows
|
# Workflows
|
||||||
if scan_result.workflows:
|
if scan_result.workflows:
|
||||||
|
|
@ -210,19 +207,19 @@ class AISuggestionsSerializer(serializers.Serializer):
|
||||||
try:
|
try:
|
||||||
workflow = Workflow.objects.get(pk=workflow_id)
|
workflow = Workflow.objects.get(pk=workflow_id)
|
||||||
workflow_suggestions.append({
|
workflow_suggestions.append({
|
||||||
'id': workflow.id,
|
"id": workflow.id,
|
||||||
'name': workflow.name,
|
"name": workflow.name,
|
||||||
'confidence': confidence,
|
"confidence": confidence,
|
||||||
})
|
})
|
||||||
except Workflow.DoesNotExist:
|
except Workflow.DoesNotExist:
|
||||||
# Workflow no longer exists in database; skip this suggestion
|
# Workflow no longer exists in database; skip this suggestion
|
||||||
pass
|
pass
|
||||||
data['workflows'] = workflow_suggestions
|
data["workflows"] = workflow_suggestions
|
||||||
|
|
||||||
# Title suggestion
|
# Title suggestion
|
||||||
if scan_result.title_suggestion:
|
if scan_result.title_suggestion:
|
||||||
data['title_suggestion'] = {
|
data["title_suggestion"] = {
|
||||||
'title': scan_result.title_suggestion,
|
"title": scan_result.title_suggestion,
|
||||||
}
|
}
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
@ -234,26 +231,26 @@ class SuggestionSerializerMixin:
|
||||||
"""
|
"""
|
||||||
def validate(self, attrs):
|
def validate(self, attrs):
|
||||||
"""Validate that the correct value field is provided for the suggestion type."""
|
"""Validate that the correct value field is provided for the suggestion type."""
|
||||||
suggestion_type = attrs.get('suggestion_type')
|
suggestion_type = attrs.get("suggestion_type")
|
||||||
value_id = attrs.get('value_id')
|
value_id = attrs.get("value_id")
|
||||||
value_text = attrs.get('value_text')
|
value_text = attrs.get("value_text")
|
||||||
|
|
||||||
# Types that require value_id
|
# Types that require value_id
|
||||||
if suggestion_type in ID_REQUIRED_TYPES and not value_id:
|
if suggestion_type in ID_REQUIRED_TYPES and not value_id:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
f"value_id is required for suggestion_type '{suggestion_type}'"
|
f"value_id is required for suggestion_type '{suggestion_type}'",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Types that require value_text
|
# Types that require value_text
|
||||||
if suggestion_type in TEXT_REQUIRED_TYPES and not value_text:
|
if suggestion_type in TEXT_REQUIRED_TYPES and not value_text:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
f"value_text is required for suggestion_type '{suggestion_type}'"
|
f"value_text is required for suggestion_type '{suggestion_type}'",
|
||||||
)
|
)
|
||||||
|
|
||||||
# For custom_field, either is acceptable
|
# For custom_field, either is acceptable
|
||||||
if suggestion_type == 'custom_field' and not value_id and not value_text:
|
if suggestion_type == "custom_field" and not value_id and not value_text:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
"Either value_id or value_text must be provided for custom_field"
|
"Either value_id or value_text must be provided for custom_field",
|
||||||
)
|
)
|
||||||
|
|
||||||
return attrs
|
return attrs
|
||||||
|
|
@ -295,19 +292,19 @@ class AISuggestionFeedbackSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = AISuggestionFeedback
|
model = AISuggestionFeedback
|
||||||
fields = [
|
fields = [
|
||||||
'id',
|
"id",
|
||||||
'document',
|
"document",
|
||||||
'suggestion_type',
|
"suggestion_type",
|
||||||
'suggested_value_id',
|
"suggested_value_id",
|
||||||
'suggested_value_text',
|
"suggested_value_text",
|
||||||
'confidence',
|
"confidence",
|
||||||
'status',
|
"status",
|
||||||
'user',
|
"user",
|
||||||
'created_at',
|
"created_at",
|
||||||
'applied_at',
|
"applied_at",
|
||||||
'metadata',
|
"metadata",
|
||||||
]
|
]
|
||||||
read_only_fields = ['id', 'created_at', 'applied_at']
|
read_only_fields = ["id", "created_at", "applied_at"]
|
||||||
|
|
||||||
|
|
||||||
class AISuggestionStatsSerializer(serializers.Serializer):
|
class AISuggestionStatsSerializer(serializers.Serializer):
|
||||||
|
|
|
||||||
|
|
@ -18,13 +18,11 @@ from django.test import TestCase
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from documents.ai_deletion_manager import AIDeletionManager
|
from documents.ai_deletion_manager import AIDeletionManager
|
||||||
from documents.models import (
|
from documents.models import Correspondent
|
||||||
Correspondent,
|
from documents.models import DeletionRequest
|
||||||
DeletionRequest,
|
from documents.models import Document
|
||||||
Document,
|
from documents.models import DocumentType
|
||||||
DocumentType,
|
from documents.models import Tag
|
||||||
Tag,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAIDeletionManagerCreateRequest(TestCase):
|
class TestAIDeletionManagerCreateRequest(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -10,24 +10,23 @@ Tests cover:
|
||||||
- Permission assignment and verification
|
- Permission assignment and verification
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from django.contrib.auth.models import Group, Permission, User
|
from django.contrib.auth.models import Group
|
||||||
|
from django.contrib.auth.models import Permission
|
||||||
|
from django.contrib.auth.models import User
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.permissions import (
|
from documents.permissions import CanApplyAISuggestionsPermission
|
||||||
CanApplyAISuggestionsPermission,
|
from documents.permissions import CanApproveDeletionsPermission
|
||||||
CanApproveDeletionsPermission,
|
from documents.permissions import CanConfigureAIPermission
|
||||||
CanConfigureAIPermission,
|
from documents.permissions import CanViewAISuggestionsPermission
|
||||||
CanViewAISuggestionsPermission,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockView:
|
class MockView:
|
||||||
"""Mock view for testing permissions."""
|
"""Mock view for testing permissions."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TestCanViewAISuggestionsPermission(TestCase):
|
class TestCanViewAISuggestionsPermission(TestCase):
|
||||||
|
|
@ -41,13 +40,13 @@ class TestCanViewAISuggestionsPermission(TestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.regular_user = User.objects.create_user(
|
self.regular_user = User.objects.create_user(
|
||||||
username="regular", email="regular@test.com", password="regular123"
|
username="regular", email="regular@test.com", password="regular123",
|
||||||
)
|
)
|
||||||
self.permitted_user = User.objects.create_user(
|
self.permitted_user = User.objects.create_user(
|
||||||
username="permitted", email="permitted@test.com", password="permitted123"
|
username="permitted", email="permitted@test.com", password="permitted123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign permission to permitted_user
|
# Assign permission to permitted_user
|
||||||
|
|
@ -107,13 +106,13 @@ class TestCanApplyAISuggestionsPermission(TestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.regular_user = User.objects.create_user(
|
self.regular_user = User.objects.create_user(
|
||||||
username="regular", email="regular@test.com", password="regular123"
|
username="regular", email="regular@test.com", password="regular123",
|
||||||
)
|
)
|
||||||
self.permitted_user = User.objects.create_user(
|
self.permitted_user = User.objects.create_user(
|
||||||
username="permitted", email="permitted@test.com", password="permitted123"
|
username="permitted", email="permitted@test.com", password="permitted123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign permission to permitted_user
|
# Assign permission to permitted_user
|
||||||
|
|
@ -173,13 +172,13 @@ class TestCanApproveDeletionsPermission(TestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.regular_user = User.objects.create_user(
|
self.regular_user = User.objects.create_user(
|
||||||
username="regular", email="regular@test.com", password="regular123"
|
username="regular", email="regular@test.com", password="regular123",
|
||||||
)
|
)
|
||||||
self.permitted_user = User.objects.create_user(
|
self.permitted_user = User.objects.create_user(
|
||||||
username="permitted", email="permitted@test.com", password="permitted123"
|
username="permitted", email="permitted@test.com", password="permitted123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign permission to permitted_user
|
# Assign permission to permitted_user
|
||||||
|
|
@ -239,13 +238,13 @@ class TestCanConfigureAIPermission(TestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.regular_user = User.objects.create_user(
|
self.regular_user = User.objects.create_user(
|
||||||
username="regular", email="regular@test.com", password="regular123"
|
username="regular", email="regular@test.com", password="regular123",
|
||||||
)
|
)
|
||||||
self.permitted_user = User.objects.create_user(
|
self.permitted_user = User.objects.create_user(
|
||||||
username="permitted", email="permitted@test.com", password="permitted123"
|
username="permitted", email="permitted@test.com", password="permitted123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign permission to permitted_user
|
# Assign permission to permitted_user
|
||||||
|
|
@ -345,7 +344,7 @@ class TestRoleBasedAccessControl(TestCase):
|
||||||
def test_viewer_role_permissions(self):
|
def test_viewer_role_permissions(self):
|
||||||
"""Test that viewer role has appropriate permissions."""
|
"""Test that viewer role has appropriate permissions."""
|
||||||
user = User.objects.create_user(
|
user = User.objects.create_user(
|
||||||
username="viewer", email="viewer@test.com", password="viewer123"
|
username="viewer", email="viewer@test.com", password="viewer123",
|
||||||
)
|
)
|
||||||
user.groups.add(self.viewer_group)
|
user.groups.add(self.viewer_group)
|
||||||
|
|
||||||
|
|
@ -360,7 +359,7 @@ class TestRoleBasedAccessControl(TestCase):
|
||||||
def test_editor_role_permissions(self):
|
def test_editor_role_permissions(self):
|
||||||
"""Test that editor role has appropriate permissions."""
|
"""Test that editor role has appropriate permissions."""
|
||||||
user = User.objects.create_user(
|
user = User.objects.create_user(
|
||||||
username="editor", email="editor@test.com", password="editor123"
|
username="editor", email="editor@test.com", password="editor123",
|
||||||
)
|
)
|
||||||
user.groups.add(self.editor_group)
|
user.groups.add(self.editor_group)
|
||||||
|
|
||||||
|
|
@ -375,7 +374,7 @@ class TestRoleBasedAccessControl(TestCase):
|
||||||
def test_admin_role_permissions(self):
|
def test_admin_role_permissions(self):
|
||||||
"""Test that admin role has all permissions."""
|
"""Test that admin role has all permissions."""
|
||||||
user = User.objects.create_user(
|
user = User.objects.create_user(
|
||||||
username="ai_admin", email="ai_admin@test.com", password="admin123"
|
username="ai_admin", email="ai_admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
user.groups.add(self.admin_group)
|
user.groups.add(self.admin_group)
|
||||||
|
|
||||||
|
|
@ -390,7 +389,7 @@ class TestRoleBasedAccessControl(TestCase):
|
||||||
def test_user_with_multiple_groups(self):
|
def test_user_with_multiple_groups(self):
|
||||||
"""Test that user permissions accumulate from multiple groups."""
|
"""Test that user permissions accumulate from multiple groups."""
|
||||||
user = User.objects.create_user(
|
user = User.objects.create_user(
|
||||||
username="multi_role", email="multi@test.com", password="multi123"
|
username="multi_role", email="multi@test.com", password="multi123",
|
||||||
)
|
)
|
||||||
user.groups.add(self.viewer_group, self.editor_group)
|
user.groups.add(self.viewer_group, self.editor_group)
|
||||||
|
|
||||||
|
|
@ -405,7 +404,7 @@ class TestRoleBasedAccessControl(TestCase):
|
||||||
def test_direct_permission_assignment_overrides_group(self):
|
def test_direct_permission_assignment_overrides_group(self):
|
||||||
"""Test that direct permission assignment works alongside group permissions."""
|
"""Test that direct permission assignment works alongside group permissions."""
|
||||||
user = User.objects.create_user(
|
user = User.objects.create_user(
|
||||||
username="special", email="special@test.com", password="special123"
|
username="special", email="special@test.com", password="special123",
|
||||||
)
|
)
|
||||||
user.groups.add(self.viewer_group)
|
user.groups.add(self.viewer_group)
|
||||||
|
|
||||||
|
|
@ -428,7 +427,7 @@ class TestPermissionAssignment(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test user."""
|
"""Set up test user."""
|
||||||
self.user = User.objects.create_user(
|
self.user = User.objects.create_user(
|
||||||
username="testuser", email="test@test.com", password="test123"
|
username="testuser", email="test@test.com", password="test123",
|
||||||
)
|
)
|
||||||
content_type = ContentType.objects.get_for_model(Document)
|
content_type = ContentType.objects.get_for_model(Document)
|
||||||
self.view_permission, _ = Permission.objects.get_or_create(
|
self.view_permission, _ = Permission.objects.get_or_create(
|
||||||
|
|
@ -500,7 +499,7 @@ class TestPermissionEdgeCases(TestCase):
|
||||||
def test_inactive_user_with_permission(self):
|
def test_inactive_user_with_permission(self):
|
||||||
"""Test that inactive users are denied even with permission."""
|
"""Test that inactive users are denied even with permission."""
|
||||||
user = User.objects.create_user(
|
user = User.objects.create_user(
|
||||||
username="inactive", email="inactive@test.com", password="inactive123"
|
username="inactive", email="inactive@test.com", password="inactive123",
|
||||||
)
|
)
|
||||||
user.is_active = False
|
user.is_active = False
|
||||||
user.save()
|
user.save()
|
||||||
|
|
|
||||||
|
|
@ -21,24 +21,21 @@ Tests cover:
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase
|
||||||
|
from django.test import override_settings
|
||||||
|
|
||||||
from documents.ai_scanner import (
|
from documents.ai_scanner import AIDocumentScanner
|
||||||
AIScanResult,
|
from documents.ai_scanner import AIScanResult
|
||||||
AIDocumentScanner,
|
from documents.ai_scanner import get_ai_scanner
|
||||||
get_ai_scanner,
|
from documents.models import Correspondent
|
||||||
)
|
from documents.models import CustomField
|
||||||
from documents.models import (
|
from documents.models import Document
|
||||||
Correspondent,
|
from documents.models import DocumentType
|
||||||
CustomField,
|
from documents.models import StoragePath
|
||||||
Document,
|
from documents.models import Tag
|
||||||
DocumentType,
|
from documents.models import Workflow
|
||||||
StoragePath,
|
from documents.models import WorkflowAction
|
||||||
Tag,
|
from documents.models import WorkflowTrigger
|
||||||
Workflow,
|
|
||||||
WorkflowTrigger,
|
|
||||||
WorkflowAction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAIScanResult(TestCase):
|
class TestAIScanResult(TestCase):
|
||||||
|
|
@ -100,7 +97,7 @@ class TestAIDocumentScannerInitialization(TestCase):
|
||||||
"""Test scanner initialization with custom confidence thresholds."""
|
"""Test scanner initialization with custom confidence thresholds."""
|
||||||
scanner = AIDocumentScanner(
|
scanner = AIDocumentScanner(
|
||||||
auto_apply_threshold=0.90,
|
auto_apply_threshold=0.90,
|
||||||
suggest_threshold=0.70
|
suggest_threshold=0.70,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(scanner.auto_apply_threshold, 0.90)
|
self.assertEqual(scanner.auto_apply_threshold, 0.90)
|
||||||
|
|
@ -145,14 +142,14 @@ class TestAIDocumentScannerInitialization(TestCase):
|
||||||
class TestAIDocumentScannerLazyLoading(TestCase):
|
class TestAIDocumentScannerLazyLoading(TestCase):
|
||||||
"""Test lazy loading of ML components."""
|
"""Test lazy loading of ML components."""
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_get_classifier_loads_successfully(self, mock_logger):
|
def test_get_classifier_loads_successfully(self, mock_logger):
|
||||||
"""Test successful lazy loading of classifier."""
|
"""Test successful lazy loading of classifier."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
# Mock the import and class
|
# Mock the import and class
|
||||||
mock_classifier_instance = mock.MagicMock()
|
mock_classifier_instance = mock.MagicMock()
|
||||||
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
|
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
|
||||||
return_value=mock_classifier_instance) as mock_classifier_class:
|
return_value=mock_classifier_instance) as mock_classifier_class:
|
||||||
classifier = scanner._get_classifier()
|
classifier = scanner._get_classifier()
|
||||||
|
|
||||||
|
|
@ -161,13 +158,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
||||||
mock_classifier_class.assert_called_once()
|
mock_classifier_class.assert_called_once()
|
||||||
mock_logger.info.assert_called_with("ML classifier loaded successfully")
|
mock_logger.info.assert_called_with("ML classifier loaded successfully")
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_get_classifier_returns_cached_instance(self, mock_logger):
|
def test_get_classifier_returns_cached_instance(self, mock_logger):
|
||||||
"""Test that classifier is only loaded once."""
|
"""Test that classifier is only loaded once."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
mock_classifier_instance = mock.MagicMock()
|
mock_classifier_instance = mock.MagicMock()
|
||||||
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
|
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
|
||||||
return_value=mock_classifier_instance):
|
return_value=mock_classifier_instance):
|
||||||
classifier1 = scanner._get_classifier()
|
classifier1 = scanner._get_classifier()
|
||||||
classifier2 = scanner._get_classifier()
|
classifier2 = scanner._get_classifier()
|
||||||
|
|
@ -175,12 +172,12 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
||||||
self.assertEqual(classifier1, classifier2)
|
self.assertEqual(classifier1, classifier2)
|
||||||
self.assertIs(classifier1, classifier2)
|
self.assertIs(classifier1, classifier2)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_get_classifier_handles_import_error(self, mock_logger):
|
def test_get_classifier_handles_import_error(self, mock_logger):
|
||||||
"""Test that classifier loading handles import errors gracefully."""
|
"""Test that classifier loading handles import errors gracefully."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
with mock.patch('documents.ai_scanner.TransformerDocumentClassifier',
|
with mock.patch("documents.ai_scanner.TransformerDocumentClassifier",
|
||||||
side_effect=ImportError("Module not found")):
|
side_effect=ImportError("Module not found")):
|
||||||
classifier = scanner._get_classifier()
|
classifier = scanner._get_classifier()
|
||||||
|
|
||||||
|
|
@ -196,13 +193,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
||||||
|
|
||||||
self.assertIsNone(classifier)
|
self.assertIsNone(classifier)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_get_ner_extractor_loads_successfully(self, mock_logger):
|
def test_get_ner_extractor_loads_successfully(self, mock_logger):
|
||||||
"""Test successful lazy loading of NER extractor."""
|
"""Test successful lazy loading of NER extractor."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
mock_ner_instance = mock.MagicMock()
|
mock_ner_instance = mock.MagicMock()
|
||||||
with mock.patch('documents.ai_scanner.DocumentNER',
|
with mock.patch("documents.ai_scanner.DocumentNER",
|
||||||
return_value=mock_ner_instance) as mock_ner_class:
|
return_value=mock_ner_instance) as mock_ner_class:
|
||||||
ner = scanner._get_ner_extractor()
|
ner = scanner._get_ner_extractor()
|
||||||
|
|
||||||
|
|
@ -211,25 +208,25 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
||||||
mock_ner_class.assert_called_once()
|
mock_ner_class.assert_called_once()
|
||||||
mock_logger.info.assert_called_with("NER extractor loaded successfully")
|
mock_logger.info.assert_called_with("NER extractor loaded successfully")
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_get_ner_extractor_handles_error(self, mock_logger):
|
def test_get_ner_extractor_handles_error(self, mock_logger):
|
||||||
"""Test NER extractor handles loading errors."""
|
"""Test NER extractor handles loading errors."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
with mock.patch('documents.ai_scanner.DocumentNER',
|
with mock.patch("documents.ai_scanner.DocumentNER",
|
||||||
side_effect=Exception("Failed to load")):
|
side_effect=Exception("Failed to load")):
|
||||||
ner = scanner._get_ner_extractor()
|
ner = scanner._get_ner_extractor()
|
||||||
|
|
||||||
self.assertIsNone(ner)
|
self.assertIsNone(ner)
|
||||||
mock_logger.warning.assert_called()
|
mock_logger.warning.assert_called()
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_get_semantic_search_loads_successfully(self, mock_logger):
|
def test_get_semantic_search_loads_successfully(self, mock_logger):
|
||||||
"""Test successful lazy loading of semantic search."""
|
"""Test successful lazy loading of semantic search."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
mock_search_instance = mock.MagicMock()
|
mock_search_instance = mock.MagicMock()
|
||||||
with mock.patch('documents.ai_scanner.SemanticSearch',
|
with mock.patch("documents.ai_scanner.SemanticSearch",
|
||||||
return_value=mock_search_instance) as mock_search_class:
|
return_value=mock_search_instance) as mock_search_class:
|
||||||
search = scanner._get_semantic_search()
|
search = scanner._get_semantic_search()
|
||||||
|
|
||||||
|
|
@ -238,13 +235,13 @@ class TestAIDocumentScannerLazyLoading(TestCase):
|
||||||
mock_search_class.assert_called_once()
|
mock_search_class.assert_called_once()
|
||||||
mock_logger.info.assert_called_with("Semantic search loaded successfully")
|
mock_logger.info.assert_called_with("Semantic search loaded successfully")
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_get_table_extractor_loads_successfully(self, mock_logger):
|
def test_get_table_extractor_loads_successfully(self, mock_logger):
|
||||||
"""Test successful lazy loading of table extractor."""
|
"""Test successful lazy loading of table extractor."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
mock_extractor_instance = mock.MagicMock()
|
mock_extractor_instance = mock.MagicMock()
|
||||||
with mock.patch('documents.ai_scanner.TableExtractor',
|
with mock.patch("documents.ai_scanner.TableExtractor",
|
||||||
return_value=mock_extractor_instance) as mock_extractor_class:
|
return_value=mock_extractor_instance) as mock_extractor_class:
|
||||||
extractor = scanner._get_table_extractor()
|
extractor = scanner._get_table_extractor()
|
||||||
|
|
||||||
|
|
@ -276,7 +273,7 @@ class TestExtractEntities(TestCase):
|
||||||
"dates": ["2024-01-01", "2024-12-31"],
|
"dates": ["2024-01-01", "2024-12-31"],
|
||||||
"amounts": ["$1,000", "$500"],
|
"amounts": ["$1,000", "$500"],
|
||||||
"locations": ["New York"],
|
"locations": ["New York"],
|
||||||
"misc": ["Invoice#123"]
|
"misc": ["Invoice#123"],
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner._ner_extractor = mock_ner
|
scanner._ner_extractor = mock_ner
|
||||||
|
|
@ -320,7 +317,7 @@ class TestExtractEntities(TestCase):
|
||||||
|
|
||||||
self.assertEqual(entities, {})
|
self.assertEqual(entities, {})
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_extract_entities_handles_exception(self, mock_logger):
|
def test_extract_entities_handles_exception(self, mock_logger):
|
||||||
"""Test that entity extraction handles exceptions gracefully."""
|
"""Test that entity extraction handles exceptions gracefully."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -345,10 +342,10 @@ class TestSuggestTags(TestCase):
|
||||||
self.tag3 = Tag.objects.create(name="Tax", matching_algorithm=Tag.MATCH_AUTO)
|
self.tag3 = Tag.objects.create(name="Tax", matching_algorithm=Tag.MATCH_AUTO)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_tags')
|
@mock.patch("documents.ai_scanner.match_tags")
|
||||||
def test_suggest_tags_with_matched_tags(self, mock_match_tags):
|
def test_suggest_tags_with_matched_tags(self, mock_match_tags):
|
||||||
"""Test tag suggestions from matching."""
|
"""Test tag suggestions from matching."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -357,7 +354,7 @@ class TestSuggestTags(TestCase):
|
||||||
suggestions = scanner._suggest_tags(
|
suggestions = scanner._suggest_tags(
|
||||||
self.document,
|
self.document,
|
||||||
"Invoice from ACME Corp",
|
"Invoice from ACME Corp",
|
||||||
{}
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should suggest both matched tags
|
# Should suggest both matched tags
|
||||||
|
|
@ -370,14 +367,14 @@ class TestSuggestTags(TestCase):
|
||||||
for _, confidence in suggestions:
|
for _, confidence in suggestions:
|
||||||
self.assertGreaterEqual(confidence, 0.6)
|
self.assertGreaterEqual(confidence, 0.6)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_tags')
|
@mock.patch("documents.ai_scanner.match_tags")
|
||||||
def test_suggest_tags_with_organization_entities(self, mock_match_tags):
|
def test_suggest_tags_with_organization_entities(self, mock_match_tags):
|
||||||
"""Test tag suggestions based on organization entities."""
|
"""Test tag suggestions based on organization entities."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
mock_match_tags.return_value = []
|
mock_match_tags.return_value = []
|
||||||
|
|
||||||
entities = {
|
entities = {
|
||||||
"organizations": [{"text": "ACME Corp"}]
|
"organizations": [{"text": "ACME Corp"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
suggestions = scanner._suggest_tags(self.document, "text", entities)
|
suggestions = scanner._suggest_tags(self.document, "text", entities)
|
||||||
|
|
@ -386,7 +383,7 @@ class TestSuggestTags(TestCase):
|
||||||
tag_ids = [tag_id for tag_id, _ in suggestions]
|
tag_ids = [tag_id for tag_id, _ in suggestions]
|
||||||
self.assertIn(self.tag2.id, tag_ids)
|
self.assertIn(self.tag2.id, tag_ids)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_tags')
|
@mock.patch("documents.ai_scanner.match_tags")
|
||||||
def test_suggest_tags_removes_duplicates(self, mock_match_tags):
|
def test_suggest_tags_removes_duplicates(self, mock_match_tags):
|
||||||
"""Test that duplicate tags keep highest confidence."""
|
"""Test that duplicate tags keep highest confidence."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -397,8 +394,8 @@ class TestSuggestTags(TestCase):
|
||||||
|
|
||||||
# Implementation should remove duplicates in actual code
|
# Implementation should remove duplicates in actual code
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_tags')
|
@mock.patch("documents.ai_scanner.match_tags")
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_suggest_tags_handles_exception(self, mock_logger, mock_match_tags):
|
def test_suggest_tags_handles_exception(self, mock_logger, mock_match_tags):
|
||||||
"""Test tag suggestion handles exceptions."""
|
"""Test tag suggestion handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -417,18 +414,18 @@ class TestDetectCorrespondent(TestCase):
|
||||||
"""Set up test correspondents."""
|
"""Set up test correspondents."""
|
||||||
self.correspondent1 = Correspondent.objects.create(
|
self.correspondent1 = Correspondent.objects.create(
|
||||||
name="ACME Corporation",
|
name="ACME Corporation",
|
||||||
matching_algorithm=Correspondent.MATCH_AUTO
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
self.correspondent2 = Correspondent.objects.create(
|
self.correspondent2 = Correspondent.objects.create(
|
||||||
name="TechStart Inc",
|
name="TechStart Inc",
|
||||||
matching_algorithm=Correspondent.MATCH_AUTO
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||||
def test_detect_correspondent_with_match(self, mock_match):
|
def test_detect_correspondent_with_match(self, mock_match):
|
||||||
"""Test correspondent detection with successful match."""
|
"""Test correspondent detection with successful match."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -441,7 +438,7 @@ class TestDetectCorrespondent(TestCase):
|
||||||
self.assertEqual(corr_id, self.correspondent1.id)
|
self.assertEqual(corr_id, self.correspondent1.id)
|
||||||
self.assertEqual(confidence, 0.85)
|
self.assertEqual(confidence, 0.85)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||||
def test_detect_correspondent_without_match(self, mock_match):
|
def test_detect_correspondent_without_match(self, mock_match):
|
||||||
"""Test correspondent detection without match."""
|
"""Test correspondent detection without match."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -451,14 +448,14 @@ class TestDetectCorrespondent(TestCase):
|
||||||
|
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||||
def test_detect_correspondent_from_ner_entities(self, mock_match):
|
def test_detect_correspondent_from_ner_entities(self, mock_match):
|
||||||
"""Test correspondent detection from NER organizations."""
|
"""Test correspondent detection from NER organizations."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
mock_match.return_value = []
|
mock_match.return_value = []
|
||||||
|
|
||||||
entities = {
|
entities = {
|
||||||
"organizations": [{"text": "ACME Corporation"}]
|
"organizations": [{"text": "ACME Corporation"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
result = scanner._detect_correspondent(self.document, "text", entities)
|
result = scanner._detect_correspondent(self.document, "text", entities)
|
||||||
|
|
@ -468,8 +465,8 @@ class TestDetectCorrespondent(TestCase):
|
||||||
self.assertEqual(corr_id, self.correspondent1.id)
|
self.assertEqual(corr_id, self.correspondent1.id)
|
||||||
self.assertEqual(confidence, 0.70)
|
self.assertEqual(confidence, 0.70)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_detect_correspondent_handles_exception(self, mock_logger, mock_match):
|
def test_detect_correspondent_handles_exception(self, mock_logger, mock_match):
|
||||||
"""Test correspondent detection handles exceptions."""
|
"""Test correspondent detection handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -488,18 +485,18 @@ class TestClassifyDocumentType(TestCase):
|
||||||
"""Set up test document types."""
|
"""Set up test document types."""
|
||||||
self.doc_type1 = DocumentType.objects.create(
|
self.doc_type1 = DocumentType.objects.create(
|
||||||
name="Invoice",
|
name="Invoice",
|
||||||
matching_algorithm=DocumentType.MATCH_AUTO
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
self.doc_type2 = DocumentType.objects.create(
|
self.doc_type2 = DocumentType.objects.create(
|
||||||
name="Receipt",
|
name="Receipt",
|
||||||
matching_algorithm=DocumentType.MATCH_AUTO
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_document_types')
|
@mock.patch("documents.ai_scanner.match_document_types")
|
||||||
def test_classify_document_type_with_match(self, mock_match):
|
def test_classify_document_type_with_match(self, mock_match):
|
||||||
"""Test document type classification with match."""
|
"""Test document type classification with match."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -512,7 +509,7 @@ class TestClassifyDocumentType(TestCase):
|
||||||
self.assertEqual(type_id, self.doc_type1.id)
|
self.assertEqual(type_id, self.doc_type1.id)
|
||||||
self.assertEqual(confidence, 0.85)
|
self.assertEqual(confidence, 0.85)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_document_types')
|
@mock.patch("documents.ai_scanner.match_document_types")
|
||||||
def test_classify_document_type_without_match(self, mock_match):
|
def test_classify_document_type_without_match(self, mock_match):
|
||||||
"""Test document type classification without match."""
|
"""Test document type classification without match."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -522,8 +519,8 @@ class TestClassifyDocumentType(TestCase):
|
||||||
|
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_document_types')
|
@mock.patch("documents.ai_scanner.match_document_types")
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_classify_document_type_handles_exception(self, mock_logger, mock_match):
|
def test_classify_document_type_handles_exception(self, mock_logger, mock_match):
|
||||||
"""Test classification handles exceptions."""
|
"""Test classification handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -543,14 +540,14 @@ class TestSuggestStoragePath(TestCase):
|
||||||
self.storage_path1 = StoragePath.objects.create(
|
self.storage_path1 = StoragePath.objects.create(
|
||||||
name="Invoices",
|
name="Invoices",
|
||||||
path="/documents/invoices",
|
path="/documents/invoices",
|
||||||
matching_algorithm=StoragePath.MATCH_AUTO
|
matching_algorithm=StoragePath.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||||
def test_suggest_storage_path_with_match(self, mock_match):
|
def test_suggest_storage_path_with_match(self, mock_match):
|
||||||
"""Test storage path suggestion with match."""
|
"""Test storage path suggestion with match."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -564,7 +561,7 @@ class TestSuggestStoragePath(TestCase):
|
||||||
self.assertEqual(path_id, self.storage_path1.id)
|
self.assertEqual(path_id, self.storage_path1.id)
|
||||||
self.assertEqual(confidence, 0.80)
|
self.assertEqual(confidence, 0.80)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||||
def test_suggest_storage_path_without_match(self, mock_match):
|
def test_suggest_storage_path_without_match(self, mock_match):
|
||||||
"""Test storage path suggestion without match."""
|
"""Test storage path suggestion without match."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -575,8 +572,8 @@ class TestSuggestStoragePath(TestCase):
|
||||||
|
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_suggest_storage_path_handles_exception(self, mock_logger, mock_match):
|
def test_suggest_storage_path_handles_exception(self, mock_logger, mock_match):
|
||||||
"""Test storage path suggestion handles exceptions."""
|
"""Test storage path suggestion handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -596,19 +593,19 @@ class TestExtractCustomFields(TestCase):
|
||||||
"""Set up test custom fields."""
|
"""Set up test custom fields."""
|
||||||
self.field_date = CustomField.objects.create(
|
self.field_date = CustomField.objects.create(
|
||||||
name="Invoice Date",
|
name="Invoice Date",
|
||||||
data_type=CustomField.FieldDataType.DATE
|
data_type=CustomField.FieldDataType.DATE,
|
||||||
)
|
)
|
||||||
self.field_amount = CustomField.objects.create(
|
self.field_amount = CustomField.objects.create(
|
||||||
name="Total Amount",
|
name="Total Amount",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_email = CustomField.objects.create(
|
self.field_email = CustomField.objects.create(
|
||||||
name="Contact Email",
|
name="Contact Email",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_extract_custom_fields_with_entities(self):
|
def test_extract_custom_fields_with_entities(self):
|
||||||
|
|
@ -618,7 +615,7 @@ class TestExtractCustomFields(TestCase):
|
||||||
entities = {
|
entities = {
|
||||||
"dates": [{"text": "2024-01-01"}],
|
"dates": [{"text": "2024-01-01"}],
|
||||||
"amounts": [{"text": "$1,000"}],
|
"amounts": [{"text": "$1,000"}],
|
||||||
"emails": ["test@example.com"]
|
"emails": ["test@example.com"],
|
||||||
}
|
}
|
||||||
|
|
||||||
fields = scanner._extract_custom_fields(self.document, "text", entities)
|
fields = scanner._extract_custom_fields(self.document, "text", entities)
|
||||||
|
|
@ -638,12 +635,12 @@ class TestExtractCustomFields(TestCase):
|
||||||
# Should return empty dict
|
# Should return empty dict
|
||||||
self.assertEqual(fields, {})
|
self.assertEqual(fields, {})
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_extract_custom_fields_handles_exception(self, mock_logger):
|
def test_extract_custom_fields_handles_exception(self, mock_logger):
|
||||||
"""Test custom field extraction handles exceptions."""
|
"""Test custom field extraction handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
with mock.patch.object(CustomField.objects, 'all',
|
with mock.patch.object(CustomField.objects, "all",
|
||||||
side_effect=Exception("DB error")):
|
side_effect=Exception("DB error")):
|
||||||
fields = scanner._extract_custom_fields(self.document, "text", {})
|
fields = scanner._extract_custom_fields(self.document, "text", {})
|
||||||
|
|
||||||
|
|
@ -658,31 +655,31 @@ class TestExtractFieldValue(TestCase):
|
||||||
"""Set up test fields."""
|
"""Set up test fields."""
|
||||||
self.field_date = CustomField.objects.create(
|
self.field_date = CustomField.objects.create(
|
||||||
name="Invoice Date",
|
name="Invoice Date",
|
||||||
data_type=CustomField.FieldDataType.DATE
|
data_type=CustomField.FieldDataType.DATE,
|
||||||
)
|
)
|
||||||
self.field_amount = CustomField.objects.create(
|
self.field_amount = CustomField.objects.create(
|
||||||
name="Total Amount",
|
name="Total Amount",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_invoice = CustomField.objects.create(
|
self.field_invoice = CustomField.objects.create(
|
||||||
name="Invoice Number",
|
name="Invoice Number",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_email = CustomField.objects.create(
|
self.field_email = CustomField.objects.create(
|
||||||
name="Contact Email",
|
name="Contact Email",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_phone = CustomField.objects.create(
|
self.field_phone = CustomField.objects.create(
|
||||||
name="Phone Number",
|
name="Phone Number",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_person = CustomField.objects.create(
|
self.field_person = CustomField.objects.create(
|
||||||
name="Person Name",
|
name="Person Name",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_company = CustomField.objects.create(
|
self.field_company = CustomField.objects.create(
|
||||||
name="Company Name",
|
name="Company Name",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_extract_field_value_date(self):
|
def test_extract_field_value_date(self):
|
||||||
|
|
@ -691,7 +688,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {"dates": [{"text": "2024-01-01"}]}
|
entities = {"dates": [{"text": "2024-01-01"}]}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_date, "text", entities
|
self.field_date, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(value, "2024-01-01")
|
self.assertEqual(value, "2024-01-01")
|
||||||
|
|
@ -703,7 +700,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {"amounts": [{"text": "$1,000"}]}
|
entities = {"amounts": [{"text": "$1,000"}]}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_amount, "text", entities
|
self.field_amount, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(value, "$1,000")
|
self.assertEqual(value, "$1,000")
|
||||||
|
|
@ -715,7 +712,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {"invoice_numbers": ["INV-12345"]}
|
entities = {"invoice_numbers": ["INV-12345"]}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_invoice, "text", entities
|
self.field_invoice, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(value, "INV-12345")
|
self.assertEqual(value, "INV-12345")
|
||||||
|
|
@ -727,7 +724,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {"emails": ["test@example.com"]}
|
entities = {"emails": ["test@example.com"]}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_email, "text", entities
|
self.field_email, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(value, "test@example.com")
|
self.assertEqual(value, "test@example.com")
|
||||||
|
|
@ -739,7 +736,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {"phones": ["+1-555-1234"]}
|
entities = {"phones": ["+1-555-1234"]}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_phone, "text", entities
|
self.field_phone, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(value, "+1-555-1234")
|
self.assertEqual(value, "+1-555-1234")
|
||||||
|
|
@ -751,7 +748,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {"persons": [{"text": "John Doe"}]}
|
entities = {"persons": [{"text": "John Doe"}]}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_person, "text", entities
|
self.field_person, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(value, "John Doe")
|
self.assertEqual(value, "John Doe")
|
||||||
|
|
@ -763,7 +760,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {"organizations": [{"text": "ACME Corp"}]}
|
entities = {"organizations": [{"text": "ACME Corp"}]}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_company, "text", entities
|
self.field_company, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(value, "ACME Corp")
|
self.assertEqual(value, "ACME Corp")
|
||||||
|
|
@ -775,7 +772,7 @@ class TestExtractFieldValue(TestCase):
|
||||||
entities = {}
|
entities = {}
|
||||||
|
|
||||||
value, confidence = scanner._extract_field_value(
|
value, confidence = scanner._extract_field_value(
|
||||||
self.field_date, "text", entities
|
self.field_date, "text", entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(value)
|
self.assertIsNone(value)
|
||||||
|
|
@ -789,23 +786,23 @@ class TestSuggestWorkflows(TestCase):
|
||||||
"""Set up test workflows."""
|
"""Set up test workflows."""
|
||||||
self.workflow1 = Workflow.objects.create(
|
self.workflow1 = Workflow.objects.create(
|
||||||
name="Invoice Processing",
|
name="Invoice Processing",
|
||||||
enabled=True
|
enabled=True,
|
||||||
)
|
)
|
||||||
self.trigger1 = WorkflowTrigger.objects.create(
|
self.trigger1 = WorkflowTrigger.objects.create(
|
||||||
workflow=self.workflow1,
|
workflow=self.workflow1,
|
||||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||||
)
|
)
|
||||||
self.workflow2 = Workflow.objects.create(
|
self.workflow2 = Workflow.objects.create(
|
||||||
name="Document Archival",
|
name="Document Archival",
|
||||||
enabled=True
|
enabled=True,
|
||||||
)
|
)
|
||||||
self.trigger2 = WorkflowTrigger.objects.create(
|
self.trigger2 = WorkflowTrigger.objects.create(
|
||||||
workflow=self.workflow2,
|
workflow=self.workflow2,
|
||||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||||
)
|
)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_suggest_workflows_with_matches(self):
|
def test_suggest_workflows_with_matches(self):
|
||||||
|
|
@ -820,7 +817,7 @@ class TestSuggestWorkflows(TestCase):
|
||||||
# Create action for workflow
|
# Create action for workflow
|
||||||
WorkflowAction.objects.create(
|
WorkflowAction.objects.create(
|
||||||
workflow=self.workflow1,
|
workflow=self.workflow1,
|
||||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||||
)
|
)
|
||||||
|
|
||||||
suggestions = scanner._suggest_workflows(self.document, "text", scan_result)
|
suggestions = scanner._suggest_workflows(self.document, "text", scan_result)
|
||||||
|
|
@ -841,17 +838,17 @@ class TestSuggestWorkflows(TestCase):
|
||||||
# Should not suggest any (confidence too low)
|
# Should not suggest any (confidence too low)
|
||||||
self.assertEqual(len(suggestions), 0)
|
self.assertEqual(len(suggestions), 0)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_suggest_workflows_handles_exception(self, mock_logger):
|
def test_suggest_workflows_handles_exception(self, mock_logger):
|
||||||
"""Test workflow suggestion handles exceptions."""
|
"""Test workflow suggestion handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
scan_result = AIScanResult()
|
scan_result = AIScanResult()
|
||||||
|
|
||||||
with mock.patch.object(Workflow.objects, 'filter',
|
with mock.patch.object(Workflow.objects, "filter",
|
||||||
side_effect=Exception("DB error")):
|
side_effect=Exception("DB error")):
|
||||||
suggestions = scanner._suggest_workflows(
|
suggestions = scanner._suggest_workflows(
|
||||||
self.document, "text", scan_result
|
self.document, "text", scan_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(suggestions, [])
|
self.assertEqual(suggestions, [])
|
||||||
|
|
@ -865,11 +862,11 @@ class TestEvaluateWorkflowMatch(TestCase):
|
||||||
"""Set up test workflow."""
|
"""Set up test workflow."""
|
||||||
self.workflow = Workflow.objects.create(
|
self.workflow = Workflow.objects.create(
|
||||||
name="Test Workflow",
|
name="Test Workflow",
|
||||||
enabled=True
|
enabled=True,
|
||||||
)
|
)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_evaluate_workflow_match_base_confidence(self):
|
def test_evaluate_workflow_match_base_confidence(self):
|
||||||
|
|
@ -878,7 +875,7 @@ class TestEvaluateWorkflowMatch(TestCase):
|
||||||
scan_result = AIScanResult()
|
scan_result = AIScanResult()
|
||||||
|
|
||||||
confidence = scanner._evaluate_workflow_match(
|
confidence = scanner._evaluate_workflow_match(
|
||||||
self.workflow, self.document, scan_result
|
self.workflow, self.document, scan_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(confidence, 0.5)
|
self.assertEqual(confidence, 0.5)
|
||||||
|
|
@ -892,11 +889,11 @@ class TestEvaluateWorkflowMatch(TestCase):
|
||||||
# Create action for workflow
|
# Create action for workflow
|
||||||
WorkflowAction.objects.create(
|
WorkflowAction.objects.create(
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||||
)
|
)
|
||||||
|
|
||||||
confidence = scanner._evaluate_workflow_match(
|
confidence = scanner._evaluate_workflow_match(
|
||||||
self.workflow, self.document, scan_result
|
self.workflow, self.document, scan_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertGreater(confidence, 0.5)
|
self.assertGreater(confidence, 0.5)
|
||||||
|
|
@ -908,7 +905,7 @@ class TestEvaluateWorkflowMatch(TestCase):
|
||||||
scan_result.correspondent = (1, 0.90)
|
scan_result.correspondent = (1, 0.90)
|
||||||
|
|
||||||
confidence = scanner._evaluate_workflow_match(
|
confidence = scanner._evaluate_workflow_match(
|
||||||
self.workflow, self.document, scan_result
|
self.workflow, self.document, scan_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertGreater(confidence, 0.5)
|
self.assertGreater(confidence, 0.5)
|
||||||
|
|
@ -920,7 +917,7 @@ class TestEvaluateWorkflowMatch(TestCase):
|
||||||
scan_result.tags = [(1, 0.80), (2, 0.75)]
|
scan_result.tags = [(1, 0.80), (2, 0.75)]
|
||||||
|
|
||||||
confidence = scanner._evaluate_workflow_match(
|
confidence = scanner._evaluate_workflow_match(
|
||||||
self.workflow, self.document, scan_result
|
self.workflow, self.document, scan_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertGreater(confidence, 0.5)
|
self.assertGreater(confidence, 0.5)
|
||||||
|
|
@ -936,11 +933,11 @@ class TestEvaluateWorkflowMatch(TestCase):
|
||||||
# Create action
|
# Create action
|
||||||
WorkflowAction.objects.create(
|
WorkflowAction.objects.create(
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||||
)
|
)
|
||||||
|
|
||||||
confidence = scanner._evaluate_workflow_match(
|
confidence = scanner._evaluate_workflow_match(
|
||||||
self.workflow, self.document, scan_result
|
self.workflow, self.document, scan_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertLessEqual(confidence, 1.0)
|
self.assertLessEqual(confidence, 1.0)
|
||||||
|
|
@ -953,7 +950,7 @@ class TestSuggestTitle(TestCase):
|
||||||
"""Set up test document."""
|
"""Set up test document."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_suggest_title_with_all_entities(self):
|
def test_suggest_title_with_all_entities(self):
|
||||||
|
|
@ -963,7 +960,7 @@ class TestSuggestTitle(TestCase):
|
||||||
entities = {
|
entities = {
|
||||||
"document_type": "Invoice",
|
"document_type": "Invoice",
|
||||||
"organizations": [{"text": "ACME Corporation"}],
|
"organizations": [{"text": "ACME Corporation"}],
|
||||||
"dates": [{"text": "2024-01-01"}]
|
"dates": [{"text": "2024-01-01"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
title = scanner._suggest_title(self.document, "text", entities)
|
title = scanner._suggest_title(self.document, "text", entities)
|
||||||
|
|
@ -978,7 +975,7 @@ class TestSuggestTitle(TestCase):
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
entities = {
|
entities = {
|
||||||
"organizations": [{"text": "TechStart Inc"}]
|
"organizations": [{"text": "TechStart Inc"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
title = scanner._suggest_title(self.document, "text", entities)
|
title = scanner._suggest_title(self.document, "text", entities)
|
||||||
|
|
@ -1002,7 +999,7 @@ class TestSuggestTitle(TestCase):
|
||||||
long_org = "A" * 100
|
long_org = "A" * 100
|
||||||
entities = {
|
entities = {
|
||||||
"organizations": [{"text": long_org}],
|
"organizations": [{"text": long_org}],
|
||||||
"dates": [{"text": "2024-01-01"}]
|
"dates": [{"text": "2024-01-01"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
title = scanner._suggest_title(self.document, "text", entities)
|
title = scanner._suggest_title(self.document, "text", entities)
|
||||||
|
|
@ -1010,7 +1007,7 @@ class TestSuggestTitle(TestCase):
|
||||||
self.assertIsNotNone(title)
|
self.assertIsNotNone(title)
|
||||||
self.assertLessEqual(len(title), 127)
|
self.assertLessEqual(len(title), 127)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_suggest_title_handles_exception(self, mock_logger):
|
def test_suggest_title_handles_exception(self, mock_logger):
|
||||||
"""Test title suggestion handles exceptions."""
|
"""Test title suggestion handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -1034,7 +1031,7 @@ class TestExtractTables(TestCase):
|
||||||
|
|
||||||
mock_extractor = mock.MagicMock()
|
mock_extractor = mock.MagicMock()
|
||||||
mock_extractor.extract_tables_from_image.return_value = [
|
mock_extractor.extract_tables_from_image.return_value = [
|
||||||
{"data": [[1, 2], [3, 4]], "headers": ["A", "B"]}
|
{"data": [[1, 2], [3, 4]], "headers": ["A", "B"]},
|
||||||
]
|
]
|
||||||
scanner._table_extractor = mock_extractor
|
scanner._table_extractor = mock_extractor
|
||||||
|
|
||||||
|
|
@ -1053,7 +1050,7 @@ class TestExtractTables(TestCase):
|
||||||
|
|
||||||
self.assertEqual(tables, [])
|
self.assertEqual(tables, [])
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_extract_tables_handles_exception(self, mock_logger):
|
def test_extract_tables_handles_exception(self, mock_logger):
|
||||||
"""Test table extraction handles exceptions."""
|
"""Test table extraction handles exceptions."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -1075,17 +1072,17 @@ class TestScanDocument(TestCase):
|
||||||
"""Set up test document."""
|
"""Set up test document."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Invoice from ACME Corporation dated 2024-01-01"
|
content="Invoice from ACME Corporation dated 2024-01-01",
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch.object(AIDocumentScanner, '_extract_entities')
|
@mock.patch.object(AIDocumentScanner, "_extract_entities")
|
||||||
@mock.patch.object(AIDocumentScanner, '_suggest_tags')
|
@mock.patch.object(AIDocumentScanner, "_suggest_tags")
|
||||||
@mock.patch.object(AIDocumentScanner, '_detect_correspondent')
|
@mock.patch.object(AIDocumentScanner, "_detect_correspondent")
|
||||||
@mock.patch.object(AIDocumentScanner, '_classify_document_type')
|
@mock.patch.object(AIDocumentScanner, "_classify_document_type")
|
||||||
@mock.patch.object(AIDocumentScanner, '_suggest_storage_path')
|
@mock.patch.object(AIDocumentScanner, "_suggest_storage_path")
|
||||||
@mock.patch.object(AIDocumentScanner, '_extract_custom_fields')
|
@mock.patch.object(AIDocumentScanner, "_extract_custom_fields")
|
||||||
@mock.patch.object(AIDocumentScanner, '_suggest_workflows')
|
@mock.patch.object(AIDocumentScanner, "_suggest_workflows")
|
||||||
@mock.patch.object(AIDocumentScanner, '_suggest_title')
|
@mock.patch.object(AIDocumentScanner, "_suggest_title")
|
||||||
def test_scan_document_orchestrates_all_methods(
|
def test_scan_document_orchestrates_all_methods(
|
||||||
self,
|
self,
|
||||||
mock_title,
|
mock_title,
|
||||||
|
|
@ -1095,7 +1092,7 @@ class TestScanDocument(TestCase):
|
||||||
mock_doc_type,
|
mock_doc_type,
|
||||||
mock_correspondent,
|
mock_correspondent,
|
||||||
mock_tags,
|
mock_tags,
|
||||||
mock_entities
|
mock_entities,
|
||||||
):
|
):
|
||||||
"""Test that scan_document calls all extraction methods."""
|
"""Test that scan_document calls all extraction methods."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -1127,26 +1124,26 @@ class TestScanDocument(TestCase):
|
||||||
self.assertEqual(result.correspondent, (1, 0.90))
|
self.assertEqual(result.correspondent, (1, 0.90))
|
||||||
self.assertEqual(result.document_type, (1, 0.80))
|
self.assertEqual(result.document_type, (1, 0.80))
|
||||||
|
|
||||||
@mock.patch.object(AIDocumentScanner, '_extract_tables')
|
@mock.patch.object(AIDocumentScanner, "_extract_tables")
|
||||||
def test_scan_document_extracts_tables_when_enabled(self, mock_extract_tables):
|
def test_scan_document_extracts_tables_when_enabled(self, mock_extract_tables):
|
||||||
"""Test that tables are extracted when OCR is enabled and file path provided."""
|
"""Test that tables are extracted when OCR is enabled and file path provided."""
|
||||||
scanner = AIDocumentScanner(enable_advanced_ocr=True)
|
scanner = AIDocumentScanner(enable_advanced_ocr=True)
|
||||||
mock_extract_tables.return_value = [{"data": "test"}]
|
mock_extract_tables.return_value = [{"data": "test"}]
|
||||||
|
|
||||||
# Mock other methods to avoid complexity
|
# Mock other methods to avoid complexity
|
||||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result = scanner.scan_document(
|
result = scanner.scan_document(
|
||||||
self.document,
|
self.document,
|
||||||
"Document text",
|
"Document text",
|
||||||
original_file_path="/path/to/file.pdf"
|
original_file_path="/path/to/file.pdf",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_extract_tables.assert_called_once_with("/path/to/file.pdf")
|
mock_extract_tables.assert_called_once_with("/path/to/file.pdf")
|
||||||
|
|
@ -1156,15 +1153,15 @@ class TestScanDocument(TestCase):
|
||||||
"""Test that tables are not extracted when file path is not provided."""
|
"""Test that tables are not extracted when file path is not provided."""
|
||||||
scanner = AIDocumentScanner(enable_advanced_ocr=True)
|
scanner = AIDocumentScanner(enable_advanced_ocr=True)
|
||||||
|
|
||||||
with mock.patch.object(scanner, '_extract_tables') as mock_extract_tables, \
|
with mock.patch.object(scanner, "_extract_tables") as mock_extract_tables, \
|
||||||
mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result = scanner.scan_document(self.document, "Document text")
|
result = scanner.scan_document(self.document, "Document text")
|
||||||
|
|
||||||
|
|
@ -1183,11 +1180,11 @@ class TestApplyScanResults(TestCase):
|
||||||
self.doc_type = DocumentType.objects.create(name="Invoice")
|
self.doc_type = DocumentType.objects.create(name="Invoice")
|
||||||
self.storage_path = StoragePath.objects.create(
|
self.storage_path = StoragePath.objects.create(
|
||||||
name="Invoices",
|
name="Invoices",
|
||||||
path="/invoices"
|
path="/invoices",
|
||||||
)
|
)
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_apply_scan_results_auto_applies_high_confidence(self):
|
def test_apply_scan_results_auto_applies_high_confidence(self):
|
||||||
|
|
@ -1203,7 +1200,7 @@ class TestApplyScanResults(TestCase):
|
||||||
result = scanner.apply_scan_results(
|
result = scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify auto-applied
|
# Verify auto-applied
|
||||||
|
|
@ -1222,7 +1219,7 @@ class TestApplyScanResults(TestCase):
|
||||||
"""Test that medium confidence items are suggested, not applied."""
|
"""Test that medium confidence items are suggested, not applied."""
|
||||||
scanner = AIDocumentScanner(
|
scanner = AIDocumentScanner(
|
||||||
auto_apply_threshold=0.80,
|
auto_apply_threshold=0.80,
|
||||||
suggest_threshold=0.60
|
suggest_threshold=0.60,
|
||||||
)
|
)
|
||||||
|
|
||||||
scan_result = AIScanResult()
|
scan_result = AIScanResult()
|
||||||
|
|
@ -1232,7 +1229,7 @@ class TestApplyScanResults(TestCase):
|
||||||
result = scanner.apply_scan_results(
|
result = scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify suggested but not applied
|
# Verify suggested but not applied
|
||||||
|
|
@ -1255,7 +1252,7 @@ class TestApplyScanResults(TestCase):
|
||||||
result = scanner.apply_scan_results(
|
result = scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=False
|
auto_apply=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify nothing was applied
|
# Verify nothing was applied
|
||||||
|
|
@ -1268,21 +1265,21 @@ class TestApplyScanResults(TestCase):
|
||||||
scan_result = AIScanResult()
|
scan_result = AIScanResult()
|
||||||
scan_result.correspondent = (self.correspondent.id, 0.90)
|
scan_result.correspondent = (self.correspondent.id, 0.90)
|
||||||
|
|
||||||
with mock.patch.object(self.document, 'save',
|
with mock.patch.object(self.document, "save",
|
||||||
side_effect=Exception("Save failed")):
|
side_effect=Exception("Save failed")):
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
scanner.apply_scan_results(
|
scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify transaction was rolled back
|
# Verify transaction was rolled back
|
||||||
self.document.refresh_from_db()
|
self.document.refresh_from_db()
|
||||||
self.assertIsNone(self.document.correspondent)
|
self.assertIsNone(self.document.correspondent)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.logger')
|
@mock.patch("documents.ai_scanner.logger")
|
||||||
def test_apply_scan_results_handles_exception(self, mock_logger):
|
def test_apply_scan_results_handles_exception(self, mock_logger):
|
||||||
"""Test that apply_scan_results handles exceptions gracefully."""
|
"""Test that apply_scan_results handles exceptions gracefully."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
@ -1293,7 +1290,7 @@ class TestApplyScanResults(TestCase):
|
||||||
scanner.apply_scan_results(
|
scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_logger.error.assert_called()
|
mock_logger.error.assert_called()
|
||||||
|
|
@ -1323,21 +1320,21 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
||||||
"""Set up test document."""
|
"""Set up test document."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_scan_document_with_empty_text(self):
|
def test_scan_document_with_empty_text(self):
|
||||||
"""Test scanning document with empty text."""
|
"""Test scanning document with empty text."""
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
|
|
||||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result = scanner.scan_document(self.document, "")
|
result = scanner.scan_document(self.document, "")
|
||||||
|
|
||||||
|
|
@ -1349,14 +1346,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
long_text = "A" * 100000
|
long_text = "A" * 100000
|
||||||
|
|
||||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result = scanner.scan_document(self.document, long_text)
|
result = scanner.scan_document(self.document, long_text)
|
||||||
|
|
||||||
|
|
@ -1367,14 +1364,14 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
||||||
scanner = AIDocumentScanner()
|
scanner = AIDocumentScanner()
|
||||||
special_text = "Test with émojis 😀 and special chars: <>{}[]|\\`~"
|
special_text = "Test with émojis 😀 and special chars: <>{}[]|\\`~"
|
||||||
|
|
||||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result = scanner.scan_document(self.document, special_text)
|
result = scanner.scan_document(self.document, special_text)
|
||||||
|
|
||||||
|
|
@ -1388,7 +1385,7 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
||||||
result = scanner.apply_scan_results(
|
result = scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(result["applied"]["tags"], [])
|
self.assertEqual(result["applied"]["tags"], [])
|
||||||
|
|
@ -1403,12 +1400,12 @@ class TestEdgeCasesAndErrorHandling(TestCase):
|
||||||
# Test extreme values
|
# Test extreme values
|
||||||
scanner_low = AIDocumentScanner(
|
scanner_low = AIDocumentScanner(
|
||||||
auto_apply_threshold=0.01,
|
auto_apply_threshold=0.01,
|
||||||
suggest_threshold=0.01
|
suggest_threshold=0.01,
|
||||||
)
|
)
|
||||||
self.assertEqual(scanner_low.auto_apply_threshold, 0.01)
|
self.assertEqual(scanner_low.auto_apply_threshold, 0.01)
|
||||||
|
|
||||||
scanner_high = AIDocumentScanner(
|
scanner_high = AIDocumentScanner(
|
||||||
auto_apply_threshold=0.99,
|
auto_apply_threshold=0.99,
|
||||||
suggest_threshold=0.80
|
suggest_threshold=0.80,
|
||||||
)
|
)
|
||||||
self.assertEqual(scanner_high.auto_apply_threshold, 0.99)
|
self.assertEqual(scanner_high.auto_apply_threshold, 0.99)
|
||||||
|
|
|
||||||
|
|
@ -8,24 +8,21 @@ document consumption to metadata application.
|
||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.test import TestCase, TransactionTestCase
|
from django.test import TestCase
|
||||||
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
from documents.ai_scanner import (
|
from documents.ai_scanner import AIDocumentScanner
|
||||||
AIDocumentScanner,
|
from documents.ai_scanner import AIScanResult
|
||||||
AIScanResult,
|
from documents.ai_scanner import get_ai_scanner
|
||||||
get_ai_scanner,
|
from documents.models import Correspondent
|
||||||
)
|
from documents.models import CustomField
|
||||||
from documents.models import (
|
from documents.models import Document
|
||||||
Correspondent,
|
from documents.models import DocumentType
|
||||||
CustomField,
|
from documents.models import StoragePath
|
||||||
Document,
|
from documents.models import Tag
|
||||||
DocumentType,
|
from documents.models import Workflow
|
||||||
StoragePath,
|
from documents.models import WorkflowAction
|
||||||
Tag,
|
from documents.models import WorkflowTrigger
|
||||||
Workflow,
|
|
||||||
WorkflowTrigger,
|
|
||||||
WorkflowAction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAIScannerIntegrationBasic(TestCase):
|
class TestAIScannerIntegrationBasic(TestCase):
|
||||||
|
|
@ -35,49 +32,49 @@ class TestAIScannerIntegrationBasic(TestCase):
|
||||||
"""Set up test data."""
|
"""Set up test data."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Invoice from ACME Corporation",
|
title="Invoice from ACME Corporation",
|
||||||
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000"
|
content="Invoice #12345 from ACME Corporation dated 2024-01-01. Total: $1,000",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tag_invoice = Tag.objects.create(
|
self.tag_invoice = Tag.objects.create(
|
||||||
name="Invoice",
|
name="Invoice",
|
||||||
matching_algorithm=Tag.MATCH_AUTO,
|
matching_algorithm=Tag.MATCH_AUTO,
|
||||||
match="invoice"
|
match="invoice",
|
||||||
)
|
)
|
||||||
self.tag_important = Tag.objects.create(
|
self.tag_important = Tag.objects.create(
|
||||||
name="Important",
|
name="Important",
|
||||||
matching_algorithm=Tag.MATCH_AUTO,
|
matching_algorithm=Tag.MATCH_AUTO,
|
||||||
match="total"
|
match="total",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.correspondent = Correspondent.objects.create(
|
self.correspondent = Correspondent.objects.create(
|
||||||
name="ACME Corporation",
|
name="ACME Corporation",
|
||||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||||
match="acme"
|
match="acme",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.doc_type = DocumentType.objects.create(
|
self.doc_type = DocumentType.objects.create(
|
||||||
name="Invoice",
|
name="Invoice",
|
||||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||||
match="invoice"
|
match="invoice",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.storage_path = StoragePath.objects.create(
|
self.storage_path = StoragePath.objects.create(
|
||||||
name="Invoices",
|
name="Invoices",
|
||||||
path="/invoices",
|
path="/invoices",
|
||||||
matching_algorithm=StoragePath.MATCH_AUTO,
|
matching_algorithm=StoragePath.MATCH_AUTO,
|
||||||
match="invoice"
|
match="invoice",
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_tags')
|
@mock.patch("documents.ai_scanner.match_tags")
|
||||||
@mock.patch('documents.ai_scanner.match_correspondents')
|
@mock.patch("documents.ai_scanner.match_correspondents")
|
||||||
@mock.patch('documents.ai_scanner.match_document_types')
|
@mock.patch("documents.ai_scanner.match_document_types")
|
||||||
@mock.patch('documents.ai_scanner.match_storage_paths')
|
@mock.patch("documents.ai_scanner.match_storage_paths")
|
||||||
def test_full_scan_and_apply_workflow(
|
def test_full_scan_and_apply_workflow(
|
||||||
self,
|
self,
|
||||||
mock_storage,
|
mock_storage,
|
||||||
mock_types,
|
mock_types,
|
||||||
mock_correspondents,
|
mock_correspondents,
|
||||||
mock_tags
|
mock_tags,
|
||||||
):
|
):
|
||||||
"""Test complete workflow from scan to application."""
|
"""Test complete workflow from scan to application."""
|
||||||
# Mock the matching functions to return our test data
|
# Mock the matching functions to return our test data
|
||||||
|
|
@ -91,7 +88,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
||||||
# Scan the document
|
# Scan the document
|
||||||
scan_result = scanner.scan_document(
|
scan_result = scanner.scan_document(
|
||||||
self.document,
|
self.document,
|
||||||
self.document.content
|
self.document.content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify scan results
|
# Verify scan results
|
||||||
|
|
@ -105,7 +102,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
||||||
result = scanner.apply_scan_results(
|
result = scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify application
|
# Verify application
|
||||||
|
|
@ -118,7 +115,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
||||||
self.assertEqual(self.document.document_type, self.doc_type)
|
self.assertEqual(self.document.document_type, self.doc_type)
|
||||||
self.assertEqual(self.document.storage_path, self.storage_path)
|
self.assertEqual(self.document.storage_path, self.storage_path)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.match_tags')
|
@mock.patch("documents.ai_scanner.match_tags")
|
||||||
def test_scan_with_no_matches(self, mock_tags):
|
def test_scan_with_no_matches(self, mock_tags):
|
||||||
"""Test scanning when no matches are found."""
|
"""Test scanning when no matches are found."""
|
||||||
mock_tags.return_value = []
|
mock_tags.return_value = []
|
||||||
|
|
@ -127,7 +124,7 @@ class TestAIScannerIntegrationBasic(TestCase):
|
||||||
|
|
||||||
scan_result = scanner.scan_document(
|
scan_result = scanner.scan_document(
|
||||||
self.document,
|
self.document,
|
||||||
"Some random text with no matches"
|
"Some random text with no matches",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should return empty results
|
# Should return empty results
|
||||||
|
|
@ -143,24 +140,24 @@ class TestAIScannerIntegrationCustomFields(TestCase):
|
||||||
"""Set up test data with custom fields."""
|
"""Set up test data with custom fields."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Invoice",
|
title="Invoice",
|
||||||
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com"
|
content="Invoice #INV-123 dated 2024-01-01. Amount: $1,500. Contact: john@example.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.field_date = CustomField.objects.create(
|
self.field_date = CustomField.objects.create(
|
||||||
name="Invoice Date",
|
name="Invoice Date",
|
||||||
data_type=CustomField.FieldDataType.DATE
|
data_type=CustomField.FieldDataType.DATE,
|
||||||
)
|
)
|
||||||
self.field_number = CustomField.objects.create(
|
self.field_number = CustomField.objects.create(
|
||||||
name="Invoice Number",
|
name="Invoice Number",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_amount = CustomField.objects.create(
|
self.field_amount = CustomField.objects.create(
|
||||||
name="Total Amount",
|
name="Total Amount",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
self.field_email = CustomField.objects.create(
|
self.field_email = CustomField.objects.create(
|
||||||
name="Contact Email",
|
name="Contact Email",
|
||||||
data_type=CustomField.FieldDataType.STRING
|
data_type=CustomField.FieldDataType.STRING,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_custom_field_extraction_integration(self):
|
def test_custom_field_extraction_integration(self):
|
||||||
|
|
@ -173,7 +170,7 @@ class TestAIScannerIntegrationCustomFields(TestCase):
|
||||||
"dates": [{"text": "2024-01-01"}],
|
"dates": [{"text": "2024-01-01"}],
|
||||||
"amounts": [{"text": "$1,500"}],
|
"amounts": [{"text": "$1,500"}],
|
||||||
"invoice_numbers": ["INV-123"],
|
"invoice_numbers": ["INV-123"],
|
||||||
"emails": ["john@example.com"]
|
"emails": ["john@example.com"],
|
||||||
}
|
}
|
||||||
scanner._ner_extractor = mock_ner
|
scanner._ner_extractor = mock_ner
|
||||||
|
|
||||||
|
|
@ -196,29 +193,29 @@ class TestAIScannerIntegrationWorkflows(TestCase):
|
||||||
"""Set up test workflows."""
|
"""Set up test workflows."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Invoice",
|
title="Invoice",
|
||||||
content="Invoice document"
|
content="Invoice document",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.workflow1 = Workflow.objects.create(
|
self.workflow1 = Workflow.objects.create(
|
||||||
name="Invoice Processing",
|
name="Invoice Processing",
|
||||||
enabled=True
|
enabled=True,
|
||||||
)
|
)
|
||||||
self.trigger1 = WorkflowTrigger.objects.create(
|
self.trigger1 = WorkflowTrigger.objects.create(
|
||||||
workflow=self.workflow1,
|
workflow=self.workflow1,
|
||||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||||
)
|
)
|
||||||
self.action1 = WorkflowAction.objects.create(
|
self.action1 = WorkflowAction.objects.create(
|
||||||
workflow=self.workflow1,
|
workflow=self.workflow1,
|
||||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT
|
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.workflow2 = Workflow.objects.create(
|
self.workflow2 = Workflow.objects.create(
|
||||||
name="Archive Documents",
|
name="Archive Documents",
|
||||||
enabled=True
|
enabled=True,
|
||||||
)
|
)
|
||||||
self.trigger2 = WorkflowTrigger.objects.create(
|
self.trigger2 = WorkflowTrigger.objects.create(
|
||||||
workflow=self.workflow2,
|
workflow=self.workflow2,
|
||||||
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION
|
type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_workflow_suggestion_integration(self):
|
def test_workflow_suggestion_integration(self):
|
||||||
|
|
@ -234,7 +231,7 @@ class TestAIScannerIntegrationWorkflows(TestCase):
|
||||||
workflows = scanner._suggest_workflows(
|
workflows = scanner._suggest_workflows(
|
||||||
self.document,
|
self.document,
|
||||||
self.document.content,
|
self.document.content,
|
||||||
scan_result
|
scan_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should suggest workflows
|
# Should suggest workflows
|
||||||
|
|
@ -250,7 +247,7 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
|
||||||
"""Set up test data."""
|
"""Set up test data."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
self.tag = Tag.objects.create(name="TestTag")
|
self.tag = Tag.objects.create(name="TestTag")
|
||||||
self.correspondent = Correspondent.objects.create(name="TestCorp")
|
self.correspondent = Correspondent.objects.create(name="TestCorp")
|
||||||
|
|
@ -273,12 +270,12 @@ class TestAIScannerIntegrationTransactions(TransactionTestCase):
|
||||||
raise Exception("Forced save failure")
|
raise Exception("Forced save failure")
|
||||||
return original_save(self, *args, **kwargs)
|
return original_save(self, *args, **kwargs)
|
||||||
|
|
||||||
with mock.patch.object(Document, 'save', failing_save):
|
with mock.patch.object(Document, "save", failing_save):
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
scanner.apply_scan_results(
|
scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify changes were rolled back
|
# Verify changes were rolled back
|
||||||
|
|
@ -297,19 +294,19 @@ class TestAIScannerIntegrationPerformance(TestCase):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
doc = Document.objects.create(
|
doc = Document.objects.create(
|
||||||
title=f"Document {i}",
|
title=f"Document {i}",
|
||||||
content=f"Content for document {i}"
|
content=f"Content for document {i}",
|
||||||
)
|
)
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
|
|
||||||
# Mock to avoid actual ML loading
|
# Mock to avoid actual ML loading
|
||||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
|
|
@ -329,16 +326,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
|
||||||
"""Set up test data."""
|
"""Set up test data."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Business Invoice",
|
title="Business Invoice",
|
||||||
content="Invoice from ACME Corporation"
|
content="Invoice from ACME Corporation",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.correspondent_acme = Correspondent.objects.create(
|
self.correspondent_acme = Correspondent.objects.create(
|
||||||
name="ACME Corporation",
|
name="ACME Corporation",
|
||||||
matching_algorithm=Correspondent.MATCH_AUTO
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
self.correspondent_other = Correspondent.objects.create(
|
self.correspondent_other = Correspondent.objects.create(
|
||||||
name="Other Company",
|
name="Other Company",
|
||||||
matching_algorithm=Correspondent.MATCH_AUTO
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_correspondent_matching_with_ner_entities(self):
|
def test_correspondent_matching_with_ner_entities(self):
|
||||||
|
|
@ -348,16 +345,16 @@ class TestAIScannerIntegrationEntityMatching(TestCase):
|
||||||
# Mock NER to extract organization
|
# Mock NER to extract organization
|
||||||
mock_ner = mock.MagicMock()
|
mock_ner = mock.MagicMock()
|
||||||
mock_ner.extract_all.return_value = {
|
mock_ner.extract_all.return_value = {
|
||||||
"organizations": [{"text": "ACME Corporation"}]
|
"organizations": [{"text": "ACME Corporation"}],
|
||||||
}
|
}
|
||||||
scanner._ner_extractor = mock_ner
|
scanner._ner_extractor = mock_ner
|
||||||
|
|
||||||
# Mock matching to return empty (so NER-based matching is used)
|
# Mock matching to return empty (so NER-based matching is used)
|
||||||
with mock.patch('documents.ai_scanner.match_correspondents', return_value=[]):
|
with mock.patch("documents.ai_scanner.match_correspondents", return_value=[]):
|
||||||
result = scanner._detect_correspondent(
|
result = scanner._detect_correspondent(
|
||||||
self.document,
|
self.document,
|
||||||
self.document.content,
|
self.document.content,
|
||||||
{"organizations": [{"text": "ACME Corporation"}]}
|
{"organizations": [{"text": "ACME Corporation"}]},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should detect ACME correspondent
|
# Should detect ACME correspondent
|
||||||
|
|
@ -375,13 +372,13 @@ class TestAIScannerIntegrationTitleGeneration(TestCase):
|
||||||
|
|
||||||
document = Document.objects.create(
|
document = Document.objects.create(
|
||||||
title="document.pdf",
|
title="document.pdf",
|
||||||
content="Invoice from ACME Corp dated 2024-01-15"
|
content="Invoice from ACME Corp dated 2024-01-15",
|
||||||
)
|
)
|
||||||
|
|
||||||
entities = {
|
entities = {
|
||||||
"document_type": "Invoice",
|
"document_type": "Invoice",
|
||||||
"organizations": [{"text": "ACME Corp"}],
|
"organizations": [{"text": "ACME Corp"}],
|
||||||
"dates": [{"text": "2024-01-15"}]
|
"dates": [{"text": "2024-01-15"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
title = scanner._suggest_title(document, document.content, entities)
|
title = scanner._suggest_title(document, document.content, entities)
|
||||||
|
|
@ -399,7 +396,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
||||||
"""Set up test data."""
|
"""Set up test data."""
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test",
|
title="Test",
|
||||||
content="Test"
|
content="Test",
|
||||||
)
|
)
|
||||||
self.tag_high = Tag.objects.create(name="HighConfidence")
|
self.tag_high = Tag.objects.create(name="HighConfidence")
|
||||||
self.tag_medium = Tag.objects.create(name="MediumConfidence")
|
self.tag_medium = Tag.objects.create(name="MediumConfidence")
|
||||||
|
|
@ -409,7 +406,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
||||||
"""Test that only high confidence suggestions are auto-applied."""
|
"""Test that only high confidence suggestions are auto-applied."""
|
||||||
scanner = AIDocumentScanner(
|
scanner = AIDocumentScanner(
|
||||||
auto_apply_threshold=0.80,
|
auto_apply_threshold=0.80,
|
||||||
suggest_threshold=0.60
|
suggest_threshold=0.60,
|
||||||
)
|
)
|
||||||
|
|
||||||
scan_result = AIScanResult()
|
scan_result = AIScanResult()
|
||||||
|
|
@ -422,7 +419,7 @@ class TestAIScannerIntegrationConfidenceLevels(TestCase):
|
||||||
result = scanner.apply_scan_results(
|
result = scanner.apply_scan_results(
|
||||||
self.document,
|
self.document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify high confidence was applied
|
# Verify high confidence was applied
|
||||||
|
|
@ -448,17 +445,17 @@ class TestAIScannerIntegrationGlobalInstance(TestCase):
|
||||||
# Should be functional
|
# Should be functional
|
||||||
document = Document.objects.create(
|
document = Document.objects.create(
|
||||||
title="Test",
|
title="Test",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(scanner1, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner1, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner1, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner1, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner1, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner1, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner1, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner1, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner1, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner1, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner1, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner1, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner1, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner1, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner1, '_suggest_title', return_value=None):
|
mock.patch.object(scanner1, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result1 = scanner1.scan_document(document, document.content)
|
result1 = scanner1.scan_document(document, document.content)
|
||||||
result2 = scanner2.scan_document(document, document.content)
|
result2 = scanner2.scan_document(document, document.content)
|
||||||
|
|
@ -476,17 +473,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
||||||
|
|
||||||
document = Document.objects.create(
|
document = Document.objects.create(
|
||||||
title="",
|
title="",
|
||||||
content=""
|
content="",
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result = scanner.scan_document(document, document.content)
|
result = scanner.scan_document(document, document.content)
|
||||||
|
|
||||||
|
|
@ -498,7 +495,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
||||||
|
|
||||||
document = Document.objects.create(
|
document = Document.objects.create(
|
||||||
title="Test",
|
title="Test",
|
||||||
content="Test"
|
content="Test",
|
||||||
)
|
)
|
||||||
|
|
||||||
scan_result = AIScanResult()
|
scan_result = AIScanResult()
|
||||||
|
|
@ -509,7 +506,7 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
||||||
result = scanner.apply_scan_results(
|
result = scanner.apply_scan_results(
|
||||||
document,
|
document,
|
||||||
scan_result,
|
scan_result,
|
||||||
auto_apply=True
|
auto_apply=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not crash, just log errors
|
# Should not crash, just log errors
|
||||||
|
|
@ -521,17 +518,17 @@ class TestAIScannerIntegrationEdgeCases(TestCase):
|
||||||
|
|
||||||
document = Document.objects.create(
|
document = Document.objects.create(
|
||||||
title="Factura - España 🇪🇸",
|
title="Factura - España 🇪🇸",
|
||||||
content="Société française • 日本語 • Ελληνικά • مرحبا"
|
content="Société française • 日本語 • Ελληνικά • مرحبا",
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(scanner, '_extract_entities', return_value={}), \
|
with mock.patch.object(scanner, "_extract_entities", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_tags', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_tags", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_detect_correspondent', return_value=None), \
|
mock.patch.object(scanner, "_detect_correspondent", return_value=None), \
|
||||||
mock.patch.object(scanner, '_classify_document_type', return_value=None), \
|
mock.patch.object(scanner, "_classify_document_type", return_value=None), \
|
||||||
mock.patch.object(scanner, '_suggest_storage_path', return_value=None), \
|
mock.patch.object(scanner, "_suggest_storage_path", return_value=None), \
|
||||||
mock.patch.object(scanner, '_extract_custom_fields', return_value={}), \
|
mock.patch.object(scanner, "_extract_custom_fields", return_value={}), \
|
||||||
mock.patch.object(scanner, '_suggest_workflows', return_value=[]), \
|
mock.patch.object(scanner, "_suggest_workflows", return_value=[]), \
|
||||||
mock.patch.object(scanner, '_suggest_title', return_value=None):
|
mock.patch.object(scanner, "_suggest_title", return_value=None):
|
||||||
|
|
||||||
result = scanner.scan_document(document, document.content)
|
result = scanner.scan_document(document, document.content)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,18 +12,17 @@ Tests cover:
|
||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.contrib.auth.models import Permission, User
|
from django.contrib.auth.models import Permission
|
||||||
|
from django.contrib.auth.models import User
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from documents.models import (
|
from documents.models import Correspondent
|
||||||
Correspondent,
|
from documents.models import DeletionRequest
|
||||||
DeletionRequest,
|
from documents.models import Document
|
||||||
Document,
|
from documents.models import DocumentType
|
||||||
DocumentType,
|
from documents.models import Tag
|
||||||
Tag,
|
|
||||||
)
|
|
||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,13 +35,13 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.user_with_permission = User.objects.create_user(
|
self.user_with_permission = User.objects.create_user(
|
||||||
username="permitted", email="permitted@test.com", password="permitted123"
|
username="permitted", email="permitted@test.com", password="permitted123",
|
||||||
)
|
)
|
||||||
self.user_without_permission = User.objects.create_user(
|
self.user_without_permission = User.objects.create_user(
|
||||||
username="regular", email="regular@test.com", password="regular123"
|
username="regular", email="regular@test.com", password="regular123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign view permission
|
# Assign view permission
|
||||||
|
|
@ -57,7 +56,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
# Create test document
|
# Create test document
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="This is a test invoice from ACME Corporation"
|
content="This is a test invoice from ACME Corporation",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create test metadata objects
|
# Create test metadata objects
|
||||||
|
|
@ -70,7 +69,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/suggestions/",
|
"/api/ai/suggestions/",
|
||||||
{"document_id": self.document.id},
|
{"document_id": self.document.id},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
@ -82,7 +81,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/suggestions/",
|
"/api/ai/suggestions/",
|
||||||
{"document_id": self.document.id},
|
{"document_id": self.document.id},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
|
|
@ -91,7 +90,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"""Test that superusers can access the endpoint."""
|
"""Test that superusers can access the endpoint."""
|
||||||
self.client.force_authenticate(user=self.superuser)
|
self.client.force_authenticate(user=self.superuser)
|
||||||
|
|
||||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||||
# Mock the scanner response
|
# Mock the scanner response
|
||||||
mock_scan_result = mock.MagicMock()
|
mock_scan_result = mock.MagicMock()
|
||||||
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
||||||
|
|
@ -108,7 +107,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/suggestions/",
|
"/api/ai/suggestions/",
|
||||||
{"document_id": self.document.id},
|
{"document_id": self.document.id},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -119,7 +118,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"""Test that users with permission can access the endpoint."""
|
"""Test that users with permission can access the endpoint."""
|
||||||
self.client.force_authenticate(user=self.user_with_permission)
|
self.client.force_authenticate(user=self.user_with_permission)
|
||||||
|
|
||||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||||
# Mock the scanner response
|
# Mock the scanner response
|
||||||
mock_scan_result = mock.MagicMock()
|
mock_scan_result = mock.MagicMock()
|
||||||
mock_scan_result.tags = []
|
mock_scan_result.tags = []
|
||||||
|
|
@ -136,7 +135,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/suggestions/",
|
"/api/ai/suggestions/",
|
||||||
{"document_id": self.document.id},
|
{"document_id": self.document.id},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -148,7 +147,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/suggestions/",
|
"/api/ai/suggestions/",
|
||||||
{"document_id": 99999},
|
{"document_id": 99999},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
@ -160,7 +159,7 @@ class TestAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/suggestions/",
|
"/api/ai/suggestions/",
|
||||||
{},
|
{},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
@ -175,10 +174,10 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.user_with_permission = User.objects.create_user(
|
self.user_with_permission = User.objects.create_user(
|
||||||
username="permitted", email="permitted@test.com", password="permitted123"
|
username="permitted", email="permitted@test.com", password="permitted123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign apply permission
|
# Assign apply permission
|
||||||
|
|
@ -193,7 +192,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
# Create test document
|
# Create test document
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Document",
|
title="Test Document",
|
||||||
content="Test content"
|
content="Test content",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create test metadata
|
# Create test metadata
|
||||||
|
|
@ -205,7 +204,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/suggestions/apply/",
|
"/api/ai/suggestions/apply/",
|
||||||
{"document_id": self.document.id},
|
{"document_id": self.document.id},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
@ -214,7 +213,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"""Test successfully applying tag suggestions."""
|
"""Test successfully applying tag suggestions."""
|
||||||
self.client.force_authenticate(user=self.superuser)
|
self.client.force_authenticate(user=self.superuser)
|
||||||
|
|
||||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||||
# Mock the scanner response
|
# Mock the scanner response
|
||||||
mock_scan_result = mock.MagicMock()
|
mock_scan_result = mock.MagicMock()
|
||||||
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
mock_scan_result.tags = [(self.tag.id, 0.85)]
|
||||||
|
|
@ -233,9 +232,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/suggestions/apply/",
|
"/api/ai/suggestions/apply/",
|
||||||
{
|
{
|
||||||
"document_id": self.document.id,
|
"document_id": self.document.id,
|
||||||
"apply_tags": True
|
"apply_tags": True,
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -245,7 +244,7 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"""Test successfully applying correspondent suggestion."""
|
"""Test successfully applying correspondent suggestion."""
|
||||||
self.client.force_authenticate(user=self.superuser)
|
self.client.force_authenticate(user=self.superuser)
|
||||||
|
|
||||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||||
# Mock the scanner response
|
# Mock the scanner response
|
||||||
mock_scan_result = mock.MagicMock()
|
mock_scan_result = mock.MagicMock()
|
||||||
mock_scan_result.tags = []
|
mock_scan_result.tags = []
|
||||||
|
|
@ -264,9 +263,9 @@ class TestApplyAISuggestionsEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/suggestions/apply/",
|
"/api/ai/suggestions/apply/",
|
||||||
{
|
{
|
||||||
"document_id": self.document.id,
|
"document_id": self.document.id,
|
||||||
"apply_correspondent": True
|
"apply_correspondent": True,
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -285,10 +284,10 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.user_without_permission = User.objects.create_user(
|
self.user_without_permission = User.objects.create_user(
|
||||||
username="regular", email="regular@test.com", password="regular123"
|
username="regular", email="regular@test.com", password="regular123",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unauthorized_access_denied(self):
|
def test_unauthorized_access_denied(self):
|
||||||
|
|
@ -309,7 +308,7 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"""Test getting AI configuration."""
|
"""Test getting AI configuration."""
|
||||||
self.client.force_authenticate(user=self.superuser)
|
self.client.force_authenticate(user=self.superuser)
|
||||||
|
|
||||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||||
mock_scanner_instance = mock.MagicMock()
|
mock_scanner_instance = mock.MagicMock()
|
||||||
mock_scanner_instance.auto_apply_threshold = 0.80
|
mock_scanner_instance.auto_apply_threshold = 0.80
|
||||||
mock_scanner_instance.suggest_threshold = 0.60
|
mock_scanner_instance.suggest_threshold = 0.60
|
||||||
|
|
@ -331,9 +330,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/config/",
|
"/api/ai/config/",
|
||||||
{
|
{
|
||||||
"auto_apply_threshold": 0.90,
|
"auto_apply_threshold": 0.90,
|
||||||
"suggest_threshold": 0.70
|
"suggest_threshold": 0.70,
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -346,9 +345,9 @@ class TestAIConfigurationEndpoint(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
"/api/ai/config/",
|
"/api/ai/config/",
|
||||||
{
|
{
|
||||||
"auto_apply_threshold": 1.5 # Invalid: > 1.0
|
"auto_apply_threshold": 1.5, # Invalid: > 1.0
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
@ -363,13 +362,13 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
# Create users
|
# Create users
|
||||||
self.superuser = User.objects.create_superuser(
|
self.superuser = User.objects.create_superuser(
|
||||||
username="admin", email="admin@test.com", password="admin123"
|
username="admin", email="admin@test.com", password="admin123",
|
||||||
)
|
)
|
||||||
self.user_with_permission = User.objects.create_user(
|
self.user_with_permission = User.objects.create_user(
|
||||||
username="permitted", email="permitted@test.com", password="permitted123"
|
username="permitted", email="permitted@test.com", password="permitted123",
|
||||||
)
|
)
|
||||||
self.user_without_permission = User.objects.create_user(
|
self.user_without_permission = User.objects.create_user(
|
||||||
username="regular", email="regular@test.com", password="regular123"
|
username="regular", email="regular@test.com", password="regular123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign approval permission
|
# Assign approval permission
|
||||||
|
|
@ -385,7 +384,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
self.deletion_request = DeletionRequest.objects.create(
|
self.deletion_request = DeletionRequest.objects.create(
|
||||||
user=self.user_with_permission,
|
user=self.user_with_permission,
|
||||||
requested_by_ai=True,
|
requested_by_ai=True,
|
||||||
ai_reason="Document appears to be a duplicate"
|
ai_reason="Document appears to be a duplicate",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unauthorized_access_denied(self):
|
def test_unauthorized_access_denied(self):
|
||||||
|
|
@ -394,9 +393,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/deletions/approve/",
|
"/api/ai/deletions/approve/",
|
||||||
{
|
{
|
||||||
"request_id": self.deletion_request.id,
|
"request_id": self.deletion_request.id,
|
||||||
"action": "approve"
|
"action": "approve",
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
@ -409,9 +408,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/deletions/approve/",
|
"/api/ai/deletions/approve/",
|
||||||
{
|
{
|
||||||
"request_id": self.deletion_request.id,
|
"request_id": self.deletion_request.id,
|
||||||
"action": "approve"
|
"action": "approve",
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
|
|
@ -424,9 +423,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/deletions/approve/",
|
"/api/ai/deletions/approve/",
|
||||||
{
|
{
|
||||||
"request_id": self.deletion_request.id,
|
"request_id": self.deletion_request.id,
|
||||||
"action": "approve"
|
"action": "approve",
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -436,7 +435,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
self.deletion_request.refresh_from_db()
|
self.deletion_request.refresh_from_db()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.deletion_request.status,
|
self.deletion_request.status,
|
||||||
DeletionRequest.STATUS_APPROVED
|
DeletionRequest.STATUS_APPROVED,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_reject_deletion_success(self):
|
def test_reject_deletion_success(self):
|
||||||
|
|
@ -448,9 +447,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
{
|
{
|
||||||
"request_id": self.deletion_request.id,
|
"request_id": self.deletion_request.id,
|
||||||
"action": "reject",
|
"action": "reject",
|
||||||
"reason": "Document is still needed"
|
"reason": "Document is still needed",
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -459,7 +458,7 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
self.deletion_request.refresh_from_db()
|
self.deletion_request.refresh_from_db()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.deletion_request.status,
|
self.deletion_request.status,
|
||||||
DeletionRequest.STATUS_REJECTED
|
DeletionRequest.STATUS_REJECTED,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invalid_request_id(self):
|
def test_invalid_request_id(self):
|
||||||
|
|
@ -470,9 +469,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/deletions/approve/",
|
"/api/ai/deletions/approve/",
|
||||||
{
|
{
|
||||||
"request_id": 99999,
|
"request_id": 99999,
|
||||||
"action": "approve"
|
"action": "approve",
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
@ -485,9 +484,9 @@ class TestDeletionApprovalEndpoint(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/deletions/approve/",
|
"/api/ai/deletions/approve/",
|
||||||
{
|
{
|
||||||
"request_id": self.deletion_request.id,
|
"request_id": self.deletion_request.id,
|
||||||
"action": "approve"
|
"action": "approve",
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -502,7 +501,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
# Create user with all AI permissions
|
# Create user with all AI permissions
|
||||||
self.power_user = User.objects.create_user(
|
self.power_user = User.objects.create_user(
|
||||||
username="power_user", email="power@test.com", password="power123"
|
username="power_user", email="power@test.com", password="power123",
|
||||||
)
|
)
|
||||||
|
|
||||||
content_type = ContentType.objects.get_for_model(Document)
|
content_type = ContentType.objects.get_for_model(Document)
|
||||||
|
|
@ -525,7 +524,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
self.document = Document.objects.create(
|
self.document = Document.objects.create(
|
||||||
title="Test Doc",
|
title="Test Doc",
|
||||||
content="Test"
|
content="Test",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_power_user_can_access_all_endpoints(self):
|
def test_power_user_can_access_all_endpoints(self):
|
||||||
|
|
@ -533,7 +532,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
||||||
self.client.force_authenticate(user=self.power_user)
|
self.client.force_authenticate(user=self.power_user)
|
||||||
|
|
||||||
# Test suggestions endpoint
|
# Test suggestions endpoint
|
||||||
with mock.patch('documents.views.get_ai_scanner') as mock_scanner:
|
with mock.patch("documents.views.get_ai_scanner") as mock_scanner:
|
||||||
mock_scan_result = mock.MagicMock()
|
mock_scan_result = mock.MagicMock()
|
||||||
mock_scan_result.tags = []
|
mock_scan_result.tags = []
|
||||||
mock_scan_result.correspondent = None
|
mock_scan_result.correspondent = None
|
||||||
|
|
@ -553,7 +552,7 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
||||||
response1 = self.client.post(
|
response1 = self.client.post(
|
||||||
"/api/ai/suggestions/",
|
"/api/ai/suggestions/",
|
||||||
{"document_id": self.document.id},
|
{"document_id": self.document.id},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
|
@ -562,9 +561,9 @@ class TestEndpointPermissionIntegration(DirectoriesMixin, APITestCase):
|
||||||
"/api/ai/suggestions/apply/",
|
"/api/ai/suggestions/apply/",
|
||||||
{
|
{
|
||||||
"document_id": self.document.id,
|
"document_id": self.document.id,
|
||||||
"apply_tags": False
|
"apply_tags": False,
|
||||||
},
|
},
|
||||||
format="json"
|
format="json",
|
||||||
)
|
)
|
||||||
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,14 +9,12 @@ from rest_framework import status
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from documents.ai_scanner import AIScanResult
|
from documents.ai_scanner import AIScanResult
|
||||||
from documents.models import (
|
from documents.models import AISuggestionFeedback
|
||||||
AISuggestionFeedback,
|
from documents.models import Correspondent
|
||||||
Correspondent,
|
from documents.models import Document
|
||||||
Document,
|
from documents.models import DocumentType
|
||||||
DocumentType,
|
from documents.models import StoragePath
|
||||||
StoragePath,
|
from documents.models import Tag
|
||||||
Tag,
|
|
||||||
)
|
|
||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -64,12 +62,12 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_ai_suggestions_endpoint_exists(self):
|
def test_ai_suggestions_endpoint_exists(self):
|
||||||
"""Test that the ai-suggestions endpoint is accessible."""
|
"""Test that the ai-suggestions endpoint is accessible."""
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||||
)
|
)
|
||||||
# Should not be 404
|
# Should not be 404
|
||||||
self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
self.assertNotEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
@mock.patch('documents.ai_scanner.get_ai_scanner')
|
@mock.patch("documents.ai_scanner.get_ai_scanner")
|
||||||
def test_get_ai_suggestions_success(self, mock_get_scanner):
|
def test_get_ai_suggestions_success(self, mock_get_scanner):
|
||||||
"""Test successfully getting AI suggestions for a document."""
|
"""Test successfully getting AI suggestions for a document."""
|
||||||
# Create mock scan result
|
# Create mock scan result
|
||||||
|
|
@ -87,7 +85,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
# Make request
|
# Make request
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify response
|
# Verify response
|
||||||
|
|
@ -95,23 +93,23 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
# Check tags
|
# Check tags
|
||||||
self.assertIn('tags', data)
|
self.assertIn("tags", data)
|
||||||
self.assertEqual(len(data['tags']), 2)
|
self.assertEqual(len(data["tags"]), 2)
|
||||||
self.assertEqual(data['tags'][0]['id'], self.tag1.id)
|
self.assertEqual(data["tags"][0]["id"], self.tag1.id)
|
||||||
self.assertEqual(data['tags'][0]['confidence'], 0.85)
|
self.assertEqual(data["tags"][0]["confidence"], 0.85)
|
||||||
|
|
||||||
# Check correspondent
|
# Check correspondent
|
||||||
self.assertIn('correspondent', data)
|
self.assertIn("correspondent", data)
|
||||||
self.assertEqual(data['correspondent']['id'], self.correspondent.id)
|
self.assertEqual(data["correspondent"]["id"], self.correspondent.id)
|
||||||
self.assertEqual(data['correspondent']['confidence'], 0.90)
|
self.assertEqual(data["correspondent"]["confidence"], 0.90)
|
||||||
|
|
||||||
# Check document type
|
# Check document type
|
||||||
self.assertIn('document_type', data)
|
self.assertIn("document_type", data)
|
||||||
self.assertEqual(data['document_type']['id'], self.doc_type.id)
|
self.assertEqual(data["document_type"]["id"], self.doc_type.id)
|
||||||
|
|
||||||
# Check title suggestion
|
# Check title suggestion
|
||||||
self.assertIn('title_suggestion', data)
|
self.assertIn("title_suggestion", data)
|
||||||
self.assertEqual(data['title_suggestion']['title'], "Suggested Title")
|
self.assertEqual(data["title_suggestion"]["title"], "Suggested Title")
|
||||||
|
|
||||||
def test_get_ai_suggestions_no_content(self):
|
def test_get_ai_suggestions_no_content(self):
|
||||||
"""Test getting AI suggestions for document without content."""
|
"""Test getting AI suggestions for document without content."""
|
||||||
|
|
@ -126,7 +124,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/")
|
response = self.client.get(f"/api/documents/{doc.pk}/ai-suggestions/")
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
self.assertIn("no content", response.json()['detail'].lower())
|
self.assertIn("no content", response.json()["detail"].lower())
|
||||||
|
|
||||||
def test_get_ai_suggestions_document_not_found(self):
|
def test_get_ai_suggestions_document_not_found(self):
|
||||||
"""Test getting AI suggestions for non-existent document."""
|
"""Test getting AI suggestions for non-existent document."""
|
||||||
|
|
@ -137,19 +135,19 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_apply_suggestion_tag(self):
|
def test_apply_suggestion_tag(self):
|
||||||
"""Test applying a tag suggestion."""
|
"""Test applying a tag suggestion."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'tag',
|
"suggestion_type": "tag",
|
||||||
'value_id': self.tag1.id,
|
"value_id": self.tag1.id,
|
||||||
'confidence': 0.85,
|
"confidence": 0.85,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(response.json()['status'], 'success')
|
self.assertEqual(response.json()["status"], "success")
|
||||||
|
|
||||||
# Verify tag was applied
|
# Verify tag was applied
|
||||||
self.document.refresh_from_db()
|
self.document.refresh_from_db()
|
||||||
|
|
@ -158,7 +156,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
# Verify feedback was recorded
|
# Verify feedback was recorded
|
||||||
feedback = AISuggestionFeedback.objects.filter(
|
feedback = AISuggestionFeedback.objects.filter(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='tag',
|
suggestion_type="tag",
|
||||||
).first()
|
).first()
|
||||||
self.assertIsNotNone(feedback)
|
self.assertIsNotNone(feedback)
|
||||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
||||||
|
|
@ -169,15 +167,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_apply_suggestion_correspondent(self):
|
def test_apply_suggestion_correspondent(self):
|
||||||
"""Test applying a correspondent suggestion."""
|
"""Test applying a correspondent suggestion."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'correspondent',
|
"suggestion_type": "correspondent",
|
||||||
'value_id': self.correspondent.id,
|
"value_id": self.correspondent.id,
|
||||||
'confidence': 0.90,
|
"confidence": 0.90,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -189,7 +187,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
# Verify feedback was recorded
|
# Verify feedback was recorded
|
||||||
feedback = AISuggestionFeedback.objects.filter(
|
feedback = AISuggestionFeedback.objects.filter(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='correspondent',
|
suggestion_type="correspondent",
|
||||||
).first()
|
).first()
|
||||||
self.assertIsNotNone(feedback)
|
self.assertIsNotNone(feedback)
|
||||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_APPLIED)
|
||||||
|
|
@ -197,15 +195,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_apply_suggestion_document_type(self):
|
def test_apply_suggestion_document_type(self):
|
||||||
"""Test applying a document type suggestion."""
|
"""Test applying a document type suggestion."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'document_type',
|
"suggestion_type": "document_type",
|
||||||
'value_id': self.doc_type.id,
|
"value_id": self.doc_type.id,
|
||||||
'confidence': 0.88,
|
"confidence": 0.88,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -217,35 +215,35 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_apply_suggestion_title(self):
|
def test_apply_suggestion_title(self):
|
||||||
"""Test applying a title suggestion."""
|
"""Test applying a title suggestion."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'title',
|
"suggestion_type": "title",
|
||||||
'value_text': 'New Suggested Title',
|
"value_text": "New Suggested Title",
|
||||||
'confidence': 0.80,
|
"confidence": 0.80,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
# Verify title was applied
|
# Verify title was applied
|
||||||
self.document.refresh_from_db()
|
self.document.refresh_from_db()
|
||||||
self.assertEqual(self.document.title, 'New Suggested Title')
|
self.assertEqual(self.document.title, "New Suggested Title")
|
||||||
|
|
||||||
def test_apply_suggestion_invalid_type(self):
|
def test_apply_suggestion_invalid_type(self):
|
||||||
"""Test applying suggestion with invalid type."""
|
"""Test applying suggestion with invalid type."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'invalid_type',
|
"suggestion_type": "invalid_type",
|
||||||
'value_id': 1,
|
"value_id": 1,
|
||||||
'confidence': 0.85,
|
"confidence": 0.85,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
@ -253,14 +251,14 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_apply_suggestion_missing_value(self):
|
def test_apply_suggestion_missing_value(self):
|
||||||
"""Test applying suggestion without value_id or value_text."""
|
"""Test applying suggestion without value_id or value_text."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'tag',
|
"suggestion_type": "tag",
|
||||||
'confidence': 0.85,
|
"confidence": 0.85,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
@ -268,15 +266,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_apply_suggestion_nonexistent_object(self):
|
def test_apply_suggestion_nonexistent_object(self):
|
||||||
"""Test applying suggestion with non-existent object ID."""
|
"""Test applying suggestion with non-existent object ID."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'tag',
|
"suggestion_type": "tag",
|
||||||
'value_id': 99999,
|
"value_id": 99999,
|
||||||
'confidence': 0.85,
|
"confidence": 0.85,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
f"/api/documents/{self.document.pk}/apply-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
@ -284,24 +282,24 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_reject_suggestion(self):
|
def test_reject_suggestion(self):
|
||||||
"""Test rejecting an AI suggestion."""
|
"""Test rejecting an AI suggestion."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'tag',
|
"suggestion_type": "tag",
|
||||||
'value_id': self.tag1.id,
|
"value_id": self.tag1.id,
|
||||||
'confidence': 0.65,
|
"confidence": 0.65,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(response.json()['status'], 'success')
|
self.assertEqual(response.json()["status"], "success")
|
||||||
|
|
||||||
# Verify feedback was recorded
|
# Verify feedback was recorded
|
||||||
feedback = AISuggestionFeedback.objects.filter(
|
feedback = AISuggestionFeedback.objects.filter(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='tag',
|
suggestion_type="tag",
|
||||||
).first()
|
).first()
|
||||||
self.assertIsNotNone(feedback)
|
self.assertIsNotNone(feedback)
|
||||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
||||||
|
|
@ -312,15 +310,15 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
def test_reject_suggestion_with_text(self):
|
def test_reject_suggestion_with_text(self):
|
||||||
"""Test rejecting a suggestion with text value."""
|
"""Test rejecting a suggestion with text value."""
|
||||||
request_data = {
|
request_data = {
|
||||||
'suggestion_type': 'title',
|
"suggestion_type": "title",
|
||||||
'value_text': 'Bad Title Suggestion',
|
"value_text": "Bad Title Suggestion",
|
||||||
'confidence': 0.50,
|
"confidence": 0.50,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
f"/api/documents/{self.document.pk}/reject-suggestion/",
|
||||||
data=request_data,
|
data=request_data,
|
||||||
format='json',
|
format="json",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
@ -328,11 +326,11 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
# Verify feedback was recorded
|
# Verify feedback was recorded
|
||||||
feedback = AISuggestionFeedback.objects.filter(
|
feedback = AISuggestionFeedback.objects.filter(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='title',
|
suggestion_type="title",
|
||||||
).first()
|
).first()
|
||||||
self.assertIsNotNone(feedback)
|
self.assertIsNotNone(feedback)
|
||||||
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
self.assertEqual(feedback.status, AISuggestionFeedback.STATUS_REJECTED)
|
||||||
self.assertEqual(feedback.suggested_value_text, 'Bad Title Suggestion')
|
self.assertEqual(feedback.suggested_value_text, "Bad Title Suggestion")
|
||||||
|
|
||||||
def test_ai_suggestion_stats_empty(self):
|
def test_ai_suggestion_stats_empty(self):
|
||||||
"""Test getting statistics when no feedback exists."""
|
"""Test getting statistics when no feedback exists."""
|
||||||
|
|
@ -341,17 +339,17 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
self.assertEqual(data['total_suggestions'], 0)
|
self.assertEqual(data["total_suggestions"], 0)
|
||||||
self.assertEqual(data['total_applied'], 0)
|
self.assertEqual(data["total_applied"], 0)
|
||||||
self.assertEqual(data['total_rejected'], 0)
|
self.assertEqual(data["total_rejected"], 0)
|
||||||
self.assertEqual(data['accuracy_rate'], 0)
|
self.assertEqual(data["accuracy_rate"], 0)
|
||||||
|
|
||||||
def test_ai_suggestion_stats_with_data(self):
|
def test_ai_suggestion_stats_with_data(self):
|
||||||
"""Test getting statistics with feedback data."""
|
"""Test getting statistics with feedback data."""
|
||||||
# Create some feedback entries
|
# Create some feedback entries
|
||||||
AISuggestionFeedback.objects.create(
|
AISuggestionFeedback.objects.create(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='tag',
|
suggestion_type="tag",
|
||||||
suggested_value_id=self.tag1.id,
|
suggested_value_id=self.tag1.id,
|
||||||
confidence=0.85,
|
confidence=0.85,
|
||||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||||
|
|
@ -359,7 +357,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
)
|
)
|
||||||
AISuggestionFeedback.objects.create(
|
AISuggestionFeedback.objects.create(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='tag',
|
suggestion_type="tag",
|
||||||
suggested_value_id=self.tag2.id,
|
suggested_value_id=self.tag2.id,
|
||||||
confidence=0.70,
|
confidence=0.70,
|
||||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||||
|
|
@ -367,7 +365,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
)
|
)
|
||||||
AISuggestionFeedback.objects.create(
|
AISuggestionFeedback.objects.create(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='correspondent',
|
suggestion_type="correspondent",
|
||||||
suggested_value_id=self.correspondent.id,
|
suggested_value_id=self.correspondent.id,
|
||||||
confidence=0.60,
|
confidence=0.60,
|
||||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||||
|
|
@ -380,25 +378,25 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
# Check overall stats
|
# Check overall stats
|
||||||
self.assertEqual(data['total_suggestions'], 3)
|
self.assertEqual(data["total_suggestions"], 3)
|
||||||
self.assertEqual(data['total_applied'], 2)
|
self.assertEqual(data["total_applied"], 2)
|
||||||
self.assertEqual(data['total_rejected'], 1)
|
self.assertEqual(data["total_rejected"], 1)
|
||||||
self.assertAlmostEqual(data['accuracy_rate'], 66.67, places=1)
|
self.assertAlmostEqual(data["accuracy_rate"], 66.67, places=1)
|
||||||
|
|
||||||
# Check by_type stats
|
# Check by_type stats
|
||||||
self.assertIn('by_type', data)
|
self.assertIn("by_type", data)
|
||||||
self.assertIn('tag', data['by_type'])
|
self.assertIn("tag", data["by_type"])
|
||||||
self.assertEqual(data['by_type']['tag']['total'], 2)
|
self.assertEqual(data["by_type"]["tag"]["total"], 2)
|
||||||
self.assertEqual(data['by_type']['tag']['applied'], 2)
|
self.assertEqual(data["by_type"]["tag"]["applied"], 2)
|
||||||
self.assertEqual(data['by_type']['tag']['rejected'], 0)
|
self.assertEqual(data["by_type"]["tag"]["rejected"], 0)
|
||||||
|
|
||||||
# Check confidence averages
|
# Check confidence averages
|
||||||
self.assertGreater(data['average_confidence_applied'], 0)
|
self.assertGreater(data["average_confidence_applied"], 0)
|
||||||
self.assertGreater(data['average_confidence_rejected'], 0)
|
self.assertGreater(data["average_confidence_rejected"], 0)
|
||||||
|
|
||||||
# Check recent suggestions
|
# Check recent suggestions
|
||||||
self.assertIn('recent_suggestions', data)
|
self.assertIn("recent_suggestions", data)
|
||||||
self.assertEqual(len(data['recent_suggestions']), 3)
|
self.assertEqual(len(data["recent_suggestions"]), 3)
|
||||||
|
|
||||||
def test_ai_suggestion_stats_accuracy_calculation(self):
|
def test_ai_suggestion_stats_accuracy_calculation(self):
|
||||||
"""Test that accuracy rate is calculated correctly."""
|
"""Test that accuracy rate is calculated correctly."""
|
||||||
|
|
@ -406,7 +404,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
for i in range(7):
|
for i in range(7):
|
||||||
AISuggestionFeedback.objects.create(
|
AISuggestionFeedback.objects.create(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='tag',
|
suggestion_type="tag",
|
||||||
suggested_value_id=self.tag1.id,
|
suggested_value_id=self.tag1.id,
|
||||||
confidence=0.80,
|
confidence=0.80,
|
||||||
status=AISuggestionFeedback.STATUS_APPLIED,
|
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||||
|
|
@ -416,7 +414,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
AISuggestionFeedback.objects.create(
|
AISuggestionFeedback.objects.create(
|
||||||
document=self.document,
|
document=self.document,
|
||||||
suggestion_type='tag',
|
suggestion_type="tag",
|
||||||
suggested_value_id=self.tag2.id,
|
suggested_value_id=self.tag2.id,
|
||||||
confidence=0.60,
|
confidence=0.60,
|
||||||
status=AISuggestionFeedback.STATUS_REJECTED,
|
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||||
|
|
@ -428,10 +426,10 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
self.assertEqual(data['total_suggestions'], 10)
|
self.assertEqual(data["total_suggestions"], 10)
|
||||||
self.assertEqual(data['total_applied'], 7)
|
self.assertEqual(data["total_applied"], 7)
|
||||||
self.assertEqual(data['total_rejected'], 3)
|
self.assertEqual(data["total_rejected"], 3)
|
||||||
self.assertEqual(data['accuracy_rate'], 70.0)
|
self.assertEqual(data["accuracy_rate"], 70.0)
|
||||||
|
|
||||||
def test_authentication_required(self):
|
def test_authentication_required(self):
|
||||||
"""Test that authentication is required for all endpoints."""
|
"""Test that authentication is required for all endpoints."""
|
||||||
|
|
@ -439,7 +437,7 @@ class TestAISuggestionsAPI(DirectoriesMixin, APITestCase):
|
||||||
|
|
||||||
# Test ai-suggestions endpoint
|
# Test ai-suggestions endpoint
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
f"/api/documents/{self.document.pk}/ai-suggestions/"
|
f"/api/documents/{self.document.pk}/ai-suggestions/",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,17 +11,14 @@ Tests cover:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.test import override_settings
|
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from documents.models import (
|
from documents.models import Correspondent
|
||||||
Correspondent,
|
from documents.models import DeletionRequest
|
||||||
DeletionRequest,
|
from documents.models import Document
|
||||||
Document,
|
from documents.models import DocumentType
|
||||||
DocumentType,
|
from documents.models import Tag
|
||||||
Tag,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeletionRequestAPI(APITestCase):
|
class TestDeletionRequestAPI(APITestCase):
|
||||||
|
|
|
||||||
|
|
@ -1561,7 +1561,6 @@ class TestConsumerAIScannerIntegration(
|
||||||
Verifies that AI scanner respects database transactions and handles
|
Verifies that AI scanner respects database transactions and handles
|
||||||
rollbacks correctly.
|
rollbacks correctly.
|
||||||
"""
|
"""
|
||||||
from django.db import transaction as db_transaction
|
|
||||||
|
|
||||||
tag = Tag.objects.create(name="Invoice")
|
tag = Tag.objects.create(name="Invoice")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,11 @@ from django.contrib.auth.models import User
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from documents.models import (
|
from documents.models import Correspondent
|
||||||
Correspondent,
|
from documents.models import DeletionRequest
|
||||||
DeletionRequest,
|
from documents.models import Document
|
||||||
Document,
|
from documents.models import DocumentType
|
||||||
DocumentType,
|
from documents.models import Tag
|
||||||
Tag,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeletionRequestModelCreation(TestCase):
|
class TestDeletionRequestModelCreation(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,9 @@ from unittest import mock
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from documents.ml.model_cache import (
|
from documents.ml.model_cache import CacheMetrics
|
||||||
CacheMetrics,
|
from documents.ml.model_cache import LRUCache
|
||||||
LRUCache,
|
from documents.ml.model_cache import ModelCacheManager
|
||||||
ModelCacheManager,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCacheMetrics(TestCase):
|
class TestCacheMetrics(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,6 @@ class TestMLCacheDirectory:
|
||||||
|
|
||||||
def test_model_cache_writable(self, tmp_path):
|
def test_model_cache_writable(self, tmp_path):
|
||||||
"""Test that we can write to model cache directory."""
|
"""Test that we can write to model cache directory."""
|
||||||
import pathlib
|
|
||||||
|
|
||||||
# Use tmp_path fixture for testing
|
# Use tmp_path fixture for testing
|
||||||
cache_dir = tmp_path / ".cache" / "huggingface"
|
cache_dir = tmp_path / ".cache" / "huggingface"
|
||||||
|
|
@ -169,7 +168,6 @@ class TestMLCacheDirectory:
|
||||||
|
|
||||||
def test_torch_cache_directory(self, tmp_path, monkeypatch):
|
def test_torch_cache_directory(self, tmp_path, monkeypatch):
|
||||||
"""Test that PyTorch can use a custom cache directory."""
|
"""Test that PyTorch can use a custom cache directory."""
|
||||||
import torch
|
|
||||||
|
|
||||||
# Set custom cache directory
|
# Set custom cache directory
|
||||||
cache_dir = tmp_path / ".cache" / "torch"
|
cache_dir = tmp_path / ".cache" / "torch"
|
||||||
|
|
@ -204,9 +202,10 @@ class TestMLPerformanceBasic:
|
||||||
|
|
||||||
def test_numpy_performance_basic(self):
|
def test_numpy_performance_basic(self):
|
||||||
"""Test basic NumPy performance with larger arrays."""
|
"""Test basic NumPy performance with larger arrays."""
|
||||||
import numpy as np
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Create large array (10 million elements)
|
# Create large array (10 million elements)
|
||||||
arr = np.random.rand(10_000_000)
|
arr = np.random.rand(10_000_000)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,8 @@ from rest_framework.viewsets import ViewSet
|
||||||
|
|
||||||
from documents import bulk_edit
|
from documents import bulk_edit
|
||||||
from documents import index
|
from documents import index
|
||||||
|
from documents.ai_scanner import AIDocumentScanner
|
||||||
|
from documents.ai_scanner import get_ai_scanner
|
||||||
from documents.bulk_download import ArchiveOnlyStrategy
|
from documents.bulk_download import ArchiveOnlyStrategy
|
||||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||||
from documents.bulk_download import OriginalsOnlyStrategy
|
from documents.bulk_download import OriginalsOnlyStrategy
|
||||||
|
|
@ -141,13 +143,10 @@ from documents.models import UiSettings
|
||||||
from documents.models import Workflow
|
from documents.models import Workflow
|
||||||
from documents.models import WorkflowAction
|
from documents.models import WorkflowAction
|
||||||
from documents.models import WorkflowTrigger
|
from documents.models import WorkflowTrigger
|
||||||
from documents.ai_scanner import AIDocumentScanner
|
|
||||||
from documents.ai_scanner import get_ai_scanner
|
|
||||||
from documents.parsers import get_parser_class_for_mime_type
|
from documents.parsers import get_parser_class_for_mime_type
|
||||||
from documents.parsers import parse_date_generator
|
from documents.parsers import parse_date_generator
|
||||||
from documents.permissions import AcknowledgeTasksPermissions
|
from documents.permissions import AcknowledgeTasksPermissions
|
||||||
from documents.permissions import CanApplyAISuggestionsPermission
|
from documents.permissions import CanApplyAISuggestionsPermission
|
||||||
from documents.permissions import CanApproveDeletionsPermission
|
|
||||||
from documents.permissions import CanConfigureAIPermission
|
from documents.permissions import CanConfigureAIPermission
|
||||||
from documents.permissions import CanViewAISuggestionsPermission
|
from documents.permissions import CanViewAISuggestionsPermission
|
||||||
from documents.permissions import PaperlessAdminPermissions
|
from documents.permissions import PaperlessAdminPermissions
|
||||||
|
|
@ -1388,7 +1387,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||||
scan_result = scanner.scan_document(
|
scan_result = scanner.scan_document(
|
||||||
document=document,
|
document=document,
|
||||||
document_text=document.content,
|
document_text=document.content,
|
||||||
original_file_path=document.source_path if hasattr(document, 'source_path') else None,
|
original_file_path=document.source_path if hasattr(document, "source_path") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert scan result to serializable format
|
# Convert scan result to serializable format
|
||||||
|
|
@ -1424,43 +1423,43 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||||
serializer = ApplySuggestionSerializer(data=request.data)
|
serializer = ApplySuggestionSerializer(data=request.data)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
suggestion_type = serializer.validated_data['suggestion_type']
|
suggestion_type = serializer.validated_data["suggestion_type"]
|
||||||
value_id = serializer.validated_data.get('value_id')
|
value_id = serializer.validated_data.get("value_id")
|
||||||
value_text = serializer.validated_data.get('value_text')
|
value_text = serializer.validated_data.get("value_text")
|
||||||
confidence = serializer.validated_data['confidence']
|
confidence = serializer.validated_data["confidence"]
|
||||||
|
|
||||||
# Apply the suggestion based on type
|
# Apply the suggestion based on type
|
||||||
applied = False
|
applied = False
|
||||||
result_message = ""
|
result_message = ""
|
||||||
|
|
||||||
if suggestion_type == 'tag' and value_id:
|
if suggestion_type == "tag" and value_id:
|
||||||
tag = Tag.objects.get(pk=value_id)
|
tag = Tag.objects.get(pk=value_id)
|
||||||
document.tags.add(tag)
|
document.tags.add(tag)
|
||||||
applied = True
|
applied = True
|
||||||
result_message = f"Tag '{tag.name}' applied"
|
result_message = f"Tag '{tag.name}' applied"
|
||||||
|
|
||||||
elif suggestion_type == 'correspondent' and value_id:
|
elif suggestion_type == "correspondent" and value_id:
|
||||||
correspondent = Correspondent.objects.get(pk=value_id)
|
correspondent = Correspondent.objects.get(pk=value_id)
|
||||||
document.correspondent = correspondent
|
document.correspondent = correspondent
|
||||||
document.save()
|
document.save()
|
||||||
applied = True
|
applied = True
|
||||||
result_message = f"Correspondent '{correspondent.name}' applied"
|
result_message = f"Correspondent '{correspondent.name}' applied"
|
||||||
|
|
||||||
elif suggestion_type == 'document_type' and value_id:
|
elif suggestion_type == "document_type" and value_id:
|
||||||
doc_type = DocumentType.objects.get(pk=value_id)
|
doc_type = DocumentType.objects.get(pk=value_id)
|
||||||
document.document_type = doc_type
|
document.document_type = doc_type
|
||||||
document.save()
|
document.save()
|
||||||
applied = True
|
applied = True
|
||||||
result_message = f"Document type '{doc_type.name}' applied"
|
result_message = f"Document type '{doc_type.name}' applied"
|
||||||
|
|
||||||
elif suggestion_type == 'storage_path' and value_id:
|
elif suggestion_type == "storage_path" and value_id:
|
||||||
storage_path = StoragePath.objects.get(pk=value_id)
|
storage_path = StoragePath.objects.get(pk=value_id)
|
||||||
document.storage_path = storage_path
|
document.storage_path = storage_path
|
||||||
document.save()
|
document.save()
|
||||||
applied = True
|
applied = True
|
||||||
result_message = f"Storage path '{storage_path.name}' applied"
|
result_message = f"Storage path '{storage_path.name}' applied"
|
||||||
|
|
||||||
elif suggestion_type == 'title' and value_text:
|
elif suggestion_type == "title" and value_text:
|
||||||
document.title = value_text
|
document.title = value_text
|
||||||
document.save()
|
document.save()
|
||||||
applied = True
|
applied = True
|
||||||
|
|
@ -1518,10 +1517,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||||
serializer = RejectSuggestionSerializer(data=request.data)
|
serializer = RejectSuggestionSerializer(data=request.data)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
suggestion_type = serializer.validated_data['suggestion_type']
|
suggestion_type = serializer.validated_data["suggestion_type"]
|
||||||
value_id = serializer.validated_data.get('value_id')
|
value_id = serializer.validated_data.get("value_id")
|
||||||
value_text = serializer.validated_data.get('value_text')
|
value_text = serializer.validated_data.get("value_text")
|
||||||
confidence = serializer.validated_data['confidence']
|
confidence = serializer.validated_data["confidence"]
|
||||||
|
|
||||||
# Record feedback
|
# Record feedback
|
||||||
AISuggestionFeedback.objects.create(
|
AISuggestionFeedback.objects.create(
|
||||||
|
|
@ -1554,7 +1553,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||||
Returns aggregated data about applied vs rejected suggestions,
|
Returns aggregated data about applied vs rejected suggestions,
|
||||||
accuracy rates, and confidence scores.
|
accuracy rates, and confidence scores.
|
||||||
"""
|
"""
|
||||||
from django.db.models import Avg, Count, Q
|
from django.db.models import Avg
|
||||||
|
from django.db.models import Count
|
||||||
|
from django.db.models import Q
|
||||||
|
|
||||||
from documents.models import AISuggestionFeedback
|
from documents.models import AISuggestionFeedback
|
||||||
from documents.serializers.ai_suggestions import AISuggestionStatsSerializer
|
from documents.serializers.ai_suggestions import AISuggestionStatsSerializer
|
||||||
|
|
||||||
|
|
@ -1562,61 +1564,63 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||||
# Get overall counts
|
# Get overall counts
|
||||||
total_feedbacks = AISuggestionFeedback.objects.count()
|
total_feedbacks = AISuggestionFeedback.objects.count()
|
||||||
total_applied = AISuggestionFeedback.objects.filter(
|
total_applied = AISuggestionFeedback.objects.filter(
|
||||||
status=AISuggestionFeedback.STATUS_APPLIED
|
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||||
).count()
|
).count()
|
||||||
total_rejected = AISuggestionFeedback.objects.filter(
|
total_rejected = AISuggestionFeedback.objects.filter(
|
||||||
status=AISuggestionFeedback.STATUS_REJECTED
|
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
# Calculate accuracy rate
|
# Calculate accuracy rate
|
||||||
accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0
|
accuracy_rate = (total_applied / total_feedbacks * 100) if total_feedbacks > 0 else 0
|
||||||
|
|
||||||
# Get statistics by suggestion type using a single aggregated query
|
# Get statistics by suggestion type using a single aggregated query
|
||||||
stats_by_type = AISuggestionFeedback.objects.values('suggestion_type').annotate(
|
stats_by_type = AISuggestionFeedback.objects.values("suggestion_type").annotate(
|
||||||
total=Count('id'),
|
total=Count("id"),
|
||||||
applied=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
|
applied=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_APPLIED)),
|
||||||
rejected=Count('id', filter=Q(status=AISuggestionFeedback.STATUS_REJECTED))
|
rejected=Count("id", filter=Q(status=AISuggestionFeedback.STATUS_REJECTED)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the by_type dictionary using the aggregated results
|
# Build the by_type dictionary using the aggregated results
|
||||||
by_type = {}
|
by_type = {}
|
||||||
for stat in stats_by_type:
|
for stat in stats_by_type:
|
||||||
suggestion_type = stat['suggestion_type']
|
suggestion_type = stat["suggestion_type"]
|
||||||
type_total = stat['total']
|
type_total = stat["total"]
|
||||||
type_applied = stat['applied']
|
type_applied = stat["applied"]
|
||||||
type_rejected = stat['rejected']
|
type_rejected = stat["rejected"]
|
||||||
|
|
||||||
by_type[suggestion_type] = {
|
by_type[suggestion_type] = {
|
||||||
'total': type_total,
|
"total": type_total,
|
||||||
'applied': type_applied,
|
"applied": type_applied,
|
||||||
'rejected': type_rejected,
|
"rejected": type_rejected,
|
||||||
'accuracy_rate': (type_applied / type_total * 100) if type_total > 0 else 0,
|
"accuracy_rate": (type_applied / type_total * 100) if type_total > 0 else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get average confidence scores
|
# Get average confidence scores
|
||||||
avg_confidence_applied = AISuggestionFeedback.objects.filter(
|
avg_confidence_applied = AISuggestionFeedback.objects.filter(
|
||||||
status=AISuggestionFeedback.STATUS_APPLIED
|
status=AISuggestionFeedback.STATUS_APPLIED,
|
||||||
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
|
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
|
||||||
|
|
||||||
avg_confidence_rejected = AISuggestionFeedback.objects.filter(
|
avg_confidence_rejected = AISuggestionFeedback.objects.filter(
|
||||||
status=AISuggestionFeedback.STATUS_REJECTED
|
status=AISuggestionFeedback.STATUS_REJECTED,
|
||||||
).aggregate(Avg('confidence'))['confidence__avg'] or 0.0
|
).aggregate(Avg("confidence"))["confidence__avg"] or 0.0
|
||||||
|
|
||||||
# Get recent suggestions (last 10)
|
# Get recent suggestions (last 10)
|
||||||
recent_suggestions = AISuggestionFeedback.objects.order_by('-created_at')[:10]
|
recent_suggestions = AISuggestionFeedback.objects.order_by("-created_at")[:10]
|
||||||
|
|
||||||
# Build response data
|
# Build response data
|
||||||
from documents.serializers.ai_suggestions import AISuggestionFeedbackSerializer
|
from documents.serializers.ai_suggestions import (
|
||||||
|
AISuggestionFeedbackSerializer,
|
||||||
|
)
|
||||||
data = {
|
data = {
|
||||||
'total_suggestions': total_feedbacks,
|
"total_suggestions": total_feedbacks,
|
||||||
'total_applied': total_applied,
|
"total_applied": total_applied,
|
||||||
'total_rejected': total_rejected,
|
"total_rejected": total_rejected,
|
||||||
'accuracy_rate': accuracy_rate,
|
"accuracy_rate": accuracy_rate,
|
||||||
'by_type': by_type,
|
"by_type": by_type,
|
||||||
'average_confidence_applied': avg_confidence_applied,
|
"average_confidence_applied": avg_confidence_applied,
|
||||||
'average_confidence_rejected': avg_confidence_rejected,
|
"average_confidence_rejected": avg_confidence_rejected,
|
||||||
'recent_suggestions': AISuggestionFeedbackSerializer(
|
"recent_suggestions": AISuggestionFeedbackSerializer(
|
||||||
recent_suggestions, many=True
|
recent_suggestions, many=True,
|
||||||
).data,
|
).data,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3571,21 +3575,21 @@ class AISuggestionsView(GenericAPIView):
|
||||||
request_serializer = AISuggestionsRequestSerializer(data=request.data)
|
request_serializer = AISuggestionsRequestSerializer(data=request.data)
|
||||||
request_serializer.is_valid(raise_exception=True)
|
request_serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
document_id = request_serializer.validated_data['document_id']
|
document_id = request_serializer.validated_data["document_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
document = Document.objects.get(pk=document_id)
|
document = Document.objects.get(pk=document_id)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
return Response(
|
return Response(
|
||||||
{"error": "Document not found or you don't have permission to view it"},
|
{"error": "Document not found or you don't have permission to view it"},
|
||||||
status=status.HTTP_404_NOT_FOUND
|
status=status.HTTP_404_NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if user has permission to view this document
|
# Check if user has permission to view this document
|
||||||
if not has_perms_owner_aware(request.user, 'documents.view_document', document):
|
if not has_perms_owner_aware(request.user, "documents.view_document", document):
|
||||||
return Response(
|
return Response(
|
||||||
{"error": "Permission denied"},
|
{"error": "Permission denied"},
|
||||||
status=status.HTTP_403_FORBIDDEN
|
status=status.HTTP_403_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get AI scanner and scan document
|
# Get AI scanner and scan document
|
||||||
|
|
@ -3600,7 +3604,7 @@ class AISuggestionsView(GenericAPIView):
|
||||||
"document_type": None,
|
"document_type": None,
|
||||||
"storage_path": None,
|
"storage_path": None,
|
||||||
"title_suggestion": scan_result.title_suggestion,
|
"title_suggestion": scan_result.title_suggestion,
|
||||||
"custom_fields": {}
|
"custom_fields": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Format tag suggestions
|
# Format tag suggestions
|
||||||
|
|
@ -3610,7 +3614,7 @@ class AISuggestionsView(GenericAPIView):
|
||||||
response_data["tags"].append({
|
response_data["tags"].append({
|
||||||
"id": tag.id,
|
"id": tag.id,
|
||||||
"name": tag.name,
|
"name": tag.name,
|
||||||
"confidence": confidence
|
"confidence": confidence,
|
||||||
})
|
})
|
||||||
except Tag.DoesNotExist:
|
except Tag.DoesNotExist:
|
||||||
# Tag was suggested by AI but no longer exists; skip it
|
# Tag was suggested by AI but no longer exists; skip it
|
||||||
|
|
@ -3624,7 +3628,7 @@ class AISuggestionsView(GenericAPIView):
|
||||||
response_data["correspondent"] = {
|
response_data["correspondent"] = {
|
||||||
"id": correspondent.id,
|
"id": correspondent.id,
|
||||||
"name": correspondent.name,
|
"name": correspondent.name,
|
||||||
"confidence": confidence
|
"confidence": confidence,
|
||||||
}
|
}
|
||||||
except Correspondent.DoesNotExist:
|
except Correspondent.DoesNotExist:
|
||||||
# Correspondent was suggested but no longer exists; skip it
|
# Correspondent was suggested but no longer exists; skip it
|
||||||
|
|
@ -3638,7 +3642,7 @@ class AISuggestionsView(GenericAPIView):
|
||||||
response_data["document_type"] = {
|
response_data["document_type"] = {
|
||||||
"id": doc_type.id,
|
"id": doc_type.id,
|
||||||
"name": doc_type.name,
|
"name": doc_type.name,
|
||||||
"confidence": confidence
|
"confidence": confidence,
|
||||||
}
|
}
|
||||||
except DocumentType.DoesNotExist:
|
except DocumentType.DoesNotExist:
|
||||||
# Document type was suggested but no longer exists; skip it
|
# Document type was suggested but no longer exists; skip it
|
||||||
|
|
@ -3652,7 +3656,7 @@ class AISuggestionsView(GenericAPIView):
|
||||||
response_data["storage_path"] = {
|
response_data["storage_path"] = {
|
||||||
"id": storage_path.id,
|
"id": storage_path.id,
|
||||||
"name": storage_path.name,
|
"name": storage_path.name,
|
||||||
"confidence": confidence
|
"confidence": confidence,
|
||||||
}
|
}
|
||||||
except StoragePath.DoesNotExist:
|
except StoragePath.DoesNotExist:
|
||||||
# Storage path was suggested but no longer exists; skip it
|
# Storage path was suggested but no longer exists; skip it
|
||||||
|
|
@ -3662,7 +3666,7 @@ class AISuggestionsView(GenericAPIView):
|
||||||
for field_id, (value, confidence) in scan_result.custom_fields.items():
|
for field_id, (value, confidence) in scan_result.custom_fields.items():
|
||||||
response_data["custom_fields"][str(field_id)] = {
|
response_data["custom_fields"][str(field_id)] = {
|
||||||
"value": value,
|
"value": value,
|
||||||
"confidence": confidence
|
"confidence": confidence,
|
||||||
}
|
}
|
||||||
|
|
||||||
return Response(response_data)
|
return Response(response_data)
|
||||||
|
|
@ -3683,21 +3687,21 @@ class ApplyAISuggestionsView(GenericAPIView):
|
||||||
serializer = ApplyAISuggestionsSerializer(data=request.data)
|
serializer = ApplyAISuggestionsSerializer(data=request.data)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
document_id = serializer.validated_data['document_id']
|
document_id = serializer.validated_data["document_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
document = Document.objects.get(pk=document_id)
|
document = Document.objects.get(pk=document_id)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
return Response(
|
return Response(
|
||||||
{"error": "Document not found"},
|
{"error": "Document not found"},
|
||||||
status=status.HTTP_404_NOT_FOUND
|
status=status.HTTP_404_NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if user has permission to change this document
|
# Check if user has permission to change this document
|
||||||
if not has_perms_owner_aware(request.user, 'documents.change_document', document):
|
if not has_perms_owner_aware(request.user, "documents.change_document", document):
|
||||||
return Response(
|
return Response(
|
||||||
{"error": "Permission denied"},
|
{"error": "Permission denied"},
|
||||||
status=status.HTTP_403_FORBIDDEN
|
status=status.HTTP_403_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get AI scanner and scan document
|
# Get AI scanner and scan document
|
||||||
|
|
@ -3707,8 +3711,8 @@ class ApplyAISuggestionsView(GenericAPIView):
|
||||||
# Apply suggestions based on user selections
|
# Apply suggestions based on user selections
|
||||||
applied = []
|
applied = []
|
||||||
|
|
||||||
if serializer.validated_data.get('apply_tags'):
|
if serializer.validated_data.get("apply_tags"):
|
||||||
selected_tags = serializer.validated_data.get('selected_tags', [])
|
selected_tags = serializer.validated_data.get("selected_tags", [])
|
||||||
if selected_tags:
|
if selected_tags:
|
||||||
# Apply only selected tags
|
# Apply only selected tags
|
||||||
tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags]
|
tags_to_apply = [tag_id for tag_id, _ in scan_result.tags if tag_id in selected_tags]
|
||||||
|
|
@ -3725,7 +3729,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
||||||
# Tag not found; skip applying this tag
|
# Tag not found; skip applying this tag
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if serializer.validated_data.get('apply_correspondent') and scan_result.correspondent:
|
if serializer.validated_data.get("apply_correspondent") and scan_result.correspondent:
|
||||||
corr_id, confidence = scan_result.correspondent
|
corr_id, confidence = scan_result.correspondent
|
||||||
try:
|
try:
|
||||||
correspondent = Correspondent.objects.get(pk=corr_id)
|
correspondent = Correspondent.objects.get(pk=corr_id)
|
||||||
|
|
@ -3735,7 +3739,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
||||||
# Correspondent not found; skip applying
|
# Correspondent not found; skip applying
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if serializer.validated_data.get('apply_document_type') and scan_result.document_type:
|
if serializer.validated_data.get("apply_document_type") and scan_result.document_type:
|
||||||
type_id, confidence = scan_result.document_type
|
type_id, confidence = scan_result.document_type
|
||||||
try:
|
try:
|
||||||
doc_type = DocumentType.objects.get(pk=type_id)
|
doc_type = DocumentType.objects.get(pk=type_id)
|
||||||
|
|
@ -3745,7 +3749,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
||||||
# Document type not found; skip applying
|
# Document type not found; skip applying
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if serializer.validated_data.get('apply_storage_path') and scan_result.storage_path:
|
if serializer.validated_data.get("apply_storage_path") and scan_result.storage_path:
|
||||||
path_id, confidence = scan_result.storage_path
|
path_id, confidence = scan_result.storage_path
|
||||||
try:
|
try:
|
||||||
storage_path = StoragePath.objects.get(pk=path_id)
|
storage_path = StoragePath.objects.get(pk=path_id)
|
||||||
|
|
@ -3755,7 +3759,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
||||||
# Storage path not found; skip applying
|
# Storage path not found; skip applying
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if serializer.validated_data.get('apply_title') and scan_result.title_suggestion:
|
if serializer.validated_data.get("apply_title") and scan_result.title_suggestion:
|
||||||
document.title = scan_result.title_suggestion
|
document.title = scan_result.title_suggestion
|
||||||
applied.append(f"title: {scan_result.title_suggestion}")
|
applied.append(f"title: {scan_result.title_suggestion}")
|
||||||
|
|
||||||
|
|
@ -3765,7 +3769,7 @@ class ApplyAISuggestionsView(GenericAPIView):
|
||||||
return Response({
|
return Response({
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"document_id": document.id,
|
"document_id": document.id,
|
||||||
"applied": applied
|
"applied": applied,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3805,14 +3809,14 @@ class AIConfigurationView(GenericAPIView):
|
||||||
|
|
||||||
# Create new scanner with updated configuration
|
# Create new scanner with updated configuration
|
||||||
config = {}
|
config = {}
|
||||||
if 'auto_apply_threshold' in serializer.validated_data:
|
if "auto_apply_threshold" in serializer.validated_data:
|
||||||
config['auto_apply_threshold'] = serializer.validated_data['auto_apply_threshold']
|
config["auto_apply_threshold"] = serializer.validated_data["auto_apply_threshold"]
|
||||||
if 'suggest_threshold' in serializer.validated_data:
|
if "suggest_threshold" in serializer.validated_data:
|
||||||
config['suggest_threshold'] = serializer.validated_data['suggest_threshold']
|
config["suggest_threshold"] = serializer.validated_data["suggest_threshold"]
|
||||||
if 'ml_enabled' in serializer.validated_data:
|
if "ml_enabled" in serializer.validated_data:
|
||||||
config['enable_ml_features'] = serializer.validated_data['ml_enabled']
|
config["enable_ml_features"] = serializer.validated_data["ml_enabled"]
|
||||||
if 'advanced_ocr_enabled' in serializer.validated_data:
|
if "advanced_ocr_enabled" in serializer.validated_data:
|
||||||
config['enable_advanced_ocr'] = serializer.validated_data['advanced_ocr_enabled']
|
config["enable_advanced_ocr"] = serializer.validated_data["advanced_ocr_enabled"]
|
||||||
|
|
||||||
# Update global scanner instance
|
# Update global scanner instance
|
||||||
# WARNING: Not thread-safe. Consider storing configuration in database
|
# WARNING: Not thread-safe. Consider storing configuration in database
|
||||||
|
|
@ -3822,7 +3826,7 @@ class AIConfigurationView(GenericAPIView):
|
||||||
|
|
||||||
return Response({
|
return Response({
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": "AI configuration updated. Changes may require server restart for consistency."
|
"message": "AI configuration updated. Changes may require server restart for consistency.",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
||||||
# Check permissions
|
# Check permissions
|
||||||
if not self._can_manage_request(deletion_request):
|
if not self._can_manage_request(deletion_request):
|
||||||
return HttpResponseForbidden(
|
return HttpResponseForbidden(
|
||||||
"You don't have permission to approve this deletion request."
|
"You don't have permission to approve this deletion request.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate status
|
# Validate status
|
||||||
|
|
@ -114,11 +114,11 @@ class DeletionRequestViewSet(ModelViewSet):
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted document {doc_id} ('{doc_title}') "
|
f"Deleted document {doc_id} ('{doc_title}') "
|
||||||
f"as part of deletion request {deletion_request.id}"
|
f"as part of deletion request {deletion_request.id}",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to delete document {doc.id}: {str(e)}"
|
f"Failed to delete document {doc.id}: {e!s}",
|
||||||
)
|
)
|
||||||
failed_deletions.append({
|
failed_deletions.append({
|
||||||
"id": doc.id,
|
"id": doc.id,
|
||||||
|
|
@ -138,14 +138,14 @@ class DeletionRequestViewSet(ModelViewSet):
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deletion request {deletion_request.id} completed. "
|
f"Deletion request {deletion_request.id} completed. "
|
||||||
f"Deleted {deleted_count}/{len(documents)} documents."
|
f"Deleted {deleted_count}/{len(documents)} documents.",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error executing deletion request {deletion_request.id}: {str(e)}"
|
f"Error executing deletion request {deletion_request.id}: {e!s}",
|
||||||
)
|
)
|
||||||
return Response(
|
return Response(
|
||||||
{"error": f"Failed to execute deletion: {str(e)}"},
|
{"error": f"Failed to execute deletion: {e!s}"},
|
||||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -176,7 +176,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
||||||
# Check permissions
|
# Check permissions
|
||||||
if not self._can_manage_request(deletion_request):
|
if not self._can_manage_request(deletion_request):
|
||||||
return HttpResponseForbidden(
|
return HttpResponseForbidden(
|
||||||
"You don't have permission to reject this deletion request."
|
"You don't have permission to reject this deletion request.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate status
|
# Validate status
|
||||||
|
|
@ -199,7 +199,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deletion request {deletion_request.id} rejected by user {request.user.username}"
|
f"Deletion request {deletion_request.id} rejected by user {request.user.username}",
|
||||||
)
|
)
|
||||||
|
|
||||||
serializer = self.get_serializer(deletion_request)
|
serializer = self.get_serializer(deletion_request)
|
||||||
|
|
@ -228,7 +228,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
||||||
# Check permissions
|
# Check permissions
|
||||||
if not self._can_manage_request(deletion_request):
|
if not self._can_manage_request(deletion_request):
|
||||||
return HttpResponseForbidden(
|
return HttpResponseForbidden(
|
||||||
"You don't have permission to cancel this deletion request."
|
"You don't have permission to cancel this deletion request.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate status
|
# Validate status
|
||||||
|
|
@ -249,7 +249,7 @@ class DeletionRequestViewSet(ModelViewSet):
|
||||||
deletion_request.save()
|
deletion_request.save()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}"
|
f"Deletion request {deletion_request.id} cancelled by user {request.user.username}",
|
||||||
)
|
)
|
||||||
|
|
||||||
serializer = self.get_serializer(deletion_request)
|
serializer = self.get_serializer(deletion_request)
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ class SecurityHeadersMiddleware:
|
||||||
|
|
||||||
# Store nonce in request for use in templates
|
# Store nonce in request for use in templates
|
||||||
# Templates can access this via {{ request.csp_nonce }}
|
# Templates can access this via {{ request.csp_nonce }}
|
||||||
if hasattr(request, '_csp_nonce'):
|
if hasattr(request, "_csp_nonce"):
|
||||||
request._csp_nonce = nonce
|
request._csp_nonce = nonce
|
||||||
|
|
||||||
# Prevent clickjacking attacks
|
# Prevent clickjacking attacks
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -26,39 +25,39 @@ logger = logging.getLogger("paperless.security")
|
||||||
# Lista explícita de tipos MIME permitidos
|
# Lista explícita de tipos MIME permitidos
|
||||||
ALLOWED_MIME_TYPES = {
|
ALLOWED_MIME_TYPES = {
|
||||||
# Documentos
|
# Documentos
|
||||||
'application/pdf',
|
"application/pdf",
|
||||||
'application/vnd.oasis.opendocument.text',
|
"application/vnd.oasis.opendocument.text",
|
||||||
'application/msword',
|
"application/msword",
|
||||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
'application/vnd.ms-excel',
|
"application/vnd.ms-excel",
|
||||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
'application/vnd.ms-powerpoint',
|
"application/vnd.ms-powerpoint",
|
||||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
'application/vnd.oasis.opendocument.spreadsheet',
|
"application/vnd.oasis.opendocument.spreadsheet",
|
||||||
'application/vnd.oasis.opendocument.presentation',
|
"application/vnd.oasis.opendocument.presentation",
|
||||||
'application/rtf',
|
"application/rtf",
|
||||||
'text/rtf',
|
"text/rtf",
|
||||||
|
|
||||||
# Imágenes
|
# Imágenes
|
||||||
'image/jpeg',
|
"image/jpeg",
|
||||||
'image/png',
|
"image/png",
|
||||||
'image/gif',
|
"image/gif",
|
||||||
'image/tiff',
|
"image/tiff",
|
||||||
'image/bmp',
|
"image/bmp",
|
||||||
'image/webp',
|
"image/webp",
|
||||||
|
|
||||||
# Texto
|
# Texto
|
||||||
'text/plain',
|
"text/plain",
|
||||||
'text/html',
|
"text/html",
|
||||||
'text/csv',
|
"text/csv",
|
||||||
'text/markdown',
|
"text/markdown",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Maximum file size (100MB by default)
|
# Maximum file size (100MB by default)
|
||||||
# Can be overridden by settings.MAX_UPLOAD_SIZE
|
# Can be overridden by settings.MAX_UPLOAD_SIZE
|
||||||
try:
|
try:
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
MAX_FILE_SIZE = getattr(settings, 'MAX_UPLOAD_SIZE', 100 * 1024 * 1024) # 100MB por defecto
|
MAX_FILE_SIZE = getattr(settings, "MAX_UPLOAD_SIZE", 100 * 1024 * 1024) # 100MB por defecto
|
||||||
except ImportError:
|
except ImportError:
|
||||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes
|
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB in bytes
|
||||||
|
|
||||||
|
|
@ -114,7 +113,6 @@ ALLOWED_JS_PATTERNS = [
|
||||||
class FileValidationError(Exception):
|
class FileValidationError(Exception):
|
||||||
"""Raised when file validation fails."""
|
"""Raised when file validation fails."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def has_whitelisted_javascript(content: bytes) -> bool:
|
def has_whitelisted_javascript(content: bytes) -> bool:
|
||||||
|
|
@ -143,7 +141,7 @@ def validate_mime_type(mime_type: str) -> None:
|
||||||
if mime_type not in ALLOWED_MIME_TYPES:
|
if mime_type not in ALLOWED_MIME_TYPES:
|
||||||
raise FileValidationError(
|
raise FileValidationError(
|
||||||
f"MIME type '{mime_type}' is not allowed. "
|
f"MIME type '{mime_type}' is not allowed. "
|
||||||
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}"
|
f"Allowed types: {', '.join(sorted(ALLOWED_MIME_TYPES))}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ api_router.register(r"config", ApplicationConfigurationViewSet)
|
||||||
api_router.register(r"processed_mail", ProcessedMailViewSet)
|
api_router.register(r"processed_mail", ProcessedMailViewSet)
|
||||||
api_router.register(r"deletion_requests", DeletionRequestViewSet)
|
api_router.register(r"deletion_requests", DeletionRequestViewSet)
|
||||||
api_router.register(
|
api_router.register(
|
||||||
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests"
|
r"deletion-requests", DeletionRequestViewSet, basename="deletion-requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue