paperless-ngx/src/paperless/ai/chat.py

77 lines
2.2 KiB
Python
Raw Normal View History

2025-04-24 23:41:31 -07:00
import logging
import sys
2025-04-24 23:41:31 -07:00
2025-04-24 23:56:51 -07:00
from llama_index.core import VectorStoreIndex
2025-04-25 10:06:26 -07:00
from llama_index.core.prompts import PromptTemplate
2025-04-24 23:41:31 -07:00
from llama_index.core.query_engine import RetrieverQueryEngine
2025-04-25 00:09:33 -07:00
from documents.models import Document
2025-04-24 23:41:31 -07:00
from paperless.ai.client import AIClient
2025-04-24 23:56:51 -07:00
from paperless.ai.indexing import load_index
2025-04-24 23:41:31 -07:00
logger = logging.getLogger("paperless.ai.chat")
2025-04-25 10:06:26 -07:00
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:""",
)
2025-04-24 23:41:31 -07:00
2025-04-25 10:06:26 -07:00
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
2025-04-24 23:56:51 -07:00
index = load_index()
2025-04-25 00:09:33 -07:00
doc_ids = [doc.pk for doc in documents]
# Filter only the node(s) that match the document IDs
2025-04-24 23:56:51 -07:00
nodes = [
node
for node in index.docstore.docs.values()
2025-04-25 00:09:33 -07:00
if node.metadata.get("document_id") in doc_ids
2025-04-24 23:56:51 -07:00
]
2025-04-25 00:09:33 -07:00
if len(nodes) == 0:
logger.warning("No nodes found for the given documents.")
return "Sorry, I couldn't find any content to answer your question."
2025-04-24 23:56:51 -07:00
2025-04-25 10:06:26 -07:00
local_index = VectorStoreIndex(nodes=nodes)
2025-04-25 00:09:33 -07:00
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
)
2025-04-24 23:56:51 -07:00
2025-04-25 10:06:26 -07:00
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
# TODO: include document metadata in the context
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
else:
top_nodes = retriever.retrieve(query_str)
context = "\n\n".join(
2025-04-25 10:53:38 -07:00
f"TITLE: {node.metadata.get('title')}\n{node.text[:500]}"
for node in top_nodes
2025-04-25 10:06:26 -07:00
)
prompt = CHAT_PROMPT_TMPL.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
2025-04-25 00:09:33 -07:00
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
2025-04-24 23:56:51 -07:00
llm=client.llm,
2025-04-25 10:06:26 -07:00
streaming=True,
2025-04-24 23:56:51 -07:00
)
2025-04-25 00:09:33 -07:00
logger.debug("Document chat prompt: %s", prompt)
2025-04-25 10:06:26 -07:00
response_stream = query_engine.query(prompt)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()