Spaces:
Running
on
Zero
Running
on
Zero
import datetime | |
import difflib | |
import json | |
import re | |
import tempfile | |
import gradio as gr | |
import polars as pl | |
from gradio_modal import Modal | |
from huggingface_hub import CommitOperationAdd, HfApi | |
from table import PATCH_REPO_ID, df_orig | |
# TODO: remove this once https://github.com/gradio-app/gradio/issues/11022 is fixed # noqa: FIX002, TD002 | |
NOTE = """\ | |
#### ⚠️ Note | |
You may encounter an issue when selecting table data after using the search bar. | |
This is due to a known bug in Gradio. | |
The issue typically occurs when multiple rows remain after filtering. | |
If only one row remains, the selection should work as expected. | |
""" | |
api = HfApi() | |
PR_VIEW_COLUMNS = [ | |
"title", | |
"authors_str", | |
"openreview_md", | |
"arxiv_id", | |
"github_md", | |
"Spaces", | |
"Models", | |
"Datasets", | |
"paper_id", | |
] | |
PR_RAW_COLUMNS = [ | |
"paper_id", | |
"title", | |
"authors", | |
"arxiv_id", | |
"project_page", | |
"github", | |
"space_ids", | |
"model_ids", | |
"dataset_ids", | |
] | |
df_pr_view = df_orig.with_columns(pl.lit("📝").alias("Fix")).select(["Fix", *PR_VIEW_COLUMNS]) | |
df_pr_view = df_pr_view.with_columns(pl.col("arxiv_id").fill_null("")) | |
df_pr_raw = df_orig.select(PR_RAW_COLUMNS) | |
def df_pr_row_selected( | |
evt: gr.SelectData, | |
) -> tuple[ | |
Modal, | |
gr.Textbox, # title | |
gr.Textbox, # authors | |
gr.Textbox, # arxiv_id | |
gr.Textbox, # project_page | |
gr.Textbox, # github | |
gr.Textbox, # space_ids | |
gr.Textbox, # model_ids | |
gr.Textbox, # dataset_ids | |
dict | None, # original_data | |
]: | |
if evt.value != "📝": | |
return ( | |
Modal(), | |
gr.Textbox(), # title | |
gr.Textbox(), # authors | |
gr.Textbox(), # arxiv_id | |
gr.Textbox(), # project_page | |
gr.Textbox(), # github | |
gr.Textbox(), # space_ids | |
gr.Textbox(), # model_ids | |
gr.Textbox(), # dataset_ids | |
None, # original_data | |
) | |
paper_id = evt.row_value[-1] | |
row = df_pr_raw.filter(pl.col("paper_id") == paper_id) | |
original_data = row.to_dicts()[0] | |
authors = original_data["authors"] | |
space_ids = original_data["space_ids"] | |
model_ids = original_data["model_ids"] | |
dataset_ids = original_data["dataset_ids"] | |
return ( | |
Modal(visible=True), | |
gr.Textbox(value=row["title"].item()), # title | |
gr.Textbox(value="\n".join(authors)), # authors | |
gr.Textbox(value=row["arxiv_id"].item()), # arxiv_id | |
gr.Textbox(value=row["project_page"].item()), # project_page | |
gr.Textbox(value=row["github"].item()), # github | |
gr.Textbox(value="\n".join(space_ids)), # space_ids | |
gr.Textbox(value="\n".join(model_ids)), # model_ids | |
gr.Textbox(value="\n".join(dataset_ids)), # dataset_ids | |
original_data, # original_data | |
) | |
URL_PATTERN = re.compile(r"^(https?://)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(:\d+)?(/.*)?$") | |
GITHUB_PATTERN = re.compile(r"^https://github\.com/[^/\s]+/[^/\s]+(/tree/[^/\s]+/[^/\s].*)?$") | |
REPO_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$") | |
ARXIV_ID_PATTERN = re.compile(r"^\d{4}\.\d{4,5}$") | |
def is_valid_url(url: str) -> bool: | |
return URL_PATTERN.match(url) is not None | |
def is_valid_github_url(url: str) -> bool: | |
return GITHUB_PATTERN.match(url) is not None | |
def is_valid_repo_id(repo_id: str) -> bool: | |
return REPO_ID_PATTERN.match(repo_id) is not None | |
def is_valid_arxiv_id(arxiv_id: str) -> bool: | |
return ARXIV_ID_PATTERN.match(arxiv_id) is not None | |
def validate_pr_data( | |
title_pr: str, | |
authors_pr: str, | |
arxiv_id_pr: str, | |
project_page_pr: str, | |
github_pr: str, | |
space_ids: list[str], | |
model_ids: list[str], | |
dataset_ids: list[str], | |
) -> None: | |
if not title_pr: | |
raise gr.Error("Title cannot be empty", print_exception=False) | |
if not authors_pr: | |
raise gr.Error("Authors cannot be empty", print_exception=False) | |
if arxiv_id_pr and not is_valid_arxiv_id(arxiv_id_pr): | |
raise gr.Error( | |
"Invalid arXiv ID format. Expected format: 'YYYY.NNNNN' (e.g., '2023.01234')", print_exception=False | |
) | |
if project_page_pr and not is_valid_url(project_page_pr): | |
raise gr.Error("Project page must be a valid URL", print_exception=False) | |
if github_pr and not is_valid_github_url(github_pr): | |
raise gr.Error("GitHub must be a valid GitHub URL", print_exception=False) | |
for repo_id in space_ids + model_ids + dataset_ids: | |
if not is_valid_repo_id(repo_id): | |
error_msg = f"Space/Model/Dataset ID must be in the format 'org_name/repo_name'. Got: {repo_id}" | |
raise gr.Error(error_msg, print_exception=False) | |
def format_submitted_data( | |
title_pr: str, | |
authors_pr: str, | |
arxiv_id_pr: str, | |
project_page_pr: str, | |
github_pr: str, | |
space_ids_pr: str, | |
model_ids_pr: str, | |
dataset_ids_pr: str, | |
) -> dict: | |
space_ids = [repo_id for repo_id in space_ids_pr.split("\n") if repo_id.strip()] | |
model_ids = [repo_id for repo_id in model_ids_pr.split("\n") if repo_id.strip()] | |
dataset_ids = [repo_id for repo_id in dataset_ids_pr.split("\n") if repo_id.strip()] | |
validate_pr_data(title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, space_ids, model_ids, dataset_ids) | |
return { | |
"title": title_pr, | |
"authors": [a for a in authors_pr.split("\n") if a.strip()], | |
"arxiv_id": arxiv_id_pr if arxiv_id_pr else None, | |
"project_page": project_page_pr if project_page_pr else None, | |
"github": github_pr if github_pr else None, | |
"space_ids": space_ids, | |
"model_ids": model_ids, | |
"dataset_ids": dataset_ids, | |
} | |
def preview_diff( | |
title_pr: str, | |
authors_pr: str, | |
arxiv_id_pr: str, | |
project_page_pr: str, | |
github_pr: str, | |
space_ids_pr: str, | |
model_ids_pr: str, | |
dataset_ids_pr: str, | |
original_data: dict, | |
) -> tuple[gr.Markdown, gr.Button]: | |
submitted_data = format_submitted_data( | |
title_pr, | |
authors_pr, | |
arxiv_id_pr, | |
project_page_pr, | |
github_pr, | |
space_ids_pr, | |
model_ids_pr, | |
dataset_ids_pr, | |
) | |
submitted_data = {"paper_id": original_data["paper_id"], **submitted_data} | |
original_json = json.dumps(original_data, indent=2) | |
submitted_json = json.dumps(submitted_data, indent=2) | |
diff = difflib.unified_diff( | |
original_json.splitlines(), | |
submitted_json.splitlines(), | |
fromfile="before", | |
tofile="after", | |
lineterm="", | |
) | |
diff_str = "\n".join(diff) | |
return gr.Markdown(value=f"```diff\n{diff_str}\n```"), gr.Button(visible=True) | |
def open_pr( | |
title_pr: str, | |
authors_pr: str, | |
arxiv_id_pr: str, | |
project_page_pr: str, | |
github_pr: str, | |
space_ids_pr: str, | |
model_ids_pr: str, | |
dataset_ids_pr: str, | |
original_data: dict, | |
oauth_token: gr.OAuthToken | None, | |
) -> gr.Markdown: | |
submitted_data = format_submitted_data( | |
title_pr, | |
authors_pr, | |
arxiv_id_pr, | |
project_page_pr, | |
github_pr, | |
space_ids_pr, | |
model_ids_pr, | |
dataset_ids_pr, | |
) | |
diff_dict = {key: submitted_data[key] for key in submitted_data if submitted_data[key] != original_data[key]} | |
if not diff_dict: | |
gr.Info("No data to submit") | |
return "" | |
paper_id = original_data["paper_id"] | |
diff_dict["paper_id"] = paper_id | |
original_json = json.dumps(original_data, indent=2) | |
submitted_json = json.dumps(submitted_data, indent=2) | |
diff = "\n".join( | |
difflib.unified_diff( | |
original_json.splitlines(), | |
submitted_json.splitlines(), | |
fromfile="before", | |
tofile="after", | |
lineterm="", | |
) | |
) | |
diff_dict["diff"] = diff | |
timestamp = datetime.datetime.now(datetime.timezone.utc) | |
diff_dict["timestamp"] = timestamp.isoformat() | |
with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: | |
json.dump(diff_dict, f, indent=2) | |
f.flush() | |
commit = CommitOperationAdd(f"data/{paper_id}--{timestamp.strftime('%Y-%m-%d-%H-%M-%S')}.json", f.name) | |
res = api.create_commit( | |
repo_id=PATCH_REPO_ID, | |
operations=[commit], | |
commit_message=f"Update {paper_id}", | |
repo_type="dataset", | |
create_pr=True, | |
token=oauth_token.token if oauth_token else None, | |
) | |
return gr.Markdown(value=res.pr_url, visible=True) | |
def render_open_pr_page(profile: gr.OAuthProfile | None) -> dict: | |
return gr.Column(visible=profile is not None) | |
with gr.Blocks() as demo: | |
gr.LoginButton() | |
with gr.Column(visible=False) as open_pr_col: | |
gr.Markdown(NOTE) | |
df_pr = gr.Dataframe( | |
value=df_pr_view, | |
datatype=[ | |
"str", # Fix | |
"str", # Title | |
"str", # Authors | |
"markdown", # openreview | |
"str", # arxiv_id | |
"markdown", # github | |
"markdown", # spaces | |
"markdown", # models | |
"markdown", # datasets | |
"str", # paper id | |
], | |
column_widths=[ | |
"50px", # Fix | |
"40%", # Title | |
"20%", # Authors | |
None, # openreview | |
"100px", # arxiv_id | |
None, # github | |
None, # spaces | |
None, # models | |
None, # datasets | |
None, # paper id | |
], | |
type="polars", | |
row_count=(0, "dynamic"), | |
interactive=False, | |
max_height=1000, | |
show_search="search", | |
) | |
with Modal(visible=False) as pr_modal: | |
with gr.Group(): | |
title_pr = gr.Textbox(label="Title") | |
authors_pr = gr.Textbox(label="Authors") | |
arxiv_id_pr = gr.Textbox(label="arXiv ID") | |
project_page_pr = gr.Textbox(label="Project page") | |
github_pr = gr.Textbox(label="GitHub") | |
spaces_pr = gr.Textbox( | |
label="Spaces", | |
info="Enter one space ID (e.g., 'org_name/space_name') per line.", | |
) | |
models_pr = gr.Textbox( | |
label="Models", | |
info="Enter one model ID (e.g., 'org_name/model_name') per line.", | |
) | |
datasets_pr = gr.Textbox( | |
label="Datasets", | |
info="Enter one dataset ID (e.g., 'org_name/dataset_name') per line.", | |
) | |
original_data = gr.State() | |
preview_diff_button = gr.Button("Preview diff") | |
diff_view = gr.Markdown() | |
open_pr_button = gr.Button("Open PR", visible=False) | |
pr_url = gr.Markdown(visible=False) | |
pr_modal.blur( | |
fn=lambda: (None, gr.Button(visible=False), gr.Markdown(visible=False)), | |
outputs=[diff_view, open_pr_button, pr_url], | |
) | |
df_pr.select( | |
fn=df_pr_row_selected, | |
outputs=[ | |
pr_modal, | |
title_pr, | |
authors_pr, | |
arxiv_id_pr, | |
project_page_pr, | |
github_pr, | |
spaces_pr, | |
models_pr, | |
datasets_pr, | |
original_data, | |
], | |
) | |
preview_diff_button.click( | |
fn=preview_diff, | |
inputs=[ | |
title_pr, | |
authors_pr, | |
arxiv_id_pr, | |
project_page_pr, | |
github_pr, | |
spaces_pr, | |
models_pr, | |
datasets_pr, | |
original_data, | |
], | |
outputs=[diff_view, open_pr_button], | |
) | |
open_pr_button.click( | |
fn=open_pr, | |
inputs=[ | |
title_pr, | |
authors_pr, | |
arxiv_id_pr, | |
project_page_pr, | |
github_pr, | |
spaces_pr, | |
models_pr, | |
datasets_pr, | |
original_data, | |
], | |
outputs=pr_url, | |
) | |
demo.load(fn=render_open_pr_page, outputs=open_pr_col) | |
if __name__ == "__main__": | |
demo.queue(api_open=False).launch(show_api=False) | |