File size: 2,927 Bytes
30ac9ed
 
85aa854
30ac9ed
 
 
 
 
 
 
 
23cefb2
 
4c54fb1
30ac9ed
 
 
 
23cefb2
30ac9ed
 
 
 
 
 
 
 
23cefb2
 
 
 
 
 
 
 
 
30ac9ed
 
 
23cefb2
30ac9ed
 
 
23cefb2
 
 
 
 
 
 
 
 
 
 
30ac9ed
85aa854
 
30ac9ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85aa854
 
 
 
4c54fb1
23cefb2
30ac9ed
 
 
23cefb2
 
30ac9ed
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import time
import json
from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher
import streamlit as st
from pathlib import Path
import sys
path_root = Path("./")
sys.path.append(str(path_root))


encoder_index_map = {
    'uniCOIL': ('UniCoil', 'castorini/unicoil-noexp-msmarco-passage', 'index-unicoil'),
    'SPLADE++ Ensemble Distil': ('SpladePlusPlusEnsembleDistil', 'naver/splade-cocondenser-ensembledistil', 'index-splade-pp-ed'),
    'SPLADE++ Self Distil': ('SpladePlusPlusSelfDistil', 'naver/splade-cocondenser-selfdistil', 'index-splade-pp-sd')
}

index = 'index-splade-pp-ed'
encoder = 'SpladePlusPlusEnsembleDistil'
encoder_index = 0

st.set_page_config(page_title="Pyserini with ONNX Runtime",
                   page_icon='🌸', layout="centered")

cola, colb, colc = st.columns([5, 4, 5])
with colb:
    st.image("logo.jpeg")


colaa, colbb, colcc = st.columns([1, 8, 1])
with colbb:
    runtime = st.select_slider(
        'Select a runtime type',
        options=['PyTorch', 'ONNX Runtime'])
    st.write('Now using: ', runtime)


colaa, colbb, colcc = st.columns([1, 8, 1])
with colbb:
    encoder = st.select_slider(
        'Select a query encoder',
        options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil'])
    st.write('Now Running Encoder: ', encoder)

if runtime == 'PyTorch':
    runtime = 'pytorch'
    runtime_index = 1
else:
    runtime = 'onnx'
    runtime_index = 0

encoder, index = encoder_index_map[encoder][runtime_index], encoder_index_map[encoder][2]

searcher = LuceneImpactSearcher(
    f'indexes/{index}', f'{encoder}', encoder_type=f'{runtime}')

corpus = LuceneSearcher(f'indexes/index-unicoil')

col1, col2 = st.columns([9, 1])
with col1:
    search_query = st.text_input(label="search query", placeholder="Search")

with col2:
    st.write('#')
    button_clicked = st.button("🔎")


if search_query or button_clicked:
    num_results = None
    t_0 = time.time()
    search_results = searcher.search(search_query, k=10)
    search_time = time.time() - t_0
    st.write(
        f'<p align=\"right\" style=\"color:grey;\">Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms</p>', unsafe_allow_html=True)
    for i, result in enumerate(search_results[:10]):
        result_score = result.score
        result_id = result.docid
        contents = json.loads(result.raw)
        contents = contents['contents'] if 'contents' in contents else contents['content']
        if contents == "":
            contents = json.loads(corpus.doc(result_id).raw())['contents']

        output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'

        try:
            st.write(output, unsafe_allow_html=True)
            st.write(
                f'<div class="row">{contents}</div>', unsafe_allow_html=True)

        except:
            pass
        st.write('---')