diff --git a/src/documents/index.py b/src/documents/index.py index 406445675..ba766bb31 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -287,77 +287,76 @@ class DelayedQuery: self.first_score = None self.filter_queryset = filter_queryset self.suggested_correction = None - self.manual_hits: list | None = None - ordering = self.query_params.get("ordering") - ordering_check = ordering.lstrip("-") if ordering else None - self._manual_sort_requested = ( - ordering_check.startswith("custom_field_") if ordering_check else False - ) + self._manual_hits_cache: list | None = None def __len__(self) -> int: - manual_hits = self._get_manual_hits() + manual_hits = self._manual_hits() if manual_hits is not None: return len(manual_hits) page = self[0:1] return len(page) - def _get_manual_hits(self): - if not self._manual_sort_requested: + def _manual_sort_requested(self): + ordering = self.query_params.get("ordering", "") + return ordering.lstrip("-").startswith("custom_field_") + + def _manual_hits(self): + if not self._manual_sort_requested(): return None - if self.manual_hits is None: - self.manual_hits = self._build_manual_hits() - return self.manual_hits + if self._manual_hits_cache is None: + q, mask, suggested_correction = self._get_query() + self.suggested_correction = suggested_correction - def _build_manual_hits(self): - q, mask, suggested_correction = self._get_query() - self.suggested_correction = suggested_correction + results = self.searcher.search( + q, + mask=mask, + filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader), + limit=None, + ) + results.fragmenter = highlight.ContextFragmenter(surround=50) + results.formatter = HtmlFormatter(tagname="span", between=" ... ") - results = self.searcher.search( - q, - mask=mask, - filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader), - limit=None, - ) - results.fragmenter = highlight.ContextFragmenter(surround=50) - results.formatter = HtmlFormatter(tagname="span", between=" ... ") + if not self.first_score and len(results) > 0: + self.first_score = results[0].score - if not self.first_score and len(results) > 0: - self.first_score = results[0].score - - if self.first_score: - results.top_n = list( - map( - lambda hit: ( - (hit[0] / self.first_score) if self.first_score else None, - hit[1], + if self.first_score: + results.top_n = list( + map( + lambda hit: ( + (hit[0] / self.first_score) if self.first_score else None, + hit[1], + ), + results.top_n, ), - results.top_n, + ) + + hits_by_id = {hit["id"]: hit for hit in results} + matching_ids = list(hits_by_id.keys()) + + ordered_ids = list( + self.filter_queryset.filter(id__in=matching_ids).values_list( + "id", + flat=True, ), ) + ordered_ids = list(dict.fromkeys(ordered_ids)) - hits_by_id = {hit["id"]: hit for hit in results} - matching_ids = list(hits_by_id.keys()) - - ordered_ids = list( - self.filter_queryset.filter(id__in=matching_ids).values_list( - "id", - flat=True, - ), - ) - ordered_ids = list(dict.fromkeys(ordered_ids)) - - return [hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id] + self._manual_hits_cache = [ + hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id + ] + return self._manual_hits_cache def __getitem__(self, item): if item.start in self.saved_results: return self.saved_results[item.start] - manual_hits = self._get_manual_hits() + manual_hits = self._manual_hits() if manual_hits is not None: start = 0 if item.start is None else item.start stop = item.stop - page = manual_hits[start:stop] if stop is not None else manual_hits[start:] + hits = manual_hits[start:stop] if stop is not None else manual_hits[start:] + page = ManualResultsPage(hits) self.saved_results[start] = page return page @@ -395,6 +394,20 @@ class DelayedQuery: return page +class ManualResultsPage(list): + def __init__(self, hits): + super().__init__(hits) + self.results = ManualResults(hits) + + +class ManualResults: + def __init__(self, hits): + self._docnums = [hit.docnum for hit in hits] + + def docs(self): + return self._docnums + + class LocalDateParser(English): def reverse_timezone_offset(self, d): return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone( diff --git a/src/paperless/views.py b/src/paperless/views.py index aa0e2b9be..e79c0e668 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -70,22 +70,18 @@ class StandardPagination(PageNumberPagination): def get_all_result_ids(self): query = self.page.paginator.object_list if isinstance(query, DelayedQuery): - manual_hits = getattr(query, "manual_hits", None) - if manual_hits is not None: - ids = [hit["id"] for hit in manual_hits] - else: - first_page = query.saved_results.get(0) - if not first_page: - return [] + try: ids = [ query.searcher.ixreader.stored_fields( doc_num, )["id"] - for doc_num in first_page.results.docs() + for doc_num in query.saved_results.get(0).results.docs() ] + except Exception: + pass else: - ids = list(self.page.paginator.object_list.values_list("pk", flat=True)) - return list(ids) + ids = self.page.paginator.object_list.values_list("pk", flat=True) + return ids def get_paginated_response_schema(self, schema): response_schema = super().get_paginated_response_schema(schema)