TenPoisk
commited on
Commit
·
5009091
1
Parent(s):
15b2571
Delete st_utils.py
Browse files- st_utils.py +0 -126
st_utils.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments
|
3 |
-
from pprint import pprint
|
4 |
-
from hf_search import HFSearch
|
5 |
-
import streamlit as st
|
6 |
-
import itertools
|
7 |
-
|
8 |
-
from pbr.version import VersionInfo
|
9 |
-
print("hf_search version:", VersionInfo('hf_search').version_string())
|
10 |
-
|
11 |
-
hf_search = HFSearch(top_k=200)
|
12 |
-
|
13 |
-
@st.cache
|
14 |
-
def hf_api(query, limit=5, sort=None, filters={}):
|
15 |
-
print("query", query)
|
16 |
-
print("filters", filters)
|
17 |
-
print("limit", limit)
|
18 |
-
print("sort", sort)
|
19 |
-
|
20 |
-
api = HfApi()
|
21 |
-
filt = ModelFilter(
|
22 |
-
task=filters["task"],
|
23 |
-
library=filters["library"],
|
24 |
-
)
|
25 |
-
models = api.list_models(search=query, filter=filt, limit=limit, sort=sort, full=True)
|
26 |
-
hits = []
|
27 |
-
for model in models:
|
28 |
-
model = model.__dict__
|
29 |
-
hits.append(
|
30 |
-
{
|
31 |
-
"modelId": model.get("modelId"),
|
32 |
-
"tags": model.get("tags"),
|
33 |
-
"downloads": model.get("downloads"),
|
34 |
-
"likes": model.get("likes"),
|
35 |
-
}
|
36 |
-
)
|
37 |
-
count = len(hits)
|
38 |
-
if len(hits) > limit:
|
39 |
-
hits = hits[:limit]
|
40 |
-
return {"hits": hits, "count": count}
|
41 |
-
|
42 |
-
|
43 |
-
@st.cache
|
44 |
-
def semantic_search(query, limit=5, sort=None, filters={}):
|
45 |
-
print("query", query)
|
46 |
-
print("filters", filters)
|
47 |
-
print("limit", limit)
|
48 |
-
print("sort", sort)
|
49 |
-
|
50 |
-
hits = hf_search.search(query=query, method="retrieve & rerank", limit=limit, sort=sort, filters=filters)
|
51 |
-
hits = [
|
52 |
-
{
|
53 |
-
"modelId": hit["modelId"],
|
54 |
-
"tags": hit["tags"],
|
55 |
-
"downloads": hit["downloads"],
|
56 |
-
"likes": hit["likes"],
|
57 |
-
"readme": hit.get("readme", None),
|
58 |
-
}
|
59 |
-
for hit in hits
|
60 |
-
]
|
61 |
-
return {"hits": hits, "count": len(hits)}
|
62 |
-
|
63 |
-
|
64 |
-
@st.cache
|
65 |
-
def bm25_search(query, limit=5, sort=None, filters={}):
|
66 |
-
print("query", query)
|
67 |
-
print("filters", filters)
|
68 |
-
print("limit", limit)
|
69 |
-
print("sort", sort)
|
70 |
-
|
71 |
-
# TODO: filters
|
72 |
-
hits = hf_search.search(query=query, method="bm25", limit=limit, sort=sort, filters=filters)
|
73 |
-
hits = [
|
74 |
-
{
|
75 |
-
"modelId": hit["modelId"],
|
76 |
-
"tags": hit["tags"],
|
77 |
-
"downloads": hit["downloads"],
|
78 |
-
"likes": hit["likes"],
|
79 |
-
"readme": hit.get("readme", None),
|
80 |
-
}
|
81 |
-
for hit in hits
|
82 |
-
]
|
83 |
-
hits = [
|
84 |
-
hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]]
|
85 |
-
] # unique hits
|
86 |
-
return {"hits": hits, "count": len(hits)}
|
87 |
-
|
88 |
-
|
89 |
-
def paginator(label, articles, articles_per_page=10, on_sidebar=True):
|
90 |
-
# https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7
|
91 |
-
"""Lets the user paginate a set of article.
|
92 |
-
Parameters
|
93 |
-
----------
|
94 |
-
label : str
|
95 |
-
The label to display over the pagination widget.
|
96 |
-
article : Iterator[Any]
|
97 |
-
The articles to display in the paginator.
|
98 |
-
articles_per_page: int
|
99 |
-
The number of articles to display per page.
|
100 |
-
on_sidebar: bool
|
101 |
-
Whether to display the paginator widget on the sidebar.
|
102 |
-
|
103 |
-
Returns
|
104 |
-
-------
|
105 |
-
Iterator[Tuple[int, Any]]
|
106 |
-
An iterator over *only the article on that page*, including
|
107 |
-
the item's index.
|
108 |
-
"""
|
109 |
-
|
110 |
-
# Figure out where to display the paginator
|
111 |
-
if on_sidebar:
|
112 |
-
location = st.sidebar.empty()
|
113 |
-
else:
|
114 |
-
location = st.empty()
|
115 |
-
|
116 |
-
# Display a pagination selectbox in the specified location.
|
117 |
-
articles = list(articles)
|
118 |
-
n_pages = (len(articles) - 1) // articles_per_page + 1
|
119 |
-
page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}"
|
120 |
-
page_number = location.selectbox(label, range(n_pages), format_func=page_format_func)
|
121 |
-
|
122 |
-
# Iterate over the articles in the page to let the user display them.
|
123 |
-
min_index = page_number * articles_per_page
|
124 |
-
max_index = min_index + articles_per_page
|
125 |
-
|
126 |
-
return itertools.islice(enumerate(articles), min_index, max_index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|