Sophia Koehler commited on
Commit
e39c176
·
1 Parent(s): e8df6fa
Files changed (1) hide show
  1. app.py +61 -431
app.py CHANGED
@@ -1,58 +1,49 @@
1
  # -*- coding: utf-8 -*-
2
-
3
-
4
- """## Pre-requisite code
5
-
6
- The code within this section will be used in the tasks. Please do not change these code lines.
7
-
8
- ### SciQ loading and counting
9
- """
10
-
11
  from dataclasses import dataclass
12
- import pickle
13
  import os
14
- from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
15
- from nlp4web_codebase.ir.data_loaders.dm import Document
16
- from collections import Counter
17
- import tqdm
18
  import re
 
 
 
19
  import nltk
20
- nltk.download("stopwords", quiet=True)
 
 
21
  from nltk.corpus import stopwords as nltk_stopwords
22
 
 
 
 
 
 
 
 
23
  LANGUAGE = "english"
24
- word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
25
  stopwords = set(nltk_stopwords.words(LANGUAGE))
26
-
27
-
28
- def word_splitting(text: str) -> List[str]:
29
- return word_splitter(text.lower())
30
-
31
- def lemmatization(words: List[str]) -> List[str]:
32
- return words # We ignore lemmatization here for simplicity
33
 
34
  def simple_tokenize(text: str) -> List[str]:
35
- words = word_splitting(text)
36
- tokenized = list(filter(lambda w: w not in stopwords, words))
37
- tokenized = lemmatization(tokenized)
38
  return tokenized
39
 
40
- T = TypeVar("T", bound="InvertedIndex")
41
-
42
  @dataclass
43
  class PostingList:
44
- term: str # The term
45
- docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
46
- tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
47
 
 
48
 
49
  @dataclass
50
  class InvertedIndex:
51
- posting_lists: List[PostingList] # docid -> posting_list
52
  vocab: Dict[str, int]
53
- cid2docid: Dict[str, int] # collection_id -> docid
54
- collection_ids: List[str] # docid -> collection_id
55
- doc_texts: Optional[List[str]] = None # docid -> document text
56
 
57
  def save(self, output_dir: str) -> None:
58
  os.makedirs(output_dir, exist_ok=True)
@@ -61,138 +52,28 @@ class InvertedIndex:
61
 
62
  @classmethod
63
  def from_saved(cls: Type[T], saved_dir: str) -> T:
64
- index = cls(
65
- posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
66
- )
67
  with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
68
- index = pickle.load(f)
69
- return index
70
-
71
-
72
- # The output of the counting function:
73
- @dataclass
74
- class Counting:
75
- posting_lists: List[PostingList]
76
- vocab: Dict[str, int]
77
- cid2docid: Dict[str, int]
78
- collection_ids: List[str]
79
- dfs: List[int] # tid -> df
80
- dls: List[int] # docid -> doc length
81
- avgdl: float
82
- nterms: int
83
- doc_texts: Optional[List[str]] = None
84
-
85
- def run_counting(
86
- documents: Iterable[Document],
87
- tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
88
- store_raw: bool = True, # store the document text in doc_texts
89
- ndocs: Optional[int] = None,
90
- show_progress_bar: bool = True,
91
- ) -> Counting:
92
- """Counting TFs, DFs, doc_lengths, etc."""
93
- posting_lists: List[PostingList] = []
94
- vocab: Dict[str, int] = {}
95
- cid2docid: Dict[str, int] = {}
96
- collection_ids: List[str] = []
97
- dfs: List[int] = [] # tid -> df
98
- dls: List[int] = [] # docid -> doc length
99
- nterms: int = 0
100
- doc_texts: Optional[List[str]] = []
101
- for doc in tqdm.tqdm(
102
- documents,
103
- desc="Counting",
104
- total=ndocs,
105
- disable=not show_progress_bar,
106
- ):
107
- if doc.collection_id in cid2docid:
108
- continue
109
- collection_ids.append(doc.collection_id)
110
- docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
111
- toks = tokenize_fn(doc.text)
112
- tok2tf = Counter(toks)
113
- dls.append(sum(tok2tf.values()))
114
- for tok, tf in tok2tf.items():
115
- nterms += tf
116
- tid = vocab.get(tok, None)
117
- if tid is None:
118
- posting_lists.append(
119
- PostingList(term=tok, docid_postings=[], tweight_postings=[])
120
- )
121
- tid = vocab.setdefault(tok, len(vocab))
122
- posting_lists[tid].docid_postings.append(docid)
123
- posting_lists[tid].tweight_postings.append(tf)
124
- if tid < len(dfs):
125
- dfs[tid] += 1
126
- else:
127
- dfs.append(0)
128
- if store_raw:
129
- doc_texts.append(doc.text)
130
- else:
131
- doc_texts = None
132
- return Counting(
133
- posting_lists=posting_lists,
134
- vocab=vocab,
135
- cid2docid=cid2docid,
136
- collection_ids=collection_ids,
137
- dfs=dfs,
138
- dls=dls,
139
- avgdl=sum(dls) / len(dls),
140
- nterms=nterms,
141
- doc_texts=doc_texts,
142
- )
143
-
144
- from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
145
- sciq = load_sciq()
146
- counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
147
-
148
- """### BM25 Index"""
149
-
150
- from __future__ import annotations
151
- from dataclasses import asdict, dataclass
152
- import math
153
- import os
154
- from typing import Iterable, List, Optional, Type
155
- import tqdm
156
- from nlp4web_codebase.ir.data_loaders.dm import Document
157
-
158
 
159
  @dataclass
160
  class BM25Index(InvertedIndex):
161
 
162
- @staticmethod
163
- def tokenize(text: str) -> List[str]:
164
- return simple_tokenize(text)
165
-
166
  @staticmethod
167
  def cache_term_weights(
168
- posting_lists: List[PostingList],
169
- total_docs: int,
170
- avgdl: float,
171
- dfs: List[int],
172
- dls: List[int],
173
- k1: float,
174
- b: float,
175
  ) -> None:
176
- """Compute term weights and caching"""
177
-
178
  N = total_docs
179
- for tid, posting_list in enumerate(
180
- tqdm.tqdm(posting_lists, desc="Regularizing TFs")
181
- ):
182
  idf = BM25Index.calc_idf(df=dfs[tid], N=N)
183
- for i in range(len(posting_list.docid_postings)):
184
- docid = posting_list.docid_postings[i]
185
  tf = posting_list.tweight_postings[i]
186
  dl = dls[docid]
187
- regularized_tf = BM25Index.calc_regularized_tf(
188
  tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
189
- )
190
- posting_list.tweight_postings[i] = regularized_tf * idf
191
 
192
  @staticmethod
193
- def calc_regularized_tf(
194
- tf: int, dl: float, avgdl: float, k1: float, b: float
195
- ) -> float:
196
  return tf / (tf + k1 * (1 - b + b * dl / avgdl))
197
 
198
  @staticmethod
@@ -201,305 +82,54 @@ class BM25Index(InvertedIndex):
201
 
202
  @classmethod
203
  def build_from_documents(
204
- cls: Type[BM25Index],
205
- documents: Iterable[Document],
206
- store_raw: bool = True,
207
- output_dir: Optional[str] = None,
208
- ndocs: Optional[int] = None,
209
- show_progress_bar: bool = True,
210
- k1: float = 0.9,
211
- b: float = 0.4,
212
  ) -> BM25Index:
213
- # Counting TFs, DFs, doc_lengths, etc.:
214
- counting = run_counting(
215
- documents=documents,
216
- tokenize_fn=BM25Index.tokenize,
217
- store_raw=store_raw,
218
- ndocs=ndocs,
219
- show_progress_bar=show_progress_bar,
220
- )
221
-
222
- # Compute term weights and caching:
223
- posting_lists = counting.posting_lists
224
- total_docs = len(counting.cid2docid)
225
- BM25Index.cache_term_weights(
226
- posting_lists=posting_lists,
227
- total_docs=total_docs,
228
- avgdl=counting.avgdl,
229
- dfs=counting.dfs,
230
- dls=counting.dls,
231
- k1=k1,
232
- b=b,
233
- )
234
-
235
- # Assembly and save:
236
- index = BM25Index(
237
- posting_lists=posting_lists,
238
- vocab=counting.vocab,
239
- cid2docid=counting.cid2docid,
240
- collection_ids=counting.collection_ids,
241
- doc_texts=counting.doc_texts,
242
- )
243
- return index
244
-
245
- bm25_index = BM25Index.build_from_documents(
246
- documents=iter(sciq.corpus),
247
- ndocs=12160,
248
- show_progress_bar=True,
249
- )
250
- bm25_index.save("output/bm25_index")
251
- !ls
252
-
253
- """### BM25 Retriever"""
254
-
255
- from nlp4web_codebase.ir.models import BaseRetriever
256
- from typing import Type
257
- from abc import abstractmethod
258
-
259
-
260
- class BaseInvertedIndexRetriever(BaseRetriever):
261
-
262
- @property
263
- @abstractmethod
264
- def index_class(self) -> Type[InvertedIndex]:
265
- pass
266
 
 
267
  def __init__(self, index_dir: str) -> None:
268
- self.index = self.index_class.from_saved(index_dir)
269
-
270
- def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
271
- toks = self.index.tokenize(query)
272
- target_docid = self.index.cid2docid[cid]
273
- term_weights = {}
274
- for tok in toks:
275
- if tok not in self.index.vocab:
276
- continue
277
- tid = self.index.vocab[tok]
278
- posting_list = self.index.posting_lists[tid]
279
- for docid, tweight in zip(
280
- posting_list.docid_postings, posting_list.tweight_postings
281
- ):
282
- if docid == target_docid:
283
- term_weights[tok] = tweight
284
- break
285
- return term_weights
286
-
287
- def score(self, query: str, cid: str) -> float:
288
- return sum(self.get_term_weights(query=query, cid=cid).values())
289
 
290
  def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
291
- toks = self.index.tokenize(query)
292
- docid2score: Dict[int, float] = {}
293
  for tok in toks:
294
- if tok not in self.index.vocab:
295
- continue
296
- tid = self.index.vocab[tok]
297
- posting_list = self.index.posting_lists[tid]
298
- for docid, tweight in zip(
299
- posting_list.docid_postings, posting_list.tweight_postings
300
- ):
301
- docid2score.setdefault(docid, 0)
302
- docid2score[docid] += tweight
303
- docid2score = dict(
304
- sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
305
- )
306
  return {
307
- self.index.collection_ids[docid]: score
308
- for docid, score in docid2score.items()
309
  }
310
 
311
-
312
- class BM25Retriever(BaseInvertedIndexRetriever):
313
-
314
- @property
315
- def index_class(self) -> Type[BM25Index]:
316
- return BM25Index
317
-
318
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
319
- bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
320
-
321
- """# TASK1: tune b and k1 (4 points)
322
-
323
- Tune b and k1 on the **dev** split of SciQ using the metric MAP@10. The evaluation function (`evalaute_map`) is provided. Record the values in `plots_k1` and `plots_b`. Do it in a greedy manner: as the influence from b is larger, please first tune b (with k1 fixed to the default value 0.9) and use the best value of b to further tune k1.
324
-
325
- $${\displaystyle {\text{score}}(D,Q)=\sum _{i=1}^{n}{\text{IDF}}(q_{i})\cdot {\frac {f(q_{i},D)\cdot (k_{1}+1)}{f(q_{i},D)+k_{1}\cdot \left(1-b+b\cdot {\frac {|D|}{\text{avgdl}}}\right)}}}$$
326
- """
327
-
328
- from nlp4web_codebase.ir.data_loaders import Split
329
- import pytrec_eval
330
-
331
-
332
- def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float:
333
- metric = "map_cut_10"
334
- qrels = sciq.get_qrels_dict(split)
335
- evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,))
336
- qps = evaluator.evaluate(rankings)
337
- return float(np.mean([qp[metric] for qp in qps.values()]))
338
-
339
- """Example of using the pre-requisite code:"""
340
-
341
- # Loading dataset:
342
- from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
343
- sciq = load_sciq()
344
- counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
345
-
346
- # Building BM25 index and save:
347
- bm25_index = BM25Index.build_from_documents(
348
- documents=iter(sciq.corpus),
349
- ndocs=12160,
350
- show_progress_bar=True
351
- )
352
- bm25_index.save("output/bm25_index")
353
-
354
- # Loading index and use BM25 retriever to retrieve:
355
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
356
- print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking
357
-
358
- plots_b: Dict[str, List[float]] = {
359
- "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
360
- "Y": []
361
- }
362
- plots_k1: Dict[str, List[float]] = {
363
- "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
364
- "Y": []
365
- }
366
-
367
- ## YOUR_CODE_STARTS_HERE
368
- class MyBMIndex(BM25Index):
369
-
370
- @staticmethod
371
- def calc_regularized_tf(
372
- tf: int, dl: float, avgdl: float, k1: float, b: float
373
- ) -> float:
374
- return tf * (k1 + 1) / (tf + k1 * (1 - b + b * (dl / avgdl)**1.5))
375
-
376
- @staticmethod
377
- def calc_idf(df: int, N: int):
378
- return math.log((N + 1) / (df + 0.5)) + 1
379
- import numpy as np
380
- # Two steps should be involved:
381
- # Step 1. Fix k1 value to the default one 0.9,
382
- # go through all the candidate b values (0, 0.1, ..., 1.0),
383
- # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map;
384
- # Step 2. Fix b to the best one in step 1. and do the same for k1.
385
-
386
- # Hint (on using the pre-requisite code):
387
- # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code);
388
- # - One can build bm25_index with `BM25Index.build_from_documents`;
389
- # - One can use BM25Retriever to load the index and perform retrieval on the dev queries
390
- # (dev queries can be obtained via sciq.get_split_queries(Split.dev))
391
-
392
- counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
393
-
394
- def get_ranking(k1, b, counting) -> Dict[str, Dict[str, float]]:
395
- # Building BM25 index and save:
396
- bm25_index = MyBMIndex.build_from_documents(
397
- documents=iter(sciq.corpus),
398
- ndocs=12160,
399
- show_progress_bar=True,
400
- k1=k1,
401
- b=b
402
- )
403
- bm25_index.save("output/bm25_index")
404
-
405
- # Loading index and use BM25 retriever to retrieve:
406
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
407
- query_terms = sciq.get_split_queries(split= Split.dev)
408
- rankings = {}
409
- for query in query_terms:
410
- ranking = bm25_retriever.retrieve(query=query.text)
411
- rankings[query.query_id] = ranking
412
- return rankings
413
- for b in plots_b["X"]:
414
- ranking = get_ranking(0.9, b, counting)
415
- plots_b["Y"].append(evaluate_map(rankings=ranking))
416
-
417
- max_b = np.max(plots_b["Y"])
418
- for k1 in plots_k1["X"]:
419
- ranking = get_ranking(k1, max_b, counting)
420
- plots_k1["Y"].append(evaluate_map(rankings=ranking))
421
- ## YOU_CODE_ENDS_HERE
422
-
423
- ## TEST_CASES (should be close to 0.8135637188208616 and 0.7512916099773244)
424
- print(plots_k1["Y"][9])
425
- print(plots_b["Y"][1])
426
-
427
- ## RESULT_CHECKING_POINT
428
- print(plots_k1)
429
- print(plots_b)
430
-
431
- from matplotlib import pyplot as plt
432
- plt.plot(plots_b["X"], plots_b["Y"], label="b")
433
- plt.plot(plots_k1["X"], plots_k1["Y"], label="k1")
434
- plt.ylabel("MAP")
435
- plt.legend()
436
- plt.grid()
437
- plt.show()
438
-
439
- """Let's check the effectiveness gain on test after this tuning on dev"""
440
-
441
- default_map = 0.7849
442
- best_b = plots_b["X"][np.argmax(plots_b["Y"])]
443
- best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
444
- bm25_index = BM25Index.build_from_documents(
445
- documents=iter(sciq.corpus),
446
- ndocs=12160,
447
- show_progress_bar=True,
448
- k1=best_k1,
449
- b=best_b
450
- )
451
- bm25_index.save("output/bm25_index")
452
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
453
- rankings = {}
454
- for query in sciq.get_split_queries(Split.test): # note this is now on test
455
- ranking = bm25_retriever.retrieve(query=query.text)
456
- rankings[query.query_id] = ranking
457
- optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test
458
- print(default_map, optimized_map)
459
-
460
- """# TASK3: a search-engine demo based on Huggingface space (4 points)
461
-
462
- ## TASK3.1: create the gradio app (2 point)
463
-
464
- Create a gradio app to demo the BM25 search engine index on SciQ. The app should have a single input variable for the query (of type `str`) and a single output variable for the returned ranking (of type `List[Hit]` in the code below). Please use the BM25 system with default k1 and b values.
465
-
466
- Hint: it should use a "search" function of signature:
467
-
468
- ```python
469
- def search(query: str) -> List[Hit]:
470
- ...
471
- ```
472
- """
473
-
474
- import gradio as gr
475
- from typing import TypedDict
476
-
477
  class Hit(TypedDict):
478
- cid: str
479
- score: float
480
- text: str
481
-
482
- demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
483
- return_type = List[Hit]
484
 
485
- ## YOUR_CODE_STARTS_HERE
486
  def search_sciq(query: str) -> List[Hit]:
487
  results = bm25_retriever.retrieve(query)
488
- hitlist = []
489
  for cid, score in results.items():
490
- index = bm25_retriever.index.cid2docid[cid]
491
- text = bm25_retriever.index.doc_texts[index]
492
- hitlist.append(Hit(cid=cid, score=score, text=text))
 
493
 
494
- return hitlist
495
 
496
  demo = gr.Interface(
497
  fn=search_sciq,
498
  inputs="textbox",
499
- outputs="textbox",
500
  description="BM25 Search Engine Demo on SciQ Dataset"
501
  )
502
- ## YOUR_CODE_ENDS_HERE
503
- demo.launch()
504
 
 
 
505
 
 
1
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
2
  from dataclasses import dataclass
 
3
  import os
4
+ import pickle
5
+ from typing import List, Dict, Optional, Type, TypeVar, TypedDict
 
 
6
  import re
7
+ import math
8
+ from collections import Counter
9
+ import gradio as gr
10
  import nltk
11
+ from nlp4web_codebase.ir.data_loaders.dm import Document
12
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
13
+ from nlp4web_codebase.ir.models import BaseRetriever
14
  from nltk.corpus import stopwords as nltk_stopwords
15
 
16
+ # Check nltk stopwords data
17
+ try:
18
+ nltk.data.find("corpora/stopwords")
19
+ except LookupError:
20
+ nltk.download("stopwords", quiet=True)
21
+
22
+ # Tokenization and helper functions
23
  LANGUAGE = "english"
 
24
  stopwords = set(nltk_stopwords.words(LANGUAGE))
25
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
 
 
 
 
 
 
26
 
27
  def simple_tokenize(text: str) -> List[str]:
28
+ words = word_splitter(text.lower())
29
+ tokenized = [word for word in words if word not in stopwords]
 
30
  return tokenized
31
 
 
 
32
  @dataclass
33
  class PostingList:
34
+ term: str
35
+ docid_postings: List[int]
36
+ tweight_postings: List[float]
37
 
38
+ T = TypeVar("T", bound="InvertedIndex")
39
 
40
  @dataclass
41
  class InvertedIndex:
42
+ posting_lists: List[PostingList]
43
  vocab: Dict[str, int]
44
+ cid2docid: Dict[str, int]
45
+ collection_ids: List[str]
46
+ doc_texts: Optional[List[str]] = None
47
 
48
  def save(self, output_dir: str) -> None:
49
  os.makedirs(output_dir, exist_ok=True)
 
52
 
53
  @classmethod
54
  def from_saved(cls: Type[T], saved_dir: str) -> T:
 
 
 
55
  with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
56
+ return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  @dataclass
59
  class BM25Index(InvertedIndex):
60
 
 
 
 
 
61
  @staticmethod
62
  def cache_term_weights(
63
+ posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float,
 
 
 
 
 
 
64
  ) -> None:
 
 
65
  N = total_docs
66
+ for tid, posting_list in enumerate(posting_lists):
 
 
67
  idf = BM25Index.calc_idf(df=dfs[tid], N=N)
68
+ for i, docid in enumerate(posting_list.docid_postings):
 
69
  tf = posting_list.tweight_postings[i]
70
  dl = dls[docid]
71
+ posting_list.tweight_postings[i] = BM25Index.calc_regularized_tf(
72
  tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
73
+ ) * idf
 
74
 
75
  @staticmethod
76
+ def calc_regularized_tf(tf: int, dl: float, avgdl: float, k1: float, b: float) -> float:
 
 
77
  return tf / (tf + k1 * (1 - b + b * dl / avgdl))
78
 
79
  @staticmethod
 
82
 
83
  @classmethod
84
  def build_from_documents(
85
+ cls: Type[BM25Index], documents: List[Document], avgdl: float, total_docs: int, k1: float = 0.9, b: float = 0.4
 
 
 
 
 
 
 
86
  ) -> BM25Index:
87
+ # Assume run_counting() is defined to return counting object with relevant data
88
+ counting = run_counting(documents, simple_tokenize)
89
+ BM25Index.cache_term_weights(counting.posting_lists, total_docs, avgdl, counting.dfs, counting.dls, k1, b)
90
+ return cls(counting.posting_lists, counting.vocab, counting.cid2docid, counting.collection_ids, counting.doc_texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ class BM25Retriever(BaseRetriever):
93
  def __init__(self, index_dir: str) -> None:
94
+ self.index = BM25Index.from_saved(index_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
97
+ toks = simple_tokenize(query)
98
+ docid2score = Counter()
99
  for tok in toks:
100
+ if tok in self.index.vocab:
101
+ tid = self.index.vocab[tok]
102
+ posting_list = self.index.posting_lists[tid]
103
+ for docid, weight in zip(posting_list.docid_postings, posting_list.tweight_postings):
104
+ docid2score[docid] += weight
 
 
 
 
 
 
 
105
  return {
106
+ self.index.collection_ids[docid]: score for docid, score in docid2score.most_common(topk)
 
107
  }
108
 
109
+ # Gradio app setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class Hit(TypedDict):
111
+ cid: str
112
+ score: float
113
+ text: str
 
 
 
114
 
 
115
  def search_sciq(query: str) -> List[Hit]:
116
  results = bm25_retriever.retrieve(query)
117
+ hits = []
118
  for cid, score in results.items():
119
+ docid = bm25_retriever.index.cid2docid[cid]
120
+ text = bm25_retriever.index.doc_texts[docid]
121
+ hits.append(Hit(cid=cid, score=score, text=text))
122
+ return hits
123
 
124
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
125
 
126
  demo = gr.Interface(
127
  fn=search_sciq,
128
  inputs="textbox",
129
+ outputs="json",
130
  description="BM25 Search Engine Demo on SciQ Dataset"
131
  )
 
 
132
 
133
+ if __name__ == "__main__":
134
+ demo.launch()
135