paperless-ngx/src/paperless/tests/test_ai_client.py

94 lines
2.6 KiB
Python
Raw Normal View History

2025-04-25 00:59:46 -07:00
from unittest.mock import MagicMock
2025-04-21 12:04:20 -07:00
from unittest.mock import patch
import pytest
2025-04-25 00:59:46 -07:00
from llama_index.core.llms import ChatMessage
2025-04-21 12:04:20 -07:00
2025-04-23 19:24:32 -07:00
from paperless.ai.client import AIClient
2025-04-21 12:04:20 -07:00
@pytest.fixture
2025-04-25 00:59:46 -07:00
def mock_ai_config():
with patch("paperless.ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock()
MockAIConfig.return_value = mock_config
yield mock_config
@pytest.fixture
def mock_ollama_llm():
with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
yield MockOllamaLLM
@pytest.fixture
def mock_openai_llm():
with patch("paperless.ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
2025-04-23 19:24:32 -07:00
client = AIClient()
2025-04-21 12:04:20 -07:00
2025-04-25 00:59:46 -07:00
mock_ollama_llm.assert_called_once_with(
model="test_model",
base_url="http://test-url",
)
assert client.llm == mock_ollama_llm.return_value
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
mock_ai_config.llm_backend = "openai"
mock_ai_config.llm_model = "test_model"
mock_ai_config.openai_api_key = "test_api_key"
2025-04-21 12:04:20 -07:00
2025-04-23 19:24:32 -07:00
client = AIClient()
2025-04-25 00:59:46 -07:00
mock_openai_llm.assert_called_once_with(
model="test_model",
api_key="test_api_key",
)
assert client.llm == mock_openai_llm.return_value
def test_get_llm_unsupported_backend(mock_ai_config):
mock_ai_config.llm_backend = "unsupported"
2025-04-21 12:04:20 -07:00
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
2025-04-25 00:59:46 -07:00
AIClient()
2025-04-21 12:04:20 -07:00
2025-04-25 00:59:46 -07:00
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.complete.return_value = "test_result"
2025-04-21 12:04:20 -07:00
2025-04-23 19:24:32 -07:00
client = AIClient()
2025-04-25 00:59:46 -07:00
result = client.run_llm_query("test_prompt")
mock_llm_instance.complete.assert_called_once_with("test_prompt")
assert result == "test_result"
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.return_value = "test_chat_result"
2025-04-21 12:04:20 -07:00
2025-04-23 19:24:32 -07:00
client = AIClient()
2025-04-25 00:59:46 -07:00
messages = [ChatMessage(role="user", content="Hello")]
result = client.run_chat(messages)
mock_llm_instance.chat.assert_called_once_with(messages)
assert result == "test_chat_result"