Sophia Koehler commited on
Commit
2fa43bc
·
1 Parent(s): b91726c
Files changed (2) hide show
  1. app.py +270 -64
  2. nlp4web-codebase +1 -0
app.py CHANGED
@@ -1,49 +1,50 @@
1
  # -*- coding: utf-8 -*-
 
2
  from dataclasses import dataclass
3
- import os
4
  import pickle
5
- from typing import List, Dict, Optional, Type, TypeVar, TypedDict
6
- import re
7
- import math
8
  from collections import Counter
9
- import gradio as gr
 
10
  import nltk
11
- from nlp4web_codebase.ir.data_loaders.dm import Document
12
- from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
13
- from nlp4web_codebase.ir.models import BaseRetriever
14
  from nltk.corpus import stopwords as nltk_stopwords
15
 
16
- # Check nltk stopwords data
17
- try:
18
- nltk.data.find("corpora/stopwords")
19
- except LookupError:
20
- nltk.download("stopwords", quiet=True)
21
-
22
- # Tokenization and helper functions
23
  LANGUAGE = "english"
24
- stopwords = set(nltk_stopwords.words(LANGUAGE))
25
  word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
 
 
 
 
 
 
 
 
26
 
27
  def simple_tokenize(text: str) -> List[str]:
28
- words = word_splitter(text.lower())
29
- tokenized = [word for word in words if word not in stopwords]
 
30
  return tokenized
31
 
 
 
32
  @dataclass
33
  class PostingList:
34
- term: str
35
- docid_postings: List[int]
36
- tweight_postings: List[float]
37
 
38
- T = TypeVar("T", bound="InvertedIndex")
39
 
40
  @dataclass
41
  class InvertedIndex:
42
- posting_lists: List[PostingList]
43
  vocab: Dict[str, int]
44
- cid2docid: Dict[str, int]
45
- collection_ids: List[str]
46
- doc_texts: Optional[List[str]] = None
47
 
48
  def save(self, output_dir: str) -> None:
49
  os.makedirs(output_dir, exist_ok=True)
@@ -52,28 +53,138 @@ class InvertedIndex:
52
 
53
  @classmethod
54
  def from_saved(cls: Type[T], saved_dir: str) -> T:
 
 
 
55
  with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
56
- return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  @dataclass
59
  class BM25Index(InvertedIndex):
60
 
 
 
 
 
61
  @staticmethod
62
  def cache_term_weights(
63
- posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float,
 
 
 
 
 
 
64
  ) -> None:
 
 
65
  N = total_docs
66
- for tid, posting_list in enumerate(posting_lists):
 
 
67
  idf = BM25Index.calc_idf(df=dfs[tid], N=N)
68
- for i, docid in enumerate(posting_list.docid_postings):
 
69
  tf = posting_list.tweight_postings[i]
70
  dl = dls[docid]
71
- posting_list.tweight_postings[i] = BM25Index.calc_regularized_tf(
72
  tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
73
- ) * idf
 
74
 
75
  @staticmethod
76
- def calc_regularized_tf(tf: int, dl: float, avgdl: float, k1: float, b: float) -> float:
 
 
77
  return tf / (tf + k1 * (1 - b + b * dl / avgdl))
78
 
79
  @staticmethod
@@ -82,54 +193,149 @@ class BM25Index(InvertedIndex):
82
 
83
  @classmethod
84
  def build_from_documents(
85
- cls: Type["BM25Index"], documents: List[Document], avgdl: float, total_docs: int, k1: float = 0.9, b: float = 0.4
86
- ) -> "BM25Index":
87
- # Assume run_counting() is defined to return counting object with relevant data
88
- counting = run_counting(documents, simple_tokenize)
89
- BM25Index.cache_term_weights(counting.posting_lists, total_docs, avgdl, counting.dfs, counting.dls, k1, b)
90
- return cls(counting.posting_lists, counting.vocab, counting.cid2docid, counting.collection_ids, counting.doc_texts)
91
-
92
- class BM25Retriever(BaseRetriever):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def __init__(self, index_dir: str) -> None:
94
- self.index = BM25Index.from_saved(index_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
97
- toks = simple_tokenize(query)
98
- docid2score = Counter()
99
  for tok in toks:
100
- if tok in self.index.vocab:
101
- tid = self.index.vocab[tok]
102
- posting_list = self.index.posting_lists[tid]
103
- for docid, weight in zip(posting_list.docid_postings, posting_list.tweight_postings):
104
- docid2score[docid] += weight
 
 
 
 
 
 
 
105
  return {
106
- self.index.collection_ids[docid]: score for docid, score in docid2score.most_common(topk)
 
107
  }
108
 
109
- # Gradio app setup
 
 
 
 
 
 
 
 
 
 
110
  class Hit(TypedDict):
111
- cid: str
112
- score: float
113
- text: str
 
 
 
114
 
 
 
 
 
 
 
 
 
115
  def search_sciq(query: str) -> List[Hit]:
116
  results = bm25_retriever.retrieve(query)
117
- hits = []
118
  for cid, score in results.items():
119
- docid = bm25_retriever.index.cid2docid[cid]
120
- text = bm25_retriever.index.doc_texts[docid]
121
- hits.append(Hit(cid=cid, score=score, text=text))
122
- return hits
123
 
124
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
125
 
126
  demo = gr.Interface(
127
  fn=search_sciq,
128
  inputs="textbox",
129
- outputs="json",
130
  description="BM25 Search Engine Demo on SciQ Dataset"
131
  )
132
-
133
- if __name__ == "__main__":
134
- demo.launch()
135
 
 
1
  # -*- coding: utf-8 -*-
2
+
3
  from dataclasses import dataclass
 
4
  import pickle
5
+ import os
6
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
7
+ from nlp4web_codebase.ir.data_loaders.dm import Document
8
  from collections import Counter
9
+ import tqdm
10
+ import re
11
  import nltk
12
+ nltk.download("stopwords", quiet=True)
 
 
13
  from nltk.corpus import stopwords as nltk_stopwords
14
 
 
 
 
 
 
 
 
15
  LANGUAGE = "english"
 
16
  word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
17
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
18
+
19
+
20
+ def word_splitting(text: str) -> List[str]:
21
+ return word_splitter(text.lower())
22
+
23
+ def lemmatization(words: List[str]) -> List[str]:
24
+ return words # We ignore lemmatization here for simplicity
25
 
26
  def simple_tokenize(text: str) -> List[str]:
27
+ words = word_splitting(text)
28
+ tokenized = list(filter(lambda w: w not in stopwords, words))
29
+ tokenized = lemmatization(tokenized)
30
  return tokenized
31
 
32
+ T = TypeVar("T", bound="InvertedIndex")
33
+
34
  @dataclass
35
  class PostingList:
36
+ term: str # The term
37
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
38
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
39
 
 
40
 
41
  @dataclass
42
  class InvertedIndex:
43
+ posting_lists: List[PostingList] # docid -> posting_list
44
  vocab: Dict[str, int]
45
+ cid2docid: Dict[str, int] # collection_id -> docid
46
+ collection_ids: List[str] # docid -> collection_id
47
+ doc_texts: Optional[List[str]] = None # docid -> document text
48
 
49
  def save(self, output_dir: str) -> None:
50
  os.makedirs(output_dir, exist_ok=True)
 
53
 
54
  @classmethod
55
  def from_saved(cls: Type[T], saved_dir: str) -> T:
56
+ index = cls(
57
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
58
+ )
59
  with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
60
+ index = pickle.load(f)
61
+ return index
62
+
63
+
64
+ # The output of the counting function:
65
+ @dataclass
66
+ class Counting:
67
+ posting_lists: List[PostingList]
68
+ vocab: Dict[str, int]
69
+ cid2docid: Dict[str, int]
70
+ collection_ids: List[str]
71
+ dfs: List[int] # tid -> df
72
+ dls: List[int] # docid -> doc length
73
+ avgdl: float
74
+ nterms: int
75
+ doc_texts: Optional[List[str]] = None
76
+
77
+ def run_counting(
78
+ documents: Iterable[Document],
79
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
80
+ store_raw: bool = True, # store the document text in doc_texts
81
+ ndocs: Optional[int] = None,
82
+ show_progress_bar: bool = True,
83
+ ) -> Counting:
84
+ """Counting TFs, DFs, doc_lengths, etc."""
85
+ posting_lists: List[PostingList] = []
86
+ vocab: Dict[str, int] = {}
87
+ cid2docid: Dict[str, int] = {}
88
+ collection_ids: List[str] = []
89
+ dfs: List[int] = [] # tid -> df
90
+ dls: List[int] = [] # docid -> doc length
91
+ nterms: int = 0
92
+ doc_texts: Optional[List[str]] = []
93
+ for doc in tqdm.tqdm(
94
+ documents,
95
+ desc="Counting",
96
+ total=ndocs,
97
+ disable=not show_progress_bar,
98
+ ):
99
+ if doc.collection_id in cid2docid:
100
+ continue
101
+ collection_ids.append(doc.collection_id)
102
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
103
+ toks = tokenize_fn(doc.text)
104
+ tok2tf = Counter(toks)
105
+ dls.append(sum(tok2tf.values()))
106
+ for tok, tf in tok2tf.items():
107
+ nterms += tf
108
+ tid = vocab.get(tok, None)
109
+ if tid is None:
110
+ posting_lists.append(
111
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
112
+ )
113
+ tid = vocab.setdefault(tok, len(vocab))
114
+ posting_lists[tid].docid_postings.append(docid)
115
+ posting_lists[tid].tweight_postings.append(tf)
116
+ if tid < len(dfs):
117
+ dfs[tid] += 1
118
+ else:
119
+ dfs.append(0)
120
+ if store_raw:
121
+ doc_texts.append(doc.text)
122
+ else:
123
+ doc_texts = None
124
+ return Counting(
125
+ posting_lists=posting_lists,
126
+ vocab=vocab,
127
+ cid2docid=cid2docid,
128
+ collection_ids=collection_ids,
129
+ dfs=dfs,
130
+ dls=dls,
131
+ avgdl=sum(dls) / len(dls),
132
+ nterms=nterms,
133
+ doc_texts=doc_texts,
134
+ )
135
+
136
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
137
+ sciq = load_sciq()
138
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
139
+
140
+ """### BM25 Index"""
141
+
142
+ from __future__ import annotations
143
+ from dataclasses import asdict, dataclass
144
+ import math
145
+ import os
146
+ from typing import Iterable, List, Optional, Type
147
+ import tqdm
148
+ from nlp4web_codebase.ir.data_loaders.dm import Document
149
+
150
 
151
  @dataclass
152
  class BM25Index(InvertedIndex):
153
 
154
+ @staticmethod
155
+ def tokenize(text: str) -> List[str]:
156
+ return simple_tokenize(text)
157
+
158
  @staticmethod
159
  def cache_term_weights(
160
+ posting_lists: List[PostingList],
161
+ total_docs: int,
162
+ avgdl: float,
163
+ dfs: List[int],
164
+ dls: List[int],
165
+ k1: float,
166
+ b: float,
167
  ) -> None:
168
+ """Compute term weights and caching"""
169
+
170
  N = total_docs
171
+ for tid, posting_list in enumerate(
172
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
173
+ ):
174
  idf = BM25Index.calc_idf(df=dfs[tid], N=N)
175
+ for i in range(len(posting_list.docid_postings)):
176
+ docid = posting_list.docid_postings[i]
177
  tf = posting_list.tweight_postings[i]
178
  dl = dls[docid]
179
+ regularized_tf = BM25Index.calc_regularized_tf(
180
  tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
181
+ )
182
+ posting_list.tweight_postings[i] = regularized_tf * idf
183
 
184
  @staticmethod
185
+ def calc_regularized_tf(
186
+ tf: int, dl: float, avgdl: float, k1: float, b: float
187
+ ) -> float:
188
  return tf / (tf + k1 * (1 - b + b * dl / avgdl))
189
 
190
  @staticmethod
 
193
 
194
  @classmethod
195
  def build_from_documents(
196
+ cls: Type[BM25Index],
197
+ documents: Iterable[Document],
198
+ store_raw: bool = True,
199
+ output_dir: Optional[str] = None,
200
+ ndocs: Optional[int] = None,
201
+ show_progress_bar: bool = True,
202
+ k1: float = 0.9,
203
+ b: float = 0.4,
204
+ ) -> BM25Index:
205
+ # Counting TFs, DFs, doc_lengths, etc.:
206
+ counting = run_counting(
207
+ documents=documents,
208
+ tokenize_fn=BM25Index.tokenize,
209
+ store_raw=store_raw,
210
+ ndocs=ndocs,
211
+ show_progress_bar=show_progress_bar,
212
+ )
213
+
214
+ # Compute term weights and caching:
215
+ posting_lists = counting.posting_lists
216
+ total_docs = len(counting.cid2docid)
217
+ BM25Index.cache_term_weights(
218
+ posting_lists=posting_lists,
219
+ total_docs=total_docs,
220
+ avgdl=counting.avgdl,
221
+ dfs=counting.dfs,
222
+ dls=counting.dls,
223
+ k1=k1,
224
+ b=b,
225
+ )
226
+
227
+ # Assembly and save:
228
+ index = BM25Index(
229
+ posting_lists=posting_lists,
230
+ vocab=counting.vocab,
231
+ cid2docid=counting.cid2docid,
232
+ collection_ids=counting.collection_ids,
233
+ doc_texts=counting.doc_texts,
234
+ )
235
+ return index
236
+
237
+
238
+ """### BM25 Retriever"""
239
+
240
+ from nlp4web_codebase.ir.models import BaseRetriever
241
+ from typing import Type
242
+ from abc import abstractmethod
243
+
244
+
245
+ class BaseInvertedIndexRetriever(BaseRetriever):
246
+
247
+ @property
248
+ @abstractmethod
249
+ def index_class(self) -> Type[InvertedIndex]:
250
+ pass
251
+
252
  def __init__(self, index_dir: str) -> None:
253
+ self.index = self.index_class.from_saved(index_dir)
254
+
255
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
256
+ toks = self.index.tokenize(query)
257
+ target_docid = self.index.cid2docid[cid]
258
+ term_weights = {}
259
+ for tok in toks:
260
+ if tok not in self.index.vocab:
261
+ continue
262
+ tid = self.index.vocab[tok]
263
+ posting_list = self.index.posting_lists[tid]
264
+ for docid, tweight in zip(
265
+ posting_list.docid_postings, posting_list.tweight_postings
266
+ ):
267
+ if docid == target_docid:
268
+ term_weights[tok] = tweight
269
+ break
270
+ return term_weights
271
+
272
+ def score(self, query: str, cid: str) -> float:
273
+ return sum(self.get_term_weights(query=query, cid=cid).values())
274
 
275
  def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
276
+ toks = self.index.tokenize(query)
277
+ docid2score: Dict[int, float] = {}
278
  for tok in toks:
279
+ if tok not in self.index.vocab:
280
+ continue
281
+ tid = self.index.vocab[tok]
282
+ posting_list = self.index.posting_lists[tid]
283
+ for docid, tweight in zip(
284
+ posting_list.docid_postings, posting_list.tweight_postings
285
+ ):
286
+ docid2score.setdefault(docid, 0)
287
+ docid2score[docid] += tweight
288
+ docid2score = dict(
289
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
290
+ )
291
  return {
292
+ self.index.collection_ids[docid]: score
293
+ for docid, score in docid2score.items()
294
  }
295
 
296
+
297
+ class BM25Retriever(BaseInvertedIndexRetriever):
298
+
299
+ @property
300
+ def index_class(self) -> Type[BM25Index]:
301
+ return BM25Index
302
+
303
+
304
+ import gradio as gr
305
+ from typing import TypedDict
306
+
307
  class Hit(TypedDict):
308
+ cid: str
309
+ score: float
310
+ text: str
311
+
312
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
313
+ return_type = List[Hit]
314
 
315
+ ## YOUR_CODE_STARTS_HERE
316
+ bm25_index = BM25Index.build_from_documents(
317
+ documents=iter(sciq.corpus),
318
+ ndocs=12160,
319
+ show_progress_bar=True
320
+ )
321
+ bm25_index.save("output/bm25_index")
322
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
323
  def search_sciq(query: str) -> List[Hit]:
324
  results = bm25_retriever.retrieve(query)
325
+ hitlist = []
326
  for cid, score in results.items():
327
+ index = bm25_retriever.index.cid2docid[cid]
328
+ text = bm25_retriever.index.doc_texts[index]
329
+ hitlist.append(Hit(cid=cid, score=score, text=text))
 
330
 
331
+ return hitlist
332
 
333
  demo = gr.Interface(
334
  fn=search_sciq,
335
  inputs="textbox",
336
+ outputs="textbox",
337
  description="BM25 Search Engine Demo on SciQ Dataset"
338
  )
339
+ ## YOUR_CODE_ENDS_HERE
340
+ demo.launch()
 
341
 
nlp4web-codebase ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 83f9afbbf7e372c116fdd04997a96449007f861f