This commit is contained in:
shamoon 2025-11-16 10:20:38 -08:00
parent 29de66bf23
commit 3b3e0136b4
2 changed files with 66 additions and 57 deletions

View file

@ -287,77 +287,76 @@ class DelayedQuery:
self.first_score = None self.first_score = None
self.filter_queryset = filter_queryset self.filter_queryset = filter_queryset
self.suggested_correction = None self.suggested_correction = None
self.manual_hits: list | None = None self._manual_hits_cache: list | None = None
ordering = self.query_params.get("ordering")
ordering_check = ordering.lstrip("-") if ordering else None
self._manual_sort_requested = (
ordering_check.startswith("custom_field_") if ordering_check else False
)
def __len__(self) -> int: def __len__(self) -> int:
manual_hits = self._get_manual_hits() manual_hits = self._manual_hits()
if manual_hits is not None: if manual_hits is not None:
return len(manual_hits) return len(manual_hits)
page = self[0:1] page = self[0:1]
return len(page) return len(page)
def _get_manual_hits(self): def _manual_sort_requested(self):
if not self._manual_sort_requested: ordering = self.query_params.get("ordering", "")
return ordering.lstrip("-").startswith("custom_field_")
def _manual_hits(self):
if not self._manual_sort_requested():
return None return None
if self.manual_hits is None: if self._manual_hits_cache is None:
self.manual_hits = self._build_manual_hits() q, mask, suggested_correction = self._get_query()
return self.manual_hits self.suggested_correction = suggested_correction
def _build_manual_hits(self): results = self.searcher.search(
q, mask, suggested_correction = self._get_query() q,
self.suggested_correction = suggested_correction mask=mask,
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
limit=None,
)
results.fragmenter = highlight.ContextFragmenter(surround=50)
results.formatter = HtmlFormatter(tagname="span", between=" ... ")
results = self.searcher.search( if not self.first_score and len(results) > 0:
q, self.first_score = results[0].score
mask=mask,
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
limit=None,
)
results.fragmenter = highlight.ContextFragmenter(surround=50)
results.formatter = HtmlFormatter(tagname="span", between=" ... ")
if not self.first_score and len(results) > 0: if self.first_score:
self.first_score = results[0].score results.top_n = list(
map(
if self.first_score: lambda hit: (
results.top_n = list( (hit[0] / self.first_score) if self.first_score else None,
map( hit[1],
lambda hit: ( ),
(hit[0] / self.first_score) if self.first_score else None, results.top_n,
hit[1],
), ),
results.top_n, )
hits_by_id = {hit["id"]: hit for hit in results}
matching_ids = list(hits_by_id.keys())
ordered_ids = list(
self.filter_queryset.filter(id__in=matching_ids).values_list(
"id",
flat=True,
), ),
) )
ordered_ids = list(dict.fromkeys(ordered_ids))
hits_by_id = {hit["id"]: hit for hit in results} self._manual_hits_cache = [
matching_ids = list(hits_by_id.keys()) hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id
]
ordered_ids = list( return self._manual_hits_cache
self.filter_queryset.filter(id__in=matching_ids).values_list(
"id",
flat=True,
),
)
ordered_ids = list(dict.fromkeys(ordered_ids))
return [hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id]
def __getitem__(self, item): def __getitem__(self, item):
if item.start in self.saved_results: if item.start in self.saved_results:
return self.saved_results[item.start] return self.saved_results[item.start]
manual_hits = self._get_manual_hits() manual_hits = self._manual_hits()
if manual_hits is not None: if manual_hits is not None:
start = 0 if item.start is None else item.start start = 0 if item.start is None else item.start
stop = item.stop stop = item.stop
page = manual_hits[start:stop] if stop is not None else manual_hits[start:] hits = manual_hits[start:stop] if stop is not None else manual_hits[start:]
page = ManualResultsPage(hits)
self.saved_results[start] = page self.saved_results[start] = page
return page return page
@ -395,6 +394,20 @@ class DelayedQuery:
return page return page
class ManualResultsPage(list):
def __init__(self, hits):
super().__init__(hits)
self.results = ManualResults(hits)
class ManualResults:
def __init__(self, hits):
self._docnums = [hit.docnum for hit in hits]
def docs(self):
return self._docnums
class LocalDateParser(English): class LocalDateParser(English):
def reverse_timezone_offset(self, d): def reverse_timezone_offset(self, d):
return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone( return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone(

View file

@ -70,22 +70,18 @@ class StandardPagination(PageNumberPagination):
def get_all_result_ids(self): def get_all_result_ids(self):
query = self.page.paginator.object_list query = self.page.paginator.object_list
if isinstance(query, DelayedQuery): if isinstance(query, DelayedQuery):
manual_hits = getattr(query, "manual_hits", None) try:
if manual_hits is not None:
ids = [hit["id"] for hit in manual_hits]
else:
first_page = query.saved_results.get(0)
if not first_page:
return []
ids = [ ids = [
query.searcher.ixreader.stored_fields( query.searcher.ixreader.stored_fields(
doc_num, doc_num,
)["id"] )["id"]
for doc_num in first_page.results.docs() for doc_num in query.saved_results.get(0).results.docs()
] ]
except Exception:
pass
else: else:
ids = list(self.page.paginator.object_list.values_list("pk", flat=True)) ids = self.page.paginator.object_list.values_list("pk", flat=True)
return list(ids) return ids
def get_paginated_response_schema(self, schema): def get_paginated_response_schema(self, schema):
response_schema = super().get_paginated_response_schema(schema) response_schema = super().get_paginated_response_schema(schema)