mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-09 16:25:33 +01:00
feat(ml): Complete ML model caching implementation with settings and startup integration
Co-authored-by: dawnsystem <42047891+dawnsystem@users.noreply.github.com>
This commit is contained in:
parent
894f7e231d
commit
f1161ce5fb
3 changed files with 338 additions and 0 deletions
|
|
@ -30,4 +30,38 @@ class DocumentsConfig(AppConfig):
|
|||
|
||||
import documents.schema # noqa: F401
|
||||
|
||||
# Initialize ML model cache with warm-up if configured
|
||||
self._initialize_ml_cache()
|
||||
|
||||
AppConfig.ready(self)
|
||||
|
||||
def _initialize_ml_cache(self):
|
||||
"""Initialize ML model cache and optionally warm up models."""
|
||||
from django.conf import settings
|
||||
|
||||
# Only initialize if ML features are enabled
|
||||
if not getattr(settings, "PAPERLESS_ENABLE_ML_FEATURES", False):
|
||||
return
|
||||
|
||||
# Initialize cache manager with settings
|
||||
from documents.ml.model_cache import ModelCacheManager
|
||||
|
||||
max_models = getattr(settings, "PAPERLESS_ML_CACHE_MAX_MODELS", 3)
|
||||
cache_dir = getattr(settings, "PAPERLESS_ML_MODEL_CACHE", None)
|
||||
|
||||
cache_manager = ModelCacheManager.get_instance(
|
||||
max_models=max_models,
|
||||
disk_cache_dir=str(cache_dir) if cache_dir else None,
|
||||
)
|
||||
|
||||
# Warm up models if configured
|
||||
warmup_enabled = getattr(settings, "PAPERLESS_ML_CACHE_WARMUP", False)
|
||||
if warmup_enabled:
|
||||
try:
|
||||
from documents.ai_scanner import get_ai_scanner
|
||||
scanner = get_ai_scanner()
|
||||
scanner.warm_up_models()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger("paperless.documents")
|
||||
logger.warning(f"Failed to warm up ML models: {e}")
|
||||
|
|
|
|||
292
src/documents/tests/test_ml_cache.py
Normal file
292
src/documents/tests/test_ml_cache.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
"""
|
||||
Tests for ML model caching functionality.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.ml.model_cache import (
|
||||
CacheMetrics,
|
||||
LRUCache,
|
||||
ModelCacheManager,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheMetrics(TestCase):
|
||||
"""Test cache metrics tracking."""
|
||||
|
||||
def test_record_hit(self):
|
||||
"""Test recording cache hits."""
|
||||
metrics = CacheMetrics()
|
||||
self.assertEqual(metrics.hits, 0)
|
||||
|
||||
metrics.record_hit()
|
||||
self.assertEqual(metrics.hits, 1)
|
||||
|
||||
metrics.record_hit()
|
||||
self.assertEqual(metrics.hits, 2)
|
||||
|
||||
def test_record_miss(self):
|
||||
"""Test recording cache misses."""
|
||||
metrics = CacheMetrics()
|
||||
self.assertEqual(metrics.misses, 0)
|
||||
|
||||
metrics.record_miss()
|
||||
self.assertEqual(metrics.misses, 1)
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting cache statistics."""
|
||||
metrics = CacheMetrics()
|
||||
|
||||
# Initial stats
|
||||
stats = metrics.get_stats()
|
||||
self.assertEqual(stats["hits"], 0)
|
||||
self.assertEqual(stats["misses"], 0)
|
||||
self.assertEqual(stats["hit_rate"], "0.00%")
|
||||
|
||||
# After some hits and misses
|
||||
metrics.record_hit()
|
||||
metrics.record_hit()
|
||||
metrics.record_hit()
|
||||
metrics.record_miss()
|
||||
|
||||
stats = metrics.get_stats()
|
||||
self.assertEqual(stats["hits"], 3)
|
||||
self.assertEqual(stats["misses"], 1)
|
||||
self.assertEqual(stats["total_requests"], 4)
|
||||
self.assertEqual(stats["hit_rate"], "75.00%")
|
||||
|
||||
def test_reset(self):
|
||||
"""Test resetting metrics."""
|
||||
metrics = CacheMetrics()
|
||||
metrics.record_hit()
|
||||
metrics.record_miss()
|
||||
|
||||
metrics.reset()
|
||||
|
||||
stats = metrics.get_stats()
|
||||
self.assertEqual(stats["hits"], 0)
|
||||
self.assertEqual(stats["misses"], 0)
|
||||
|
||||
|
||||
class TestLRUCache(TestCase):
|
||||
"""Test LRU cache implementation."""
|
||||
|
||||
def test_put_and_get(self):
|
||||
"""Test basic cache operations."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
|
||||
self.assertEqual(cache.get("key1"), "value1")
|
||||
self.assertEqual(cache.get("key2"), "value2")
|
||||
|
||||
def test_cache_miss(self):
|
||||
"""Test cache miss returns None."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
result = cache.get("nonexistent")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Test LRU eviction policy."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
cache.put("key3", "value3") # Should evict key1
|
||||
|
||||
self.assertIsNone(cache.get("key1")) # Evicted
|
||||
self.assertEqual(cache.get("key2"), "value2")
|
||||
self.assertEqual(cache.get("key3"), "value3")
|
||||
|
||||
def test_lru_update_access_order(self):
|
||||
"""Test that accessing an item updates its position."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
cache.get("key1") # Access key1, making it most recent
|
||||
cache.put("key3", "value3") # Should evict key2, not key1
|
||||
|
||||
self.assertEqual(cache.get("key1"), "value1")
|
||||
self.assertIsNone(cache.get("key2")) # Evicted
|
||||
self.assertEqual(cache.get("key3"), "value3")
|
||||
|
||||
def test_cache_size(self):
|
||||
"""Test cache size tracking."""
|
||||
cache = LRUCache(max_size=3)
|
||||
|
||||
self.assertEqual(cache.size(), 0)
|
||||
|
||||
cache.put("key1", "value1")
|
||||
self.assertEqual(cache.size(), 1)
|
||||
|
||||
cache.put("key2", "value2")
|
||||
self.assertEqual(cache.size(), 2)
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing cache."""
|
||||
cache = LRUCache(max_size=2)
|
||||
|
||||
cache.put("key1", "value1")
|
||||
cache.put("key2", "value2")
|
||||
|
||||
cache.clear()
|
||||
|
||||
self.assertEqual(cache.size(), 0)
|
||||
self.assertIsNone(cache.get("key1"))
|
||||
self.assertIsNone(cache.get("key2"))
|
||||
|
||||
|
||||
class TestModelCacheManager(TestCase):
|
||||
"""Test model cache manager."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# Reset singleton instance for each test
|
||||
ModelCacheManager._instance = None
|
||||
|
||||
def test_singleton_pattern(self):
|
||||
"""Test that ModelCacheManager is a singleton."""
|
||||
instance1 = ModelCacheManager.get_instance()
|
||||
instance2 = ModelCacheManager.get_instance()
|
||||
|
||||
self.assertIs(instance1, instance2)
|
||||
|
||||
def test_get_or_load_model_first_time(self):
|
||||
"""Test loading a model for the first time (cache miss)."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
# Mock loader function
|
||||
mock_model = mock.Mock()
|
||||
loader = mock.Mock(return_value=mock_model)
|
||||
|
||||
# Load model
|
||||
result = cache_manager.get_or_load_model("test_model", loader)
|
||||
|
||||
# Verify loader was called
|
||||
loader.assert_called_once()
|
||||
self.assertIs(result, mock_model)
|
||||
|
||||
def test_get_or_load_model_cached(self):
|
||||
"""Test loading a model from cache (cache hit)."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
# Mock loader function
|
||||
mock_model = mock.Mock()
|
||||
loader = mock.Mock(return_value=mock_model)
|
||||
|
||||
# Load model first time
|
||||
cache_manager.get_or_load_model("test_model", loader)
|
||||
|
||||
# Load model second time (should be cached)
|
||||
result = cache_manager.get_or_load_model("test_model", loader)
|
||||
|
||||
# Verify loader was only called once
|
||||
loader.assert_called_once()
|
||||
self.assertIs(result, mock_model)
|
||||
|
||||
def test_disk_cache_embeddings(self):
|
||||
"""Test saving and loading embeddings to/from disk."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache_manager = ModelCacheManager.get_instance(
|
||||
disk_cache_dir=tmpdir,
|
||||
)
|
||||
|
||||
# Create test embeddings
|
||||
embeddings = {
|
||||
1: "embedding1",
|
||||
2: "embedding2",
|
||||
3: "embedding3",
|
||||
}
|
||||
|
||||
# Save to disk
|
||||
cache_manager.save_embeddings_to_disk("test_embeddings", embeddings)
|
||||
|
||||
# Verify file was created
|
||||
cache_file = Path(tmpdir) / "test_embeddings.pkl"
|
||||
self.assertTrue(cache_file.exists())
|
||||
|
||||
# Load from disk
|
||||
loaded = cache_manager.load_embeddings_from_disk("test_embeddings")
|
||||
|
||||
# Verify embeddings match
|
||||
self.assertEqual(loaded, embeddings)
|
||||
|
||||
def test_get_metrics(self):
|
||||
"""Test getting cache metrics."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
# Mock loader
|
||||
loader = mock.Mock(return_value=mock.Mock())
|
||||
|
||||
# Generate some cache activity
|
||||
cache_manager.get_or_load_model("model1", loader)
|
||||
cache_manager.get_or_load_model("model1", loader) # Cache hit
|
||||
cache_manager.get_or_load_model("model2", loader)
|
||||
|
||||
# Get metrics
|
||||
metrics = cache_manager.get_metrics()
|
||||
|
||||
# Verify metrics structure
|
||||
self.assertIn("hits", metrics)
|
||||
self.assertIn("misses", metrics)
|
||||
self.assertIn("cache_size", metrics)
|
||||
self.assertIn("max_size", metrics)
|
||||
|
||||
# Verify hit/miss counts
|
||||
self.assertEqual(metrics["hits"], 1) # One cache hit
|
||||
self.assertEqual(metrics["misses"], 2) # Two cache misses
|
||||
|
||||
def test_clear_all(self):
|
||||
"""Test clearing all caches."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache_manager = ModelCacheManager.get_instance(
|
||||
disk_cache_dir=tmpdir,
|
||||
)
|
||||
|
||||
# Add some models to cache
|
||||
loader = mock.Mock(return_value=mock.Mock())
|
||||
cache_manager.get_or_load_model("model1", loader)
|
||||
|
||||
# Add embeddings to disk
|
||||
embeddings = {1: "embedding1"}
|
||||
cache_manager.save_embeddings_to_disk("test", embeddings)
|
||||
|
||||
# Clear all
|
||||
cache_manager.clear_all()
|
||||
|
||||
# Verify memory cache is cleared
|
||||
self.assertEqual(cache_manager.model_cache.size(), 0)
|
||||
|
||||
# Verify disk cache is cleared
|
||||
cache_file = Path(tmpdir) / "test.pkl"
|
||||
self.assertFalse(cache_file.exists())
|
||||
|
||||
def test_warm_up(self):
|
||||
"""Test model warm-up functionality."""
|
||||
cache_manager = ModelCacheManager.get_instance()
|
||||
|
||||
# Create mock loaders
|
||||
model1 = mock.Mock()
|
||||
model2 = mock.Mock()
|
||||
|
||||
loaders = {
|
||||
"model1": mock.Mock(return_value=model1),
|
||||
"model2": mock.Mock(return_value=model2),
|
||||
}
|
||||
|
||||
# Warm up
|
||||
cache_manager.warm_up(loaders)
|
||||
|
||||
# Verify all loaders were called
|
||||
for loader in loaders.values():
|
||||
loader.assert_called_once()
|
||||
|
||||
# Verify models are cached
|
||||
self.assertEqual(cache_manager.model_cache.size(), 2)
|
||||
|
|
@ -1195,6 +1195,18 @@ PAPERLESS_ML_MODEL_CACHE: Final[Path | None] = __get_optional_path(
|
|||
"PAPERLESS_ML_MODEL_CACHE",
|
||||
)
|
||||
|
||||
# ML Model Cache Settings
|
||||
# Maximum number of models to keep in memory cache (LRU eviction)
|
||||
PAPERLESS_ML_CACHE_MAX_MODELS: Final[int] = int(
|
||||
os.getenv("PAPERLESS_ML_CACHE_MAX_MODELS", "3"),
|
||||
)
|
||||
|
||||
# Enable model warm-up on startup (preload models for faster first use)
|
||||
PAPERLESS_ML_CACHE_WARMUP: Final[bool] = __get_boolean(
|
||||
"PAPERLESS_ML_CACHE_WARMUP",
|
||||
default=False,
|
||||
)
|
||||
|
||||
OCR_COLOR_CONVERSION_STRATEGY = os.getenv(
|
||||
"PAPERLESS_OCR_COLOR_CONVERSION_STRATEGY",
|
||||
"RGB",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue