Spaces:
Runtime error
Runtime error
import asyncio | |
import copy | |
import os | |
from dataclasses import asdict, dataclass | |
from datetime import datetime, timedelta | |
from functools import lru_cache | |
from json import JSONDecodeError | |
from typing import Any, Dict, List, Optional, Union | |
import gradio as gr | |
import httpx | |
import orjson | |
from cachetools import TTLCache, cached | |
from cashews import NOT_NONE, cache | |
from httpx import AsyncClient, Client | |
from huggingface_hub import hf_hub_url, logging | |
from huggingface_hub.utils import disable_progress_bars | |
from rich import print | |
from tqdm.auto import tqdm | |
CACHE_EXPIRY_TIME = timedelta(hours=3) | |
sync_cache = TTLCache(maxsize=200_000, ttl=CACHE_EXPIRY_TIME, timer=datetime.now) | |
cache.setup("mem://") | |
disable_progress_bars() | |
logging.set_verbosity_error() | |
token = os.getenv("HF_TOKEN") | |
headers = {"authorization": f"Bearer {token}"} | |
async def get_model_labels(model, client): | |
try: | |
url = hf_hub_url(repo_id=model, filename="config.json") | |
resp = await client.get(url, timeout=2) | |
return list(resp.json()["label2id"].keys()) | |
except (KeyError, JSONDecodeError, AttributeError): | |
return None | |
async def _try_load_model_card(hub_id, client=None): | |
if not client: | |
client = AsyncClient(headers=headers) | |
try: | |
url = hf_hub_url( | |
repo_id=hub_id, filename="README.md" | |
) # We grab card this way rather than via client library to improve performance | |
resp = await client.get(url) | |
if resp.status_code == 200: | |
card_text = resp.text | |
length = len(card_text) | |
elif resp.status_code == 404: | |
card_text = None | |
length = 0 | |
except httpx.ConnectError: | |
card_text = None | |
length = None | |
return card_text, length | |
def _try_parse_card_data(hub_json_data): | |
data = {} | |
keys = ["license", "language", "datasets"] | |
for key in keys: | |
if card_data := hub_json_data.get("cardData"): | |
try: | |
data[key] = card_data.get(key) | |
except (KeyError, AttributeError): | |
data[key] = None | |
else: | |
data[key] = None | |
return data | |
class ModelMetadata: | |
hub_id: str | |
tags: Optional[List[str]] | |
license: Optional[str] | |
library_name: Optional[str] | |
datasets: Optional[List[str]] | |
pipeline_tag: Optional[str] | |
labels: Optional[List[str]] | |
languages: Optional[Union[str, List[str]]] | |
model_card_text: Optional[str] = None | |
model_card_length: Optional[int] = None | |
likes: Optional[int] = None | |
downloads: Optional[int] = None | |
created_at: Optional[datetime] = None | |
async def from_hub(cls, hub_id, client=None): | |
try: | |
if not client: | |
client = httpx.AsyncClient() | |
url = f"https://huggingface.co./api/models/{hub_id}" | |
resp = await client.get(url) | |
hub_json_data = resp.json() | |
card_text, length = await _try_load_model_card(hub_id) | |
data = _try_parse_card_data(hub_json_data) | |
library_name = hub_json_data.get("library_name") | |
pipeline_tag = hub_json_data.get("pipeline_tag") | |
downloads = hub_json_data.get("downloads") | |
likes = hub_json_data.get("likes") | |
tags = hub_json_data.get("tags") | |
labels = await get_model_labels(hub_id, client) | |
return ModelMetadata( | |
hub_id=hub_id, | |
languages=data["language"], | |
tags=tags, | |
license=data["license"], | |
library_name=library_name, | |
datasets=data["datasets"], | |
pipeline_tag=pipeline_tag, | |
labels=labels, | |
model_card_text=card_text, | |
downloads=downloads, | |
likes=likes, | |
model_card_length=length, | |
) | |
except Exception as e: | |
print(f"Failed to create ModelMetadata for model {hub_id}: {str(e)}") | |
return None | |
COMMON_SCORES = { | |
"license": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You have not added a license to your models metadata" | |
), | |
}, | |
"datasets": { | |
"required": False, | |
"score": 1, | |
"missing_recommendation": ( | |
"You have not added any datasets to your models metadata" | |
), | |
}, | |
"model_card_text": { | |
"required": True, | |
"score": 3, | |
"missing_recommendation": """You haven't created a model card for your model. It is strongly recommended to have a model card for your model. \nYou can create for your model by clicking [here](https://huggingface.co./HUB_ID/edit/main/README.md)""", | |
}, | |
"tags": { | |
"required": False, | |
"score": 2, | |
"missing_recommendation": ( | |
"You don't have any tags defined in your model metadata. Tags can help" | |
" people find relevant models on the Hub. You can create for your model by" | |
" clicking [here](https://huggingface.co./HUB_ID/edit/main/README.md)" | |
), | |
}, | |
} | |
TASK_TYPES_WITH_LANGUAGES = { | |
"text-classification", | |
"token-classification", | |
"table-question-answering", | |
"question-answering", | |
"zero-shot-classification", | |
"translation", | |
"summarization", | |
"text-generation", | |
"text2text-generation", | |
"fill-mask", | |
"sentence-similarity", | |
"text-to-speech", | |
"automatic-speech-recognition", | |
"text-to-image", | |
"image-to-text", | |
"visual-question-answering", | |
"document-question-answering", | |
} | |
LABELS_REQUIRED_TASKS = { | |
"text-classification", | |
"token-classification", | |
"object-detection", | |
"audio-classification", | |
"image-classification", | |
"tabular-classification", | |
} | |
ALL_PIPELINES = { | |
"audio-classification", | |
"audio-to-audio", | |
"automatic-speech-recognition", | |
"conversational", | |
"depth-estimation", | |
"document-question-answering", | |
"feature-extraction", | |
"fill-mask", | |
"graph-ml", | |
"image-classification", | |
"image-segmentation", | |
"image-to-image", | |
"image-to-text", | |
"object-detection", | |
"question-answering", | |
"reinforcement-learning", | |
"robotics", | |
"sentence-similarity", | |
"summarization", | |
"table-question-answering", | |
"tabular-classification", | |
"tabular-regression", | |
"text-classification", | |
"text-generation", | |
"text-to-image", | |
"text-to-speech", | |
"text-to-video", | |
"text2text-generation", | |
"token-classification", | |
"translation", | |
"unconditional-image-generation", | |
"video-classification", | |
"visual-question-answering", | |
"voice-activity-detection", | |
"zero-shot-classification", | |
"zero-shot-image-classification", | |
} | |
def generate_task_scores_dict(): | |
task_scores = {} | |
for task in ALL_PIPELINES: | |
task_dict = copy.deepcopy(COMMON_SCORES) | |
if task in TASK_TYPES_WITH_LANGUAGES: | |
task_dict = { | |
**task_dict, | |
**{ | |
"languages": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You haven't defined any languages in your metadata. This" | |
f" is usually recommend for {task} task" | |
), | |
} | |
}, | |
} | |
if task in LABELS_REQUIRED_TASKS: | |
task_dict = { | |
**task_dict, | |
**{ | |
"labels": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You haven't defined any labels in the config.json file" | |
f" these are usually recommended for {task}" | |
), | |
} | |
}, | |
} | |
max_score = sum(value["score"] for value in task_dict.values()) | |
task_dict["_max_score"] = max_score | |
task_scores[task] = task_dict | |
return task_scores | |
def generate_common_scores(): | |
GENERIC_SCORES = copy.deepcopy(COMMON_SCORES) | |
GENERIC_SCORES["_max_score"] = sum( | |
value["score"] for value in GENERIC_SCORES.values() | |
) | |
return GENERIC_SCORES | |
SCORES = generate_task_scores_dict() | |
GENERIC_SCORES = generate_common_scores() | |
def _basic_check(data: Optional[ModelMetadata]): | |
score = 0 | |
if data is None: | |
return None | |
hub_id = data.hub_id | |
to_fix = {} | |
if task := data.pipeline_tag: | |
task_scores = SCORES[task] | |
data_dict = asdict(data) | |
for k, v in task_scores.items(): | |
if k.startswith("_"): | |
continue | |
if data_dict[k] is None: | |
to_fix[k] = task_scores[k]["missing_recommendation"].replace( | |
"HUB_ID", hub_id | |
) | |
if data_dict[k] is not None: | |
score += v["score"] | |
max_score = task_scores["_max_score"] | |
score = score / max_score | |
( | |
f"Your model's metadata score is {round(score*100)}% based on suggested" | |
f" metadata for {task}. \n" | |
) | |
if to_fix: | |
recommendations = ( | |
"Here are some suggestions to improve your model's metadata for" | |
f" {task}: \n" | |
) | |
for v in to_fix.values(): | |
recommendations += f"\n- {v}" | |
data_dict["recommendations"] = recommendations | |
data_dict["score"] = score * 100 | |
else: | |
data_dict = asdict(data) | |
for k, v in GENERIC_SCORES.items(): | |
if k.startswith("_"): | |
continue | |
if data_dict[k] is None: | |
to_fix[k] = GENERIC_SCORES[k]["missing_recommendation"].replace( | |
"HUB_ID", hub_id | |
) | |
if data_dict[k] is not None: | |
score += v["score"] | |
score = score / GENERIC_SCORES["_max_score"] | |
data_dict["score"] = max( | |
0, (score / 2) * 100 | |
) # TODO currently setting a manual penalty for not having a task | |
return orjson.dumps(data_dict) | |
def basic_check(hub_id): | |
return _basic_check(hub_id) | |
def create_query_url(query, skip=0): | |
return f"https://huggingface.co./api/search/full-text?q={query}&limit=100&skip={skip}&type=model" | |
def get_results(query, sync_client=None) -> Dict[Any, Any]: | |
if not sync_client: | |
sync_client = Client(http2=True, headers=headers) | |
url = create_query_url(query) | |
r = sync_client.get(url) | |
return r.json() | |
def parse_single_result(result): | |
name, filename = result["name"], result["fileName"] | |
search_result_file_url = hf_hub_url(name, filename) | |
repo_hub_url = f"https://huggingface.co./{name}" | |
return { | |
"name": name, | |
"search_result_file_url": search_result_file_url, | |
"repo_hub_url": repo_hub_url, | |
} | |
async def get_hub_models(results, client=None): | |
parsed_results = [parse_single_result(result) for result in results] | |
if not client: | |
client = AsyncClient(http2=True, headers=headers) | |
model_ids = [result["name"] for result in parsed_results] | |
model_objs = [ModelMetadata.from_hub(model, client=client) for model in model_ids] | |
models = await asyncio.gather(*model_objs) | |
results = [] | |
for result, model in zip(parsed_results, models): | |
score = _basic_check(model) | |
# print(f"score for {model} is {score}") | |
if score is not None: | |
score = orjson.loads(score) | |
result["metadata_score"] = score["score"] | |
result["model_card_length"] = score["model_card_length"] | |
result["is_licensed"] = (bool(score["license"]),) | |
results.append(result) | |
else: | |
results.append(None) | |
return results | |
def filter_for_license(results): | |
for result in results: | |
if result["is_licensed"]: | |
yield result | |
def filter_for_min_model_card_length(results, min_model_card_length): | |
for result in results: | |
if result["model_card_length"] > min_model_card_length: | |
yield result | |
def filter_search_results( | |
results: List[Dict[Any, Any]], | |
min_score=None, | |
min_model_card_length=None, | |
): # TODO make code more intuitive | |
# TODO setup filters as separate functions and chain results | |
results = asyncio.run(get_hub_models(results)) | |
for i, parsed_result in tqdm(enumerate(results)): | |
# parsed_result = parse_single_result(result) | |
if parsed_result is None: | |
continue | |
if ( | |
min_score is None | |
and min_model_card_length is not None | |
and parsed_result["model_card_length"] > min_model_card_length | |
or min_score is None | |
and min_model_card_length is None | |
): | |
parsed_result["original_position"] = i | |
yield parsed_result | |
elif min_score is not None: | |
if parsed_result["metadata_score"] <= min_score: | |
continue | |
if ( | |
min_model_card_length is not None | |
and parsed_result["model_card_length"] > min_model_card_length | |
or min_model_card_length is None | |
): | |
parsed_result["original_position"] = i | |
yield parsed_result | |
def sort_search_results( | |
filtered_search_results, | |
first_sort_key="metadata_score", | |
second_sort_key="original_position", # TODO expose these in results | |
): | |
return sorted( | |
list(filtered_search_results), | |
key=lambda x: (x[first_sort_key], x[second_sort_key]), | |
reverse=True, | |
) | |
def find_context(text, query, window_size): | |
# Split the text into words | |
words = text.split() | |
# Find the index of the query token | |
try: | |
index = words.index(query) | |
# Get the start and end indices of the context window | |
start = max(0, index - window_size) | |
end = min(len(words), index + window_size + 1) | |
return " ".join(words[start:end]) | |
except ValueError: | |
return " ".join(words[:window_size]) | |
def create_markdown(results): # TODO move to separate file | |
rows = [] | |
for result in results: | |
row = f"""# [{result['name']}]({result['repo_hub_url']}) | |
| Metadata Quality Score | Model card length | Licensed | | |
|------------------------|-------------------|----------| | |
| {result['metadata_score']:.0f}% | {result['model_card_length']} | {"✅" if result['is_licensed'] else "❌"} | | |
\n | |
*{result['text']}* | |
<hr> | |
\n""" | |
rows.append(row) | |
return "\n".join(rows) | |
async def get_result_card_snippet(result, query=None, client=None): | |
if not client: | |
client = AsyncClient(http2=True, headers=headers) | |
try: | |
resp = await client.get(result["search_result_file_url"]) | |
result_text = resp.text | |
result["text"] = find_context(result_text, query, 100) | |
except httpx.ConnectError: | |
result["text"] = "Could not load model card" | |
return result | |
async def get_result_card_snippets(results, query=None, client=None): | |
if not client: | |
client = AsyncClient(http2=True, headers=headers) | |
result_snippets = [ | |
get_result_card_snippet(result, query=query, client=client) | |
for result in results | |
] | |
results = await asyncio.gather(*result_snippets) | |
return results | |
sync_client = Client(http2=True, headers=headers) | |
def _search_hub( | |
query: str, | |
min_score: Optional[int] = None, | |
min_model_card_length: Optional[int] = None, | |
): | |
results = get_results(query, sync_client) | |
print(f"Found {len(results['hits'])} results") | |
results = results["hits"] | |
number_original_results = len(results) | |
filtered_results = filter_search_results( | |
results, min_score=min_score, min_model_card_length=min_model_card_length | |
) | |
filtered_results = sort_search_results(filtered_results) | |
final_results = asyncio.run(get_result_card_snippets(filtered_results, query=query)) | |
percent_of_original = round( | |
len(final_results) / number_original_results * 100, ndigits=0 | |
) | |
filtered_vs_og = f""" | |
| Number of original results | Number of results after filtering | Percentage of results after filtering | | |
| -------------------------- | --------------------------------- | -------------------------------------------- | | |
| {number_original_results} | {len(final_results)} | {percent_of_original}% | | |
""" | |
return filtered_vs_og, create_markdown(final_results) | |
def search_hub(query: str, min_score=None, min_model_card_length=None): | |
return _search_hub(query, min_score, min_model_card_length) | |
with gr.Blocks() as demo: | |
with gr.Tab("Hub Search with metadata quality filter"): | |
gr.Markdown("# 🤗 Hub model search with metadata quality filters") | |
gr.Markdown( | |
"""This search tool relies on the full-text search API. | |
Your search is passed to this API and the returned models are assessed for metadata quality. | |
If you don't specify any minimum requirements you will get back your results with metadata quality info | |
for each result. The results are ordered by: | |
- Metadata quality i.e. a model with 80% metadata quality will rank higher than one with 75% | |
- Original search order i.e. if two models have the same metadata quality the one that appeared first in the original search will rank higher. | |
If there is interest in this app I will expose more options for filtering and sorting results. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
query = gr.Textbox("historic", label="Search query") | |
with gr.Column(): | |
button = gr.Button("Search") | |
with gr.Row(): | |
# literal_search = gr.Checkbox(False, label="Literal_search") | |
# TODO add option for exact matching i.e. phrase matching | |
# gr.Checkbox(False, label="Must have license?") | |
mim_model_card_length = gr.Number( | |
None, label="Minimum model card length" | |
) | |
min_metadata_score = gr.Slider(0, label="Minimum metadata score") | |
filter_results = gr.Markdown("Filter results vs original search") | |
results_markdown = gr.Markdown("Search results") | |
button.click( | |
search_hub, | |
[query, min_metadata_score, mim_model_card_length], | |
[filter_results, results_markdown], | |
) | |
# with gr.Tab("Scoring metadata quality"): | |
# with gr.Row(): | |
# gr.Markdown( | |
# f""" | |
# # Metadata quality scoring | |
# ``` | |
# {COMMON_SCORES} | |
# ``` | |
# For example, `TASK_TYPES_WITH_LANGUAGES` defines all the tasks for which it | |
# is expected to have language metadata associated with the model. | |
# ``` | |
# {TASK_TYPES_WITH_LANGUAGES} | |
# ``` | |
# """ | |
# ) | |
demo.launch() | |