paperless-ngx/src/paperless/tests/test_ai_client.py

96 lines
3.1 KiB
Python
Raw Normal View History

2025-04-21 12:04:20 -07:00
import json
from unittest.mock import patch
import pytest
from django.conf import settings
2025-04-23 19:24:32 -07:00
from paperless.ai.client import AIClient
2025-04-21 12:04:20 -07:00
@pytest.fixture
def mock_settings():
settings.LLM_BACKEND = "openai"
settings.LLM_MODEL = "gpt-3.5-turbo"
settings.LLM_API_KEY = "test-api-key"
yield settings
2025-04-23 19:24:32 -07:00
@pytest.mark.django_db
@patch("paperless.ai.client.AIClient._run_openai_query")
@patch("paperless.ai.client.AIClient._run_ollama_query")
2025-04-21 12:04:20 -07:00
def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
2025-04-23 19:24:32 -07:00
mock_settings.LLM_BACKEND = "openai"
2025-04-21 12:04:20 -07:00
mock_openai_query.return_value = "OpenAI response"
2025-04-23 19:24:32 -07:00
client = AIClient()
result = client.run_llm_query("Test prompt")
2025-04-21 12:04:20 -07:00
assert result == "OpenAI response"
mock_openai_query.assert_called_once_with("Test prompt")
mock_ollama_query.assert_not_called()
2025-04-23 19:24:32 -07:00
@pytest.mark.django_db
@patch("paperless.ai.client.AIClient._run_openai_query")
@patch("paperless.ai.client.AIClient._run_ollama_query")
2025-04-21 12:04:20 -07:00
def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
mock_settings.LLM_BACKEND = "ollama"
mock_ollama_query.return_value = "Ollama response"
2025-04-23 19:24:32 -07:00
client = AIClient()
result = client.run_llm_query("Test prompt")
2025-04-21 12:04:20 -07:00
assert result == "Ollama response"
mock_ollama_query.assert_called_once_with("Test prompt")
mock_openai_query.assert_not_called()
2025-04-23 19:24:32 -07:00
@pytest.mark.django_db
2025-04-21 12:04:20 -07:00
def test_run_llm_query_unsupported_backend(mock_settings):
mock_settings.LLM_BACKEND = "unsupported"
2025-04-23 19:24:32 -07:00
client = AIClient()
2025-04-21 12:04:20 -07:00
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
2025-04-23 19:24:32 -07:00
client.run_llm_query("Test prompt")
2025-04-21 12:04:20 -07:00
2025-04-23 19:24:32 -07:00
@pytest.mark.django_db
2025-04-21 12:04:20 -07:00
def test_run_openai_query(httpx_mock, mock_settings):
2025-04-23 19:24:32 -07:00
mock_settings.LLM_BACKEND = "openai"
2025-04-21 12:04:20 -07:00
httpx_mock.add_response(
2025-04-23 19:24:32 -07:00
url="https://api.openai.com/v1/chat/completions",
2025-04-21 12:04:20 -07:00
json={
"choices": [{"message": {"content": "OpenAI response"}}],
},
)
2025-04-23 19:24:32 -07:00
client = AIClient()
result = client.run_llm_query("Test prompt")
2025-04-21 12:04:20 -07:00
assert result == "OpenAI response"
request = httpx_mock.get_request()
assert request.method == "POST"
assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
assert request.headers["Content-Type"] == "application/json"
assert json.loads(request.content) == {
"model": mock_settings.LLM_MODEL,
"messages": [{"role": "user", "content": "Test prompt"}],
"temperature": 0.3,
}
2025-04-23 19:24:32 -07:00
@pytest.mark.django_db
2025-04-21 12:04:20 -07:00
def test_run_ollama_query(httpx_mock, mock_settings):
2025-04-23 19:24:32 -07:00
mock_settings.LLM_BACKEND = "ollama"
2025-04-21 12:04:20 -07:00
httpx_mock.add_response(
2025-04-23 19:24:32 -07:00
url="http://localhost:11434/api/chat",
2025-04-21 12:04:20 -07:00
json={"message": {"content": "Ollama response"}},
)
2025-04-23 19:24:32 -07:00
client = AIClient()
result = client.run_llm_query("Test prompt")
2025-04-21 12:04:20 -07:00
assert result == "Ollama response"
request = httpx_mock.get_request()
assert request.method == "POST"
assert json.loads(request.content) == {
"model": mock_settings.LLM_MODEL,
"messages": [{"role": "user", "content": "Test prompt"}],
"stream": False,
}