mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-12-11 00:57:09 +01:00
the document classifier is now stateless
This commit is contained in:
parent
3e50e51b8a
commit
05f20c19c3
4 changed files with 12 additions and 19 deletions
|
|
@ -34,7 +34,6 @@ class DocumentClassifier(object):
|
|||
self.tags_classifier = None
|
||||
self.correspondent_classifier = None
|
||||
self.document_type_classifier = None
|
||||
self.X = None
|
||||
|
||||
def reload(self):
|
||||
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
|
||||
|
|
@ -167,14 +166,10 @@ class DocumentClassifier(object):
|
|||
"classifier."
|
||||
)
|
||||
|
||||
def update(self, document):
|
||||
self.X = self.data_vectorizer.transform(
|
||||
[preprocess_content(document.content)]
|
||||
)
|
||||
|
||||
def predict_correspondent(self):
|
||||
def predict_correspondent(self, content):
|
||||
if self.correspondent_classifier:
|
||||
y = self.correspondent_classifier.predict(self.X)
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
y = self.correspondent_classifier.predict(X)
|
||||
correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0]
|
||||
if correspondent_id != -1:
|
||||
return correspondent_id
|
||||
|
|
@ -183,9 +178,10 @@ class DocumentClassifier(object):
|
|||
else:
|
||||
return None
|
||||
|
||||
def predict_document_type(self):
|
||||
def predict_document_type(self, content):
|
||||
if self.document_type_classifier:
|
||||
y = self.document_type_classifier.predict(self.X)
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
y = self.document_type_classifier.predict(X)
|
||||
document_type_id = self.document_type_binarizer.inverse_transform(y)[0]
|
||||
if document_type_id != -1:
|
||||
return document_type_id
|
||||
|
|
@ -194,9 +190,10 @@ class DocumentClassifier(object):
|
|||
else:
|
||||
return None
|
||||
|
||||
def predict_tags(self):
|
||||
def predict_tags(self, content):
|
||||
if self.tags_classifier:
|
||||
y = self.tags_classifier.predict(self.X)
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
return tags_ids
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue