mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-28 17:28:00 +01:00
Changes before error encountered
Co-authored-by: dawnsystem <42047891+dawnsystem@users.noreply.github.com>
This commit is contained in:
parent
275ff4d1d4
commit
894f7e231d
5 changed files with 619 additions and 29 deletions
|
|
@ -133,35 +133,57 @@ class AIDocumentScanner:
|
|||
)
|
||||
|
||||
def _get_classifier(self):
|
||||
"""Lazy load the ML classifier."""
|
||||
"""Lazy load the ML classifier with caching."""
|
||||
if self._classifier is None and self.ml_enabled:
|
||||
try:
|
||||
from documents.ml.classifier import TransformerDocumentClassifier
|
||||
self._classifier = TransformerDocumentClassifier()
|
||||
logger.info("ML classifier loaded successfully")
|
||||
|
||||
# Get model name from settings
|
||||
model_name = getattr(
|
||||
settings,
|
||||
"PAPERLESS_ML_CLASSIFIER_MODEL",
|
||||
"distilbert-base-uncased",
|
||||
)
|
||||
|
||||
self._classifier = TransformerDocumentClassifier(
|
||||
model_name=model_name,
|
||||
use_cache=True,
|
||||
)
|
||||
logger.info("ML classifier loaded successfully with caching")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ML classifier: {e}")
|
||||
self.ml_enabled = False
|
||||
return self._classifier
|
||||
|
||||
def _get_ner_extractor(self):
|
||||
"""Lazy load the NER extractor."""
|
||||
"""Lazy load the NER extractor with caching."""
|
||||
if self._ner_extractor is None and self.ml_enabled:
|
||||
try:
|
||||
from documents.ml.ner import DocumentNER
|
||||
self._ner_extractor = DocumentNER()
|
||||
logger.info("NER extractor loaded successfully")
|
||||
self._ner_extractor = DocumentNER(use_cache=True)
|
||||
logger.info("NER extractor loaded successfully with caching")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load NER extractor: {e}")
|
||||
return self._ner_extractor
|
||||
|
||||
def _get_semantic_search(self):
|
||||
"""Lazy load semantic search."""
|
||||
"""Lazy load semantic search with caching."""
|
||||
if self._semantic_search is None and self.ml_enabled:
|
||||
try:
|
||||
from documents.ml.semantic_search import SemanticSearch
|
||||
self._semantic_search = SemanticSearch()
|
||||
logger.info("Semantic search loaded successfully")
|
||||
|
||||
# Get cache directory from settings
|
||||
cache_dir = getattr(
|
||||
settings,
|
||||
"PAPERLESS_ML_MODEL_CACHE",
|
||||
None,
|
||||
)
|
||||
|
||||
self._semantic_search = SemanticSearch(
|
||||
cache_dir=cache_dir,
|
||||
use_cache=True,
|
||||
)
|
||||
logger.info("Semantic search loaded successfully with caching")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load semantic search: {e}")
|
||||
return self._semantic_search
|
||||
|
|
@ -811,6 +833,99 @@ class AIDocumentScanner:
|
|||
"suggestions": suggestions,
|
||||
}
|
||||
|
||||
def warm_up_models(self) -> None:
|
||||
"""
|
||||
Pre-load all ML models on startup (warm-up).
|
||||
|
||||
This ensures models are cached and ready for use,
|
||||
making the first document scan fast.
|
||||
"""
|
||||
if not self.ml_enabled:
|
||||
logger.info("ML features disabled, skipping warm-up")
|
||||
return
|
||||
|
||||
import time
|
||||
logger.info("Starting ML model warm-up...")
|
||||
start_time = time.time()
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
# Define model loaders
|
||||
model_loaders = {}
|
||||
|
||||
# Classifier
|
||||
if self.ml_enabled:
|
||||
def load_classifier():
|
||||
from documents.ml.classifier import TransformerDocumentClassifier
|
||||
model_name = getattr(
|
||||
settings,
|
||||
"PAPERLESS_ML_CLASSIFIER_MODEL",
|
||||
"distilbert-base-uncased",
|
||||
)
|
||||
return TransformerDocumentClassifier(
|
||||
model_name=model_name,
|
||||
use_cache=True,
|
||||
)
|
||||
model_loaders["classifier"] = load_classifier
|
||||
|
||||
# NER
|
||||
if self.ml_enabled:
|
||||
def load_ner():
|
||||
from documents.ml.ner import DocumentNER
|
||||
return DocumentNER(use_cache=True)
|
||||
model_loaders["ner"] = load_ner
|
||||
|
||||
# Semantic Search
|
||||
if self.ml_enabled:
|
||||
def load_semantic():
|
||||
from documents.ml.semantic_search import SemanticSearch
|
||||
cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None)
|
||||
return SemanticSearch(cache_dir=cache_dir, use_cache=True)
|
||||
model_loaders["semantic_search"] = load_semantic
|
||||
|
||||
# Table Extractor
|
||||
if self.advanced_ocr_enabled:
|
||||
def load_table():
|
||||
from documents.ocr.table_extractor import TableExtractor
|
||||
return TableExtractor()
|
||||
model_loaders["table_extractor"] = load_table
|
||||
|
||||
# Warm up all models
|
||||
cache_manager.warm_up(model_loaders)
|
||||
|
||||
warm_up_time = time.time() - start_time
|
||||
logger.info(f"ML model warm-up completed in {warm_up_time:.2f}s")
|
||||
|
||||
def get_cache_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache performance metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
try:
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
return cache_manager.get_metrics()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache metrics: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all model caches."""
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
try:
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
cache_manager.clear_all()
|
||||
logger.info("All model caches cleared")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
|
||||
|
||||
# Global scanner instance (lazy initialized)
|
||||
_scanner_instance = None
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from transformers import (
|
|||
TrainingArguments,
|
||||
)
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from documents.models import Document
|
||||
|
||||
|
|
@ -93,7 +95,11 @@ class TransformerDocumentClassifier:
|
|||
- Works well even with limited training data
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "distilbert-base-uncased"):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "distilbert-base-uncased",
|
||||
use_cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize classifier.
|
||||
|
||||
|
|
@ -103,14 +109,25 @@ class TransformerDocumentClassifier:
|
|||
Alternatives:
|
||||
- bert-base-uncased (440MB, more accurate)
|
||||
- albert-base-v2 (47MB, smallest)
|
||||
use_cache: Whether to use model cache (default: True)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.use_cache = use_cache
|
||||
self.cache_manager = ModelCacheManager.get_instance() if use_cache else None
|
||||
|
||||
# Cache key for this model configuration
|
||||
self.cache_key = f"classifier_{model_name}"
|
||||
|
||||
# Load tokenizer (lightweight, not cached)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = None
|
||||
self.label_map = {}
|
||||
self.reverse_label_map = {}
|
||||
|
||||
logger.info(f"Initialized TransformerDocumentClassifier with {model_name}")
|
||||
logger.info(
|
||||
f"Initialized TransformerDocumentClassifier with {model_name} "
|
||||
f"(caching: {use_cache})"
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
|
|
@ -215,10 +232,26 @@ class TransformerDocumentClassifier:
|
|||
Args:
|
||||
model_dir: Directory containing saved model
|
||||
"""
|
||||
logger.info(f"Loading model from {model_dir}")
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model.eval() # Set to evaluation mode
|
||||
if self.use_cache and self.cache_manager:
|
||||
# Try to get from cache first
|
||||
cache_key = f"{self.cache_key}_{model_dir}"
|
||||
|
||||
def loader():
|
||||
logger.info(f"Loading model from {model_dir}")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model.eval() # Set to evaluation mode
|
||||
return {"model": model, "tokenizer": tokenizer}
|
||||
|
||||
cached = self.cache_manager.get_or_load_model(cache_key, loader)
|
||||
self.model = cached["model"]
|
||||
self.tokenizer = cached["tokenizer"]
|
||||
else:
|
||||
# Load without caching
|
||||
logger.info(f"Loading model from {model_dir}")
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model.eval() # Set to evaluation mode
|
||||
|
||||
def predict(
|
||||
self,
|
||||
|
|
|
|||
381
src/documents/ml/model_cache.py
Normal file
381
src/documents/ml/model_cache.py
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
"""
|
||||
ML Model Cache Manager for IntelliDocs-ngx.
|
||||
|
||||
Provides efficient caching for ML models with:
|
||||
- Singleton pattern to ensure single model instance per type
|
||||
- LRU eviction policy for memory management
|
||||
- Disk cache for embeddings
|
||||
- Warm-up on startup
|
||||
- Cache hit/miss metrics
|
||||
|
||||
This solves the performance issue where models are loaded fresh each time,
|
||||
causing slow performance. With this cache:
|
||||
- First load: slow (model download/load)
|
||||
- Subsequent loads: fast (from cache)
|
||||
- Memory controlled: <2GB total
|
||||
- Cache hits: >90% after warm-up
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger("paperless.ml.model_cache")
|
||||
|
||||
|
||||
class CacheMetrics:
|
||||
"""
|
||||
Track cache performance metrics.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.evictions = 0
|
||||
self.loads = 0
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def record_hit(self):
|
||||
with self.lock:
|
||||
self.hits += 1
|
||||
|
||||
def record_miss(self):
|
||||
with self.lock:
|
||||
self.misses += 1
|
||||
|
||||
def record_eviction(self):
|
||||
with self.lock:
|
||||
self.evictions += 1
|
||||
|
||||
def record_load(self):
|
||||
with self.lock:
|
||||
self.loads += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
with self.lock:
|
||||
total = self.hits + self.misses
|
||||
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"loads": self.loads,
|
||||
"total_requests": total,
|
||||
"hit_rate": f"{hit_rate:.2f}%",
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
with self.lock:
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.evictions = 0
|
||||
self.loads = 0
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""
|
||||
Thread-safe LRU (Least Recently Used) cache implementation.
|
||||
|
||||
When the cache is full, the least recently used item is evicted.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 3):
|
||||
"""
|
||||
Initialize LRU cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of items to cache
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.cache: OrderedDict[str, Any] = OrderedDict()
|
||||
self.lock = threading.Lock()
|
||||
self.metrics = CacheMetrics()
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get item from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found
|
||||
"""
|
||||
with self.lock:
|
||||
if key not in self.cache:
|
||||
self.metrics.record_miss()
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self.cache.move_to_end(key)
|
||||
self.metrics.record_hit()
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Add item to cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
"""
|
||||
with self.lock:
|
||||
if key in self.cache:
|
||||
# Update existing item
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
return
|
||||
|
||||
# Add new item
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
# Evict least recently used if needed
|
||||
if len(self.cache) > self.max_size:
|
||||
evicted_key, _ = self.cache.popitem(last=False)
|
||||
self.metrics.record_eviction()
|
||||
logger.info(f"Evicted model from cache: {evicted_key}")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cached items."""
|
||||
with self.lock:
|
||||
self.cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
with self.lock:
|
||||
return len(self.cache)
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get cache metrics."""
|
||||
return self.metrics.get_stats()
|
||||
|
||||
|
||||
class ModelCacheManager:
|
||||
"""
|
||||
Singleton cache manager for ML models.
|
||||
|
||||
Provides centralized caching for all ML models with:
|
||||
- Lazy loading with caching
|
||||
- LRU eviction policy
|
||||
- Thread-safe operations
|
||||
- Performance metrics
|
||||
|
||||
Usage:
|
||||
cache = ModelCacheManager.get_instance()
|
||||
model = cache.get_or_load_model("classifier", loader_func)
|
||||
"""
|
||||
|
||||
_instance: Optional[ModelCacheManager] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Implement singleton pattern."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_models: int = 3,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize model cache manager.
|
||||
|
||||
Args:
|
||||
max_models: Maximum number of models to keep in memory
|
||||
disk_cache_dir: Directory for disk cache (embeddings)
|
||||
"""
|
||||
# Only initialize once (singleton pattern)
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self.model_cache = LRUCache(max_size=max_models)
|
||||
self.disk_cache_dir = Path(disk_cache_dir) if disk_cache_dir else None
|
||||
|
||||
if self.disk_cache_dir:
|
||||
self.disk_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Disk cache initialized at: {self.disk_cache_dir}")
|
||||
|
||||
logger.info(f"ModelCacheManager initialized (max_models={max_models})")
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
max_models: int = 3,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
) -> ModelCacheManager:
|
||||
"""
|
||||
Get singleton instance of ModelCacheManager.
|
||||
|
||||
Args:
|
||||
max_models: Maximum number of models to keep in memory
|
||||
disk_cache_dir: Directory for disk cache
|
||||
|
||||
Returns:
|
||||
ModelCacheManager instance
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls(max_models=max_models, disk_cache_dir=disk_cache_dir)
|
||||
return cls._instance
|
||||
|
||||
def get_or_load_model(
|
||||
self,
|
||||
model_key: str,
|
||||
loader_func: Callable[[], Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Get model from cache or load it.
|
||||
|
||||
Args:
|
||||
model_key: Unique identifier for the model
|
||||
loader_func: Function to load the model if not cached
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
# Try to get from cache
|
||||
model = self.model_cache.get(model_key)
|
||||
|
||||
if model is not None:
|
||||
logger.debug(f"Model cache HIT: {model_key}")
|
||||
return model
|
||||
|
||||
# Cache miss - load model
|
||||
logger.info(f"Model cache MISS: {model_key} - loading...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
model = loader_func()
|
||||
self.model_cache.put(model_key, model)
|
||||
self.model_cache.metrics.record_load()
|
||||
|
||||
load_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Model loaded successfully: {model_key} "
|
||||
f"(took {load_time:.2f}s)"
|
||||
)
|
||||
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_key}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def save_embeddings_to_disk(
|
||||
self,
|
||||
key: str,
|
||||
embeddings: Dict[int, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Save embeddings to disk cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
embeddings: Dictionary of embeddings to save
|
||||
"""
|
||||
if not self.disk_cache_dir:
|
||||
return
|
||||
|
||||
cache_file = self.disk_cache_dir / f"{key}.pkl"
|
||||
|
||||
try:
|
||||
with open(cache_file, "wb") as f:
|
||||
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
logger.info(f"Saved {len(embeddings)} embeddings to disk: {cache_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save embeddings to disk: {e}", exc_info=True)
|
||||
|
||||
def load_embeddings_from_disk(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[Dict[int, Any]]:
|
||||
"""
|
||||
Load embeddings from disk cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Dictionary of embeddings or None if not found
|
||||
"""
|
||||
if not self.disk_cache_dir:
|
||||
return None
|
||||
|
||||
cache_file = self.disk_cache_dir / f"{key}.pkl"
|
||||
|
||||
if not cache_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(cache_file, "rb") as f:
|
||||
embeddings = pickle.load(f)
|
||||
logger.info(f"Loaded {len(embeddings)} embeddings from disk: {cache_file}")
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load embeddings from disk: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all caches (memory and disk)."""
|
||||
self.model_cache.clear()
|
||||
|
||||
if self.disk_cache_dir and self.disk_cache_dir.exists():
|
||||
for cache_file in self.disk_cache_dir.glob("*.pkl"):
|
||||
try:
|
||||
cache_file.unlink()
|
||||
logger.info(f"Deleted disk cache file: {cache_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {cache_file}: {e}")
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache performance metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
metrics = self.model_cache.get_metrics()
|
||||
metrics["cache_size"] = self.model_cache.size()
|
||||
metrics["max_size"] = self.model_cache.max_size
|
||||
|
||||
if self.disk_cache_dir and self.disk_cache_dir.exists():
|
||||
disk_files = list(self.disk_cache_dir.glob("*.pkl"))
|
||||
metrics["disk_cache_files"] = len(disk_files)
|
||||
|
||||
# Calculate total disk cache size
|
||||
total_size = sum(f.stat().st_size for f in disk_files)
|
||||
metrics["disk_cache_size_mb"] = f"{total_size / 1024 / 1024:.2f}"
|
||||
|
||||
return metrics
|
||||
|
||||
def warm_up(
|
||||
self,
|
||||
model_loaders: Dict[str, Callable[[], Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Pre-load models on startup (warm-up).
|
||||
|
||||
Args:
|
||||
model_loaders: Dictionary of {model_key: loader_function}
|
||||
"""
|
||||
logger.info(f"Starting model warm-up ({len(model_loaders)} models)...")
|
||||
start_time = time.time()
|
||||
|
||||
for model_key, loader_func in model_loaders.items():
|
||||
try:
|
||||
self.get_or_load_model(model_key, loader_func)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to warm-up model {model_key}: {e}")
|
||||
|
||||
warm_up_time = time.time() - start_time
|
||||
logger.info(f"Model warm-up completed in {warm_up_time:.2f}s")
|
||||
|
|
@ -18,6 +18,8 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from transformers import pipeline
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
|
@ -42,7 +44,11 @@ class DocumentNER:
|
|||
- Phone numbers
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "dslim/bert-base-NER"):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "dslim/bert-base-NER",
|
||||
use_cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize NER extractor.
|
||||
|
||||
|
|
@ -52,14 +58,37 @@ class DocumentNER:
|
|||
Alternatives:
|
||||
- dslim/bert-base-NER-uncased
|
||||
- dbmdz/bert-large-cased-finetuned-conll03-english
|
||||
use_cache: Whether to use model cache (default: True)
|
||||
"""
|
||||
logger.info(f"Initializing NER with model: {model_name}")
|
||||
logger.info(f"Initializing NER with model: {model_name} (caching: {use_cache})")
|
||||
|
||||
self.ner_pipeline = pipeline(
|
||||
"ner",
|
||||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.use_cache = use_cache
|
||||
self.cache_manager = ModelCacheManager.get_instance() if use_cache else None
|
||||
|
||||
# Cache key for this model
|
||||
cache_key = f"ner_{model_name}"
|
||||
|
||||
if self.use_cache and self.cache_manager:
|
||||
# Load from cache or create new
|
||||
def loader():
|
||||
return pipeline(
|
||||
"ner",
|
||||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
)
|
||||
|
||||
self.ner_pipeline = self.cache_manager.get_or_load_model(
|
||||
cache_key,
|
||||
loader,
|
||||
)
|
||||
else:
|
||||
# Load without caching
|
||||
self.ner_pipeline = pipeline(
|
||||
"ner",
|
||||
model=model_name,
|
||||
aggregation_strategy="simple",
|
||||
)
|
||||
|
||||
# Compile regex patterns for efficiency
|
||||
self._compile_patterns()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ import numpy as np
|
|||
import torch
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
|
@ -48,6 +50,7 @@ class SemanticSearch:
|
|||
self,
|
||||
model_name: str = "all-MiniLM-L6-v2",
|
||||
cache_dir: str | None = None,
|
||||
use_cache: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize semantic search.
|
||||
|
|
@ -60,16 +63,38 @@ class SemanticSearch:
|
|||
- all-mpnet-base-v2 (420MB, highest quality)
|
||||
- all-MiniLM-L12-v2 (120MB, balanced)
|
||||
cache_dir: Directory to cache model
|
||||
use_cache: Whether to use model cache (default: True)
|
||||
"""
|
||||
logger.info(f"Initializing SemanticSearch with model: {model_name}")
|
||||
logger.info(
|
||||
f"Initializing SemanticSearch with model: {model_name} "
|
||||
f"(caching: {use_cache})"
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.model = SentenceTransformer(model_name, cache_folder=cache_dir)
|
||||
|
||||
# Storage for embeddings
|
||||
# In production, this should be in a vector database like Faiss or Milvus
|
||||
self.document_embeddings = {}
|
||||
self.document_metadata = {}
|
||||
self.use_cache = use_cache
|
||||
self.cache_manager = ModelCacheManager.get_instance(
|
||||
disk_cache_dir=cache_dir,
|
||||
) if use_cache else None
|
||||
|
||||
# Cache key for this model
|
||||
cache_key = f"semantic_search_{model_name}"
|
||||
|
||||
if self.use_cache and self.cache_manager:
|
||||
# Load model from cache
|
||||
def loader():
|
||||
return SentenceTransformer(model_name, cache_folder=cache_dir)
|
||||
|
||||
self.model = self.cache_manager.get_or_load_model(cache_key, loader)
|
||||
|
||||
# Try to load embeddings from disk
|
||||
embeddings = self.cache_manager.load_embeddings_from_disk("document_embeddings")
|
||||
self.document_embeddings = embeddings if embeddings else {}
|
||||
self.document_metadata = {}
|
||||
else:
|
||||
# Load without caching
|
||||
self.model = SentenceTransformer(model_name, cache_folder=cache_dir)
|
||||
self.document_embeddings = {}
|
||||
self.document_metadata = {}
|
||||
|
||||
logger.info("SemanticSearch initialized successfully")
|
||||
|
||||
|
|
@ -139,6 +164,13 @@ class SemanticSearch:
|
|||
self.document_metadata[doc_id] = metadata
|
||||
|
||||
logger.info(f"Indexed {len(documents)} documents successfully")
|
||||
|
||||
# Save embeddings to disk cache if enabled
|
||||
if self.use_cache and self.cache_manager:
|
||||
self.cache_manager.save_embeddings_to_disk(
|
||||
"document_embeddings",
|
||||
self.document_embeddings,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue