mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-20 05:26:53 +01:00
Fixup some tests
This commit is contained in:
parent
9183bfc0a4
commit
4a28be233e
5 changed files with 167 additions and 128 deletions
|
|
@ -1,95 +1,93 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from llama_index.core.llms import ChatMessage
|
||||
|
||||
from paperless.ai.client import AIClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
settings.LLM_BACKEND = "openai"
|
||||
settings.LLM_MODEL = "gpt-3.5-turbo"
|
||||
settings.LLM_API_KEY = "test-api-key"
|
||||
yield settings
|
||||
def mock_ai_config():
|
||||
with patch("paperless.ai.client.AIConfig") as MockAIConfig:
|
||||
mock_config = MagicMock()
|
||||
MockAIConfig.return_value = mock_config
|
||||
yield mock_config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless.ai.client.AIClient._run_openai_query")
|
||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
|
||||
def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
|
||||
mock_settings.LLM_BACKEND = "openai"
|
||||
mock_openai_query.return_value = "OpenAI response"
|
||||
@pytest.fixture
|
||||
def mock_ollama_llm():
|
||||
with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
|
||||
yield MockOllamaLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_llm():
|
||||
with patch("paperless.ai.client.OpenAI") as MockOpenAI:
|
||||
yield MockOpenAI
|
||||
|
||||
|
||||
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("Test prompt")
|
||||
assert result == "OpenAI response"
|
||||
mock_openai_query.assert_called_once_with("Test prompt")
|
||||
mock_ollama_query.assert_not_called()
|
||||
|
||||
mock_ollama_llm.assert_called_once_with(
|
||||
model="test_model",
|
||||
base_url="http://test-url",
|
||||
)
|
||||
assert client.llm == mock_ollama_llm.return_value
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless.ai.client.AIClient._run_openai_query")
|
||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
|
||||
def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
|
||||
mock_settings.LLM_BACKEND = "ollama"
|
||||
mock_ollama_query.return_value = "Ollama response"
|
||||
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||
mock_ai_config.llm_backend = "openai"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.openai_api_key = "test_api_key"
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("Test prompt")
|
||||
assert result == "Ollama response"
|
||||
mock_ollama_query.assert_called_once_with("Test prompt")
|
||||
mock_openai_query.assert_not_called()
|
||||
|
||||
mock_openai_llm.assert_called_once_with(
|
||||
model="test_model",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
assert client.llm == mock_openai_llm.return_value
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_llm_query_unsupported_backend(mock_settings):
|
||||
mock_settings.LLM_BACKEND = "unsupported"
|
||||
client = AIClient()
|
||||
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "unsupported"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
|
||||
client.run_llm_query("Test prompt")
|
||||
AIClient()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_openai_query(httpx_mock, mock_settings):
|
||||
mock_settings.LLM_BACKEND = "openai"
|
||||
httpx_mock.add_response(
|
||||
url="https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"choices": [{"message": {"content": "OpenAI response"}}],
|
||||
},
|
||||
)
|
||||
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.complete.return_value = "test_result"
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("Test prompt")
|
||||
assert result == "OpenAI response"
|
||||
result = client.run_llm_query("test_prompt")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request.method == "POST"
|
||||
assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
|
||||
assert request.headers["Content-Type"] == "application/json"
|
||||
assert json.loads(request.content) == {
|
||||
"model": mock_settings.LLM_MODEL,
|
||||
"messages": [{"role": "user", "content": "Test prompt"}],
|
||||
"temperature": 0.3,
|
||||
}
|
||||
mock_llm_instance.complete.assert_called_once_with("test_prompt")
|
||||
assert result == "test_result"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_run_ollama_query(httpx_mock, mock_settings):
|
||||
mock_settings.LLM_BACKEND = "ollama"
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:11434/api/chat",
|
||||
json={"message": {"content": "Ollama response"}},
|
||||
)
|
||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.chat.return_value = "test_chat_result"
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("Test prompt")
|
||||
assert result == "Ollama response"
|
||||
messages = [ChatMessage(role="user", content="Hello")]
|
||||
result = client.run_chat(messages)
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request.method == "POST"
|
||||
assert json.loads(request.content) == {
|
||||
"model": mock_settings.LLM_MODEL,
|
||||
"messages": [{"role": "user", "content": "Test prompt"}],
|
||||
"stream": False,
|
||||
}
|
||||
mock_llm_instance.chat.assert_called_once_with(messages)
|
||||
assert result == "test_chat_result"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue