File size: 2,927 Bytes
30ac9ed 85aa854 30ac9ed 23cefb2 4c54fb1 30ac9ed 23cefb2 30ac9ed 23cefb2 30ac9ed 23cefb2 30ac9ed 23cefb2 30ac9ed 85aa854 30ac9ed 85aa854 4c54fb1 23cefb2 30ac9ed 23cefb2 30ac9ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import time
import json
from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher
import streamlit as st
from pathlib import Path
import sys
path_root = Path("./")
sys.path.append(str(path_root))
encoder_index_map = {
'uniCOIL': ('UniCoil', 'castorini/unicoil-noexp-msmarco-passage', 'index-unicoil'),
'SPLADE++ Ensemble Distil': ('SpladePlusPlusEnsembleDistil', 'naver/splade-cocondenser-ensembledistil', 'index-splade-pp-ed'),
'SPLADE++ Self Distil': ('SpladePlusPlusSelfDistil', 'naver/splade-cocondenser-selfdistil', 'index-splade-pp-sd')
}
index = 'index-splade-pp-ed'
encoder = 'SpladePlusPlusEnsembleDistil'
encoder_index = 0
st.set_page_config(page_title="Pyserini with ONNX Runtime",
page_icon='🌸', layout="centered")
cola, colb, colc = st.columns([5, 4, 5])
with colb:
st.image("logo.jpeg")
colaa, colbb, colcc = st.columns([1, 8, 1])
with colbb:
runtime = st.select_slider(
'Select a runtime type',
options=['PyTorch', 'ONNX Runtime'])
st.write('Now using: ', runtime)
colaa, colbb, colcc = st.columns([1, 8, 1])
with colbb:
encoder = st.select_slider(
'Select a query encoder',
options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil'])
st.write('Now Running Encoder: ', encoder)
if runtime == 'PyTorch':
runtime = 'pytorch'
runtime_index = 1
else:
runtime = 'onnx'
runtime_index = 0
encoder, index = encoder_index_map[encoder][runtime_index], encoder_index_map[encoder][2]
searcher = LuceneImpactSearcher(
f'indexes/{index}', f'{encoder}', encoder_type=f'{runtime}')
corpus = LuceneSearcher(f'indexes/index-unicoil')
col1, col2 = st.columns([9, 1])
with col1:
search_query = st.text_input(label="search query", placeholder="Search")
with col2:
st.write('#')
button_clicked = st.button("🔎")
if search_query or button_clicked:
num_results = None
t_0 = time.time()
search_results = searcher.search(search_query, k=10)
search_time = time.time() - t_0
st.write(
f'<p align=\"right\" style=\"color:grey;\">Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms</p>', unsafe_allow_html=True)
for i, result in enumerate(search_results[:10]):
result_score = result.score
result_id = result.docid
contents = json.loads(result.raw)
contents = contents['contents'] if 'contents' in contents else contents['content']
if contents == "":
contents = json.loads(corpus.doc(result_id).raw())['contents']
output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
try:
st.write(output, unsafe_allow_html=True)
st.write(
f'<div class="row">{contents}</div>', unsafe_allow_html=True)
except:
pass
st.write('---')
|