mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-19 21:16:56 +01:00
Unify, respect perms
[ci skip]
This commit is contained in:
parent
ccfc7d98b1
commit
5f26139a5f
3 changed files with 37 additions and 40 deletions
|
|
@ -1,20 +1,38 @@
|
|||
import logging
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.client import AIClient
|
||||
from paperless.ai.indexing import get_document_retriever
|
||||
from paperless.ai.indexing import load_index
|
||||
|
||||
logger = logging.getLogger("paperless.ai.chat")
|
||||
|
||||
|
||||
def chat_with_documents(prompt: str, user: User) -> str:
|
||||
retriever = get_document_retriever(top_k=5)
|
||||
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
|
||||
client = AIClient()
|
||||
|
||||
index = load_index()
|
||||
|
||||
doc_ids = [doc.pk for doc in documents]
|
||||
|
||||
# Filter only the node(s) that match the document IDs
|
||||
nodes = [
|
||||
node
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in doc_ids
|
||||
]
|
||||
|
||||
if len(nodes) == 0:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
return "Sorry, I couldn't find any content to answer your question."
|
||||
|
||||
local_index = VectorStoreIndex.from_documents(nodes)
|
||||
retriever = local_index.as_retriever(
|
||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||
)
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
|
|
@ -24,29 +42,3 @@ def chat_with_documents(prompt: str, user: User) -> str:
|
|||
response = query_engine.query(prompt)
|
||||
logger.debug("Document chat response: %s", response)
|
||||
return str(response)
|
||||
|
||||
|
||||
def chat_with_single_document(document, question: str, user):
|
||||
index = load_index()
|
||||
|
||||
# Filter only the node(s) belonging to this doc
|
||||
nodes = [
|
||||
node
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") == str(document.id)
|
||||
]
|
||||
|
||||
if not nodes:
|
||||
raise Exception("This document is not indexed yet.")
|
||||
|
||||
local_index = VectorStoreIndex.from_documents(nodes)
|
||||
|
||||
client = AIClient()
|
||||
|
||||
engine = RetrieverQueryEngine.from_args(
|
||||
retriever=local_index.as_retriever(similarity_top_k=3),
|
||||
llm=client.llm,
|
||||
)
|
||||
|
||||
response = engine.query(question)
|
||||
return str(response)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue