Sophia Koehler commited on
Commit
d661944
·
1 Parent(s): 3f7f963
Files changed (2) hide show
  1. app.py +494 -8
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,3 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from typing import TypedDict
3
 
@@ -8,24 +493,25 @@ class Hit(TypedDict):
8
 
9
  demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
10
  return_type = List[Hit]
 
11
  ## YOUR_CODE_STARTS_HERE
12
  def search_sciq(query: str) -> List[Hit]:
13
  results = bm25_retriever.retrieve(query)
14
-
15
- # Format the output to match the List[Hit] structure
16
- hits = []
17
  for cid, score in results.items():
18
  index = bm25_retriever.index.cid2docid[cid]
19
  text = bm25_retriever.index.doc_texts[index]
20
- hits.append(Hit(cid=cid, score=score, text=text))
21
 
22
- return hits
23
 
24
- # Set up the Gradio interface
25
  demo = gr.Interface(
26
  fn=search_sciq,
27
- inputs=gr.Textbox(label="Enter your query"),
28
- outputs=gr.JSON(label="Top 10 Results"),
29
  description="BM25 Search Engine Demo on SciQ Dataset"
30
  )
 
31
  demo.launch()
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Kopie von HW1 (more instructed).ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1BrX2Zy737ji-Lbb2evMV2P-WfzvTniHj
8
+ """
9
+
10
+ !pip install git+https://github.com/kwang2049/nlp4web-codebase.git
11
+ !git clone https://github.com/kwang2049/nlp4web-codebase.git # You can always check the content of this simple codebase at any time
12
+ !pip install gradio # we also need this additionally for this homework
13
+
14
+ """## Pre-requisite code
15
+
16
+ The code within this section will be used in the tasks. Please do not change these code lines.
17
+
18
+ ### SciQ loading and counting
19
+ """
20
+
21
+ from dataclasses import dataclass
22
+ import pickle
23
+ import os
24
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
25
+ from nlp4web_codebase.ir.data_loaders.dm import Document
26
+ from collections import Counter
27
+ import tqdm
28
+ import re
29
+ import nltk
30
+ nltk.download("stopwords", quiet=True)
31
+ from nltk.corpus import stopwords as nltk_stopwords
32
+
33
+ LANGUAGE = "english"
34
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
35
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
36
+
37
+
38
+ def word_splitting(text: str) -> List[str]:
39
+ return word_splitter(text.lower())
40
+
41
+ def lemmatization(words: List[str]) -> List[str]:
42
+ return words # We ignore lemmatization here for simplicity
43
+
44
+ def simple_tokenize(text: str) -> List[str]:
45
+ words = word_splitting(text)
46
+ tokenized = list(filter(lambda w: w not in stopwords, words))
47
+ tokenized = lemmatization(tokenized)
48
+ return tokenized
49
+
50
+ T = TypeVar("T", bound="InvertedIndex")
51
+
52
+ @dataclass
53
+ class PostingList:
54
+ term: str # The term
55
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
56
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
57
+
58
+
59
+ @dataclass
60
+ class InvertedIndex:
61
+ posting_lists: List[PostingList] # docid -> posting_list
62
+ vocab: Dict[str, int]
63
+ cid2docid: Dict[str, int] # collection_id -> docid
64
+ collection_ids: List[str] # docid -> collection_id
65
+ doc_texts: Optional[List[str]] = None # docid -> document text
66
+
67
+ def save(self, output_dir: str) -> None:
68
+ os.makedirs(output_dir, exist_ok=True)
69
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
70
+ pickle.dump(self, f)
71
+
72
+ @classmethod
73
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
74
+ index = cls(
75
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
76
+ )
77
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
78
+ index = pickle.load(f)
79
+ return index
80
+
81
+
82
+ # The output of the counting function:
83
+ @dataclass
84
+ class Counting:
85
+ posting_lists: List[PostingList]
86
+ vocab: Dict[str, int]
87
+ cid2docid: Dict[str, int]
88
+ collection_ids: List[str]
89
+ dfs: List[int] # tid -> df
90
+ dls: List[int] # docid -> doc length
91
+ avgdl: float
92
+ nterms: int
93
+ doc_texts: Optional[List[str]] = None
94
+
95
+ def run_counting(
96
+ documents: Iterable[Document],
97
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
98
+ store_raw: bool = True, # store the document text in doc_texts
99
+ ndocs: Optional[int] = None,
100
+ show_progress_bar: bool = True,
101
+ ) -> Counting:
102
+ """Counting TFs, DFs, doc_lengths, etc."""
103
+ posting_lists: List[PostingList] = []
104
+ vocab: Dict[str, int] = {}
105
+ cid2docid: Dict[str, int] = {}
106
+ collection_ids: List[str] = []
107
+ dfs: List[int] = [] # tid -> df
108
+ dls: List[int] = [] # docid -> doc length
109
+ nterms: int = 0
110
+ doc_texts: Optional[List[str]] = []
111
+ for doc in tqdm.tqdm(
112
+ documents,
113
+ desc="Counting",
114
+ total=ndocs,
115
+ disable=not show_progress_bar,
116
+ ):
117
+ if doc.collection_id in cid2docid:
118
+ continue
119
+ collection_ids.append(doc.collection_id)
120
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
121
+ toks = tokenize_fn(doc.text)
122
+ tok2tf = Counter(toks)
123
+ dls.append(sum(tok2tf.values()))
124
+ for tok, tf in tok2tf.items():
125
+ nterms += tf
126
+ tid = vocab.get(tok, None)
127
+ if tid is None:
128
+ posting_lists.append(
129
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
130
+ )
131
+ tid = vocab.setdefault(tok, len(vocab))
132
+ posting_lists[tid].docid_postings.append(docid)
133
+ posting_lists[tid].tweight_postings.append(tf)
134
+ if tid < len(dfs):
135
+ dfs[tid] += 1
136
+ else:
137
+ dfs.append(0)
138
+ if store_raw:
139
+ doc_texts.append(doc.text)
140
+ else:
141
+ doc_texts = None
142
+ return Counting(
143
+ posting_lists=posting_lists,
144
+ vocab=vocab,
145
+ cid2docid=cid2docid,
146
+ collection_ids=collection_ids,
147
+ dfs=dfs,
148
+ dls=dls,
149
+ avgdl=sum(dls) / len(dls),
150
+ nterms=nterms,
151
+ doc_texts=doc_texts,
152
+ )
153
+
154
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
155
+ sciq = load_sciq()
156
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
157
+
158
+ """### BM25 Index"""
159
+
160
+ from __future__ import annotations
161
+ from dataclasses import asdict, dataclass
162
+ import math
163
+ import os
164
+ from typing import Iterable, List, Optional, Type
165
+ import tqdm
166
+ from nlp4web_codebase.ir.data_loaders.dm import Document
167
+
168
+
169
+ @dataclass
170
+ class BM25Index(InvertedIndex):
171
+
172
+ @staticmethod
173
+ def tokenize(text: str) -> List[str]:
174
+ return simple_tokenize(text)
175
+
176
+ @staticmethod
177
+ def cache_term_weights(
178
+ posting_lists: List[PostingList],
179
+ total_docs: int,
180
+ avgdl: float,
181
+ dfs: List[int],
182
+ dls: List[int],
183
+ k1: float,
184
+ b: float,
185
+ ) -> None:
186
+ """Compute term weights and caching"""
187
+
188
+ N = total_docs
189
+ for tid, posting_list in enumerate(
190
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
191
+ ):
192
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
193
+ for i in range(len(posting_list.docid_postings)):
194
+ docid = posting_list.docid_postings[i]
195
+ tf = posting_list.tweight_postings[i]
196
+ dl = dls[docid]
197
+ regularized_tf = BM25Index.calc_regularized_tf(
198
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
199
+ )
200
+ posting_list.tweight_postings[i] = regularized_tf * idf
201
+
202
+ @staticmethod
203
+ def calc_regularized_tf(
204
+ tf: int, dl: float, avgdl: float, k1: float, b: float
205
+ ) -> float:
206
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
207
+
208
+ @staticmethod
209
+ def calc_idf(df: int, N: int):
210
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
211
+
212
+ @classmethod
213
+ def build_from_documents(
214
+ cls: Type[BM25Index],
215
+ documents: Iterable[Document],
216
+ store_raw: bool = True,
217
+ output_dir: Optional[str] = None,
218
+ ndocs: Optional[int] = None,
219
+ show_progress_bar: bool = True,
220
+ k1: float = 0.9,
221
+ b: float = 0.4,
222
+ ) -> BM25Index:
223
+ # Counting TFs, DFs, doc_lengths, etc.:
224
+ counting = run_counting(
225
+ documents=documents,
226
+ tokenize_fn=BM25Index.tokenize,
227
+ store_raw=store_raw,
228
+ ndocs=ndocs,
229
+ show_progress_bar=show_progress_bar,
230
+ )
231
+
232
+ # Compute term weights and caching:
233
+ posting_lists = counting.posting_lists
234
+ total_docs = len(counting.cid2docid)
235
+ BM25Index.cache_term_weights(
236
+ posting_lists=posting_lists,
237
+ total_docs=total_docs,
238
+ avgdl=counting.avgdl,
239
+ dfs=counting.dfs,
240
+ dls=counting.dls,
241
+ k1=k1,
242
+ b=b,
243
+ )
244
+
245
+ # Assembly and save:
246
+ index = BM25Index(
247
+ posting_lists=posting_lists,
248
+ vocab=counting.vocab,
249
+ cid2docid=counting.cid2docid,
250
+ collection_ids=counting.collection_ids,
251
+ doc_texts=counting.doc_texts,
252
+ )
253
+ return index
254
+
255
+ bm25_index = BM25Index.build_from_documents(
256
+ documents=iter(sciq.corpus),
257
+ ndocs=12160,
258
+ show_progress_bar=True,
259
+ )
260
+ bm25_index.save("output/bm25_index")
261
+ !ls
262
+
263
+ """### BM25 Retriever"""
264
+
265
+ from nlp4web_codebase.ir.models import BaseRetriever
266
+ from typing import Type
267
+ from abc import abstractmethod
268
+
269
+
270
+ class BaseInvertedIndexRetriever(BaseRetriever):
271
+
272
+ @property
273
+ @abstractmethod
274
+ def index_class(self) -> Type[InvertedIndex]:
275
+ pass
276
+
277
+ def __init__(self, index_dir: str) -> None:
278
+ self.index = self.index_class.from_saved(index_dir)
279
+
280
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
281
+ toks = self.index.tokenize(query)
282
+ target_docid = self.index.cid2docid[cid]
283
+ term_weights = {}
284
+ for tok in toks:
285
+ if tok not in self.index.vocab:
286
+ continue
287
+ tid = self.index.vocab[tok]
288
+ posting_list = self.index.posting_lists[tid]
289
+ for docid, tweight in zip(
290
+ posting_list.docid_postings, posting_list.tweight_postings
291
+ ):
292
+ if docid == target_docid:
293
+ term_weights[tok] = tweight
294
+ break
295
+ return term_weights
296
+
297
+ def score(self, query: str, cid: str) -> float:
298
+ return sum(self.get_term_weights(query=query, cid=cid).values())
299
+
300
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
301
+ toks = self.index.tokenize(query)
302
+ docid2score: Dict[int, float] = {}
303
+ for tok in toks:
304
+ if tok not in self.index.vocab:
305
+ continue
306
+ tid = self.index.vocab[tok]
307
+ posting_list = self.index.posting_lists[tid]
308
+ for docid, tweight in zip(
309
+ posting_list.docid_postings, posting_list.tweight_postings
310
+ ):
311
+ docid2score.setdefault(docid, 0)
312
+ docid2score[docid] += tweight
313
+ docid2score = dict(
314
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
315
+ )
316
+ return {
317
+ self.index.collection_ids[docid]: score
318
+ for docid, score in docid2score.items()
319
+ }
320
+
321
+
322
+ class BM25Retriever(BaseInvertedIndexRetriever):
323
+
324
+ @property
325
+ def index_class(self) -> Type[BM25Index]:
326
+ return BM25Index
327
+
328
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
329
+ bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
330
+
331
+ """# TASK1: tune b and k1 (4 points)
332
+
333
+ Tune b and k1 on the **dev** split of SciQ using the metric MAP@10. The evaluation function (`evalaute_map`) is provided. Record the values in `plots_k1` and `plots_b`. Do it in a greedy manner: as the influence from b is larger, please first tune b (with k1 fixed to the default value 0.9) and use the best value of b to further tune k1.
334
+
335
+ $${\displaystyle {\text{score}}(D,Q)=\sum _{i=1}^{n}{\text{IDF}}(q_{i})\cdot {\frac {f(q_{i},D)\cdot (k_{1}+1)}{f(q_{i},D)+k_{1}\cdot \left(1-b+b\cdot {\frac {|D|}{\text{avgdl}}}\right)}}}$$
336
+ """
337
+
338
+ from nlp4web_codebase.ir.data_loaders import Split
339
+ import pytrec_eval
340
+
341
+
342
+ def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float:
343
+ metric = "map_cut_10"
344
+ qrels = sciq.get_qrels_dict(split)
345
+ evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,))
346
+ qps = evaluator.evaluate(rankings)
347
+ return float(np.mean([qp[metric] for qp in qps.values()]))
348
+
349
+ """Example of using the pre-requisite code:"""
350
+
351
+ # Loading dataset:
352
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
353
+ sciq = load_sciq()
354
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
355
+
356
+ # Building BM25 index and save:
357
+ bm25_index = BM25Index.build_from_documents(
358
+ documents=iter(sciq.corpus),
359
+ ndocs=12160,
360
+ show_progress_bar=True
361
+ )
362
+ bm25_index.save("output/bm25_index")
363
+
364
+ # Loading index and use BM25 retriever to retrieve:
365
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
366
+ print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking
367
+
368
+ plots_b: Dict[str, List[float]] = {
369
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
370
+ "Y": []
371
+ }
372
+ plots_k1: Dict[str, List[float]] = {
373
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
374
+ "Y": []
375
+ }
376
+
377
+ ## YOUR_CODE_STARTS_HERE
378
+ class MyBMIndex(BM25Index):
379
+
380
+ @staticmethod
381
+ def calc_regularized_tf(
382
+ tf: int, dl: float, avgdl: float, k1: float, b: float
383
+ ) -> float:
384
+ return tf * (k1 + 1) / (tf + k1 * (1 - b + b * (dl / avgdl)**1.5))
385
+
386
+ @staticmethod
387
+ def calc_idf(df: int, N: int):
388
+ return math.log((N + 1) / (df + 0.5)) + 1
389
+ import numpy as np
390
+ # Two steps should be involved:
391
+ # Step 1. Fix k1 value to the default one 0.9,
392
+ # go through all the candidate b values (0, 0.1, ..., 1.0),
393
+ # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map;
394
+ # Step 2. Fix b to the best one in step 1. and do the same for k1.
395
+
396
+ # Hint (on using the pre-requisite code):
397
+ # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code);
398
+ # - One can build bm25_index with `BM25Index.build_from_documents`;
399
+ # - One can use BM25Retriever to load the index and perform retrieval on the dev queries
400
+ # (dev queries can be obtained via sciq.get_split_queries(Split.dev))
401
+
402
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
403
+
404
+ def get_ranking(k1, b, counting) -> Dict[str, Dict[str, float]]:
405
+ # Building BM25 index and save:
406
+ bm25_index = MyBMIndex.build_from_documents(
407
+ documents=iter(sciq.corpus),
408
+ ndocs=12160,
409
+ show_progress_bar=True,
410
+ k1=k1,
411
+ b=b
412
+ )
413
+ bm25_index.save("output/bm25_index")
414
+
415
+ # Loading index and use BM25 retriever to retrieve:
416
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
417
+ query_terms = sciq.get_split_queries(split= Split.dev)
418
+ rankings = {}
419
+ for query in query_terms:
420
+ ranking = bm25_retriever.retrieve(query=query.text)
421
+ rankings[query.query_id] = ranking
422
+ return rankings
423
+ for b in plots_b["X"]:
424
+ ranking = get_ranking(0.9, b, counting)
425
+ plots_b["Y"].append(evaluate_map(rankings=ranking))
426
+
427
+ max_b = np.max(plots_b["Y"])
428
+ for k1 in plots_k1["X"]:
429
+ ranking = get_ranking(k1, max_b, counting)
430
+ plots_k1["Y"].append(evaluate_map(rankings=ranking))
431
+ ## YOU_CODE_ENDS_HERE
432
+
433
+ ## TEST_CASES (should be close to 0.8135637188208616 and 0.7512916099773244)
434
+ print(plots_k1["Y"][9])
435
+ print(plots_b["Y"][1])
436
+
437
+ ## RESULT_CHECKING_POINT
438
+ print(plots_k1)
439
+ print(plots_b)
440
+
441
+ from matplotlib import pyplot as plt
442
+ plt.plot(plots_b["X"], plots_b["Y"], label="b")
443
+ plt.plot(plots_k1["X"], plots_k1["Y"], label="k1")
444
+ plt.ylabel("MAP")
445
+ plt.legend()
446
+ plt.grid()
447
+ plt.show()
448
+
449
+ """Let's check the effectiveness gain on test after this tuning on dev"""
450
+
451
+ default_map = 0.7849
452
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])]
453
+ best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
454
+ bm25_index = BM25Index.build_from_documents(
455
+ documents=iter(sciq.corpus),
456
+ ndocs=12160,
457
+ show_progress_bar=True,
458
+ k1=best_k1,
459
+ b=best_b
460
+ )
461
+ bm25_index.save("output/bm25_index")
462
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
463
+ rankings = {}
464
+ for query in sciq.get_split_queries(Split.test): # note this is now on test
465
+ ranking = bm25_retriever.retrieve(query=query.text)
466
+ rankings[query.query_id] = ranking
467
+ optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test
468
+ print(default_map, optimized_map)
469
+
470
+ """# TASK3: a search-engine demo based on Huggingface space (4 points)
471
+
472
+ ## TASK3.1: create the gradio app (2 point)
473
+
474
+ Create a gradio app to demo the BM25 search engine index on SciQ. The app should have a single input variable for the query (of type `str`) and a single output variable for the returned ranking (of type `List[Hit]` in the code below). Please use the BM25 system with default k1 and b values.
475
+
476
+ Hint: it should use a "search" function of signature:
477
+
478
+ ```python
479
+ def search(query: str) -> List[Hit]:
480
+ ...
481
+ ```
482
+ """
483
+
484
+ !pip install gradio
485
+
486
  import gradio as gr
487
  from typing import TypedDict
488
 
 
493
 
494
  demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
495
  return_type = List[Hit]
496
+
497
  ## YOUR_CODE_STARTS_HERE
498
  def search_sciq(query: str) -> List[Hit]:
499
  results = bm25_retriever.retrieve(query)
500
+ hitlist = []
 
 
501
  for cid, score in results.items():
502
  index = bm25_retriever.index.cid2docid[cid]
503
  text = bm25_retriever.index.doc_texts[index]
504
+ hitlist.append(Hit(cid=cid, score=score, text=text))
505
 
506
+ return hitlist
507
 
 
508
  demo = gr.Interface(
509
  fn=search_sciq,
510
+ inputs="textbox",
511
+ outputs="textbox",
512
  description="BM25 Search Engine Demo on SciQ Dataset"
513
  )
514
+ ## YOUR_CODE_ENDS_HERE
515
  demo.launch()
516
+
517
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ git+https://github.com/kwang2049/nlp4web-codebase.git