import datetime import difflib import json import re import tempfile import gradio as gr import polars as pl from gradio_modal import Modal from huggingface_hub import CommitOperationAdd, HfApi from table import PATCH_REPO_ID, df_orig # TODO: remove this once https://github.com/gradio-app/gradio/issues/11022 is fixed # noqa: FIX002, TD002 NOTE = """\ #### ⚠️ Note You may encounter an issue when selecting table data after using the search bar. This is due to a known bug in Gradio. The issue typically occurs when multiple rows remain after filtering. If only one row remains, the selection should work as expected. """ api = HfApi() PR_VIEW_COLUMNS = [ "title", "authors_str", "openreview_md", "arxiv_id", "github_md", "Spaces", "Models", "Datasets", "paper_id", ] PR_RAW_COLUMNS = [ "paper_id", "title", "authors", "arxiv_id", "project_page", "github", "space_ids", "model_ids", "dataset_ids", ] df_pr_view = df_orig.with_columns(pl.lit("📝").alias("Fix")).select(["Fix", *PR_VIEW_COLUMNS]) df_pr_view = df_pr_view.with_columns(pl.col("arxiv_id").fill_null("")) df_pr_raw = df_orig.select(PR_RAW_COLUMNS) def df_pr_row_selected( evt: gr.SelectData, ) -> tuple[ Modal, gr.Textbox, # title gr.Textbox, # authors gr.Textbox, # arxiv_id gr.Textbox, # project_page gr.Textbox, # github gr.Textbox, # space_ids gr.Textbox, # model_ids gr.Textbox, # dataset_ids dict | None, # original_data ]: if evt.value != "📝": return ( Modal(), gr.Textbox(), # title gr.Textbox(), # authors gr.Textbox(), # arxiv_id gr.Textbox(), # project_page gr.Textbox(), # github gr.Textbox(), # space_ids gr.Textbox(), # model_ids gr.Textbox(), # dataset_ids None, # original_data ) paper_id = evt.row_value[-1] row = df_pr_raw.filter(pl.col("paper_id") == paper_id) original_data = row.to_dicts()[0] authors = original_data["authors"] space_ids = original_data["space_ids"] model_ids = original_data["model_ids"] dataset_ids = original_data["dataset_ids"] return ( Modal(visible=True), gr.Textbox(value=row["title"].item()), # title gr.Textbox(value="\n".join(authors)), # authors gr.Textbox(value=row["arxiv_id"].item()), # arxiv_id gr.Textbox(value=row["project_page"].item()), # project_page gr.Textbox(value=row["github"].item()), # github gr.Textbox(value="\n".join(space_ids)), # space_ids gr.Textbox(value="\n".join(model_ids)), # model_ids gr.Textbox(value="\n".join(dataset_ids)), # dataset_ids original_data, # original_data ) URL_PATTERN = re.compile(r"^(https?://)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(:\d+)?(/.*)?$") GITHUB_PATTERN = re.compile(r"^https://github\.com/[^/\s]+/[^/\s]+(/tree/[^/\s]+/[^/\s].*)?$") REPO_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$") ARXIV_ID_PATTERN = re.compile(r"^\d{4}\.\d{4,5}$") def is_valid_url(url: str) -> bool: return URL_PATTERN.match(url) is not None def is_valid_github_url(url: str) -> bool: return GITHUB_PATTERN.match(url) is not None def is_valid_repo_id(repo_id: str) -> bool: return REPO_ID_PATTERN.match(repo_id) is not None def is_valid_arxiv_id(arxiv_id: str) -> bool: return ARXIV_ID_PATTERN.match(arxiv_id) is not None def validate_pr_data( title_pr: str, authors_pr: str, arxiv_id_pr: str, project_page_pr: str, github_pr: str, space_ids: list[str], model_ids: list[str], dataset_ids: list[str], ) -> None: if not title_pr: raise gr.Error("Title cannot be empty", print_exception=False) if not authors_pr: raise gr.Error("Authors cannot be empty", print_exception=False) if arxiv_id_pr and not is_valid_arxiv_id(arxiv_id_pr): raise gr.Error( "Invalid arXiv ID format. Expected format: 'YYYY.NNNNN' (e.g., '2023.01234')", print_exception=False ) if project_page_pr and not is_valid_url(project_page_pr): raise gr.Error("Project page must be a valid URL", print_exception=False) if github_pr and not is_valid_github_url(github_pr): raise gr.Error("GitHub must be a valid GitHub URL", print_exception=False) for repo_id in space_ids + model_ids + dataset_ids: if not is_valid_repo_id(repo_id): error_msg = f"Space/Model/Dataset ID must be in the format 'org_name/repo_name'. Got: {repo_id}" raise gr.Error(error_msg, print_exception=False) def format_submitted_data( title_pr: str, authors_pr: str, arxiv_id_pr: str, project_page_pr: str, github_pr: str, space_ids_pr: str, model_ids_pr: str, dataset_ids_pr: str, ) -> dict: space_ids = [repo_id for repo_id in space_ids_pr.split("\n") if repo_id.strip()] model_ids = [repo_id for repo_id in model_ids_pr.split("\n") if repo_id.strip()] dataset_ids = [repo_id for repo_id in dataset_ids_pr.split("\n") if repo_id.strip()] validate_pr_data(title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, space_ids, model_ids, dataset_ids) return { "title": title_pr, "authors": [a for a in authors_pr.split("\n") if a.strip()], "arxiv_id": arxiv_id_pr if arxiv_id_pr else None, "project_page": project_page_pr if project_page_pr else None, "github": github_pr if github_pr else None, "space_ids": space_ids, "model_ids": model_ids, "dataset_ids": dataset_ids, } def preview_diff( title_pr: str, authors_pr: str, arxiv_id_pr: str, project_page_pr: str, github_pr: str, space_ids_pr: str, model_ids_pr: str, dataset_ids_pr: str, original_data: dict, ) -> tuple[gr.Markdown, gr.Button]: submitted_data = format_submitted_data( title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, space_ids_pr, model_ids_pr, dataset_ids_pr, ) submitted_data = {"paper_id": original_data["paper_id"], **submitted_data} original_json = json.dumps(original_data, indent=2) submitted_json = json.dumps(submitted_data, indent=2) diff = difflib.unified_diff( original_json.splitlines(), submitted_json.splitlines(), fromfile="before", tofile="after", lineterm="", ) diff_str = "\n".join(diff) return gr.Markdown(value=f"```diff\n{diff_str}\n```"), gr.Button(visible=True) def open_pr( title_pr: str, authors_pr: str, arxiv_id_pr: str, project_page_pr: str, github_pr: str, space_ids_pr: str, model_ids_pr: str, dataset_ids_pr: str, original_data: dict, oauth_token: gr.OAuthToken | None, ) -> gr.Markdown: submitted_data = format_submitted_data( title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, space_ids_pr, model_ids_pr, dataset_ids_pr, ) diff_dict = {key: submitted_data[key] for key in submitted_data if submitted_data[key] != original_data[key]} if not diff_dict: gr.Info("No data to submit") return "" paper_id = original_data["paper_id"] diff_dict["paper_id"] = paper_id original_json = json.dumps(original_data, indent=2) submitted_json = json.dumps(submitted_data, indent=2) diff = "\n".join( difflib.unified_diff( original_json.splitlines(), submitted_json.splitlines(), fromfile="before", tofile="after", lineterm="", ) ) diff_dict["diff"] = diff timestamp = datetime.datetime.now(datetime.timezone.utc) diff_dict["timestamp"] = timestamp.isoformat() with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: json.dump(diff_dict, f, indent=2) f.flush() commit = CommitOperationAdd(f"data/{paper_id}--{timestamp.strftime('%Y-%m-%d-%H-%M-%S')}.json", f.name) res = api.create_commit( repo_id=PATCH_REPO_ID, operations=[commit], commit_message=f"Update {paper_id}", repo_type="dataset", create_pr=True, token=oauth_token.token if oauth_token else None, ) return gr.Markdown(value=res.pr_url, visible=True) def render_open_pr_page(profile: gr.OAuthProfile | None) -> dict: return gr.Column(visible=profile is not None) with gr.Blocks() as demo: gr.LoginButton() with gr.Column(visible=False) as open_pr_col: gr.Markdown(NOTE) df_pr = gr.Dataframe( value=df_pr_view, datatype=[ "str", # Fix "str", # Title "str", # Authors "markdown", # openreview "str", # arxiv_id "markdown", # github "markdown", # spaces "markdown", # models "markdown", # datasets "str", # paper id ], column_widths=[ "50px", # Fix "40%", # Title "20%", # Authors None, # openreview "100px", # arxiv_id None, # github None, # spaces None, # models None, # datasets None, # paper id ], type="polars", row_count=(0, "dynamic"), interactive=False, max_height=1000, show_search="search", ) with Modal(visible=False) as pr_modal: with gr.Group(): title_pr = gr.Textbox(label="Title") authors_pr = gr.Textbox(label="Authors") arxiv_id_pr = gr.Textbox(label="arXiv ID") project_page_pr = gr.Textbox(label="Project page") github_pr = gr.Textbox(label="GitHub") spaces_pr = gr.Textbox( label="Spaces", info="Enter one space ID (e.g., 'org_name/space_name') per line.", ) models_pr = gr.Textbox( label="Models", info="Enter one model ID (e.g., 'org_name/model_name') per line.", ) datasets_pr = gr.Textbox( label="Datasets", info="Enter one dataset ID (e.g., 'org_name/dataset_name') per line.", ) original_data = gr.State() preview_diff_button = gr.Button("Preview diff") diff_view = gr.Markdown() open_pr_button = gr.Button("Open PR", visible=False) pr_url = gr.Markdown(visible=False) pr_modal.blur( fn=lambda: (None, gr.Button(visible=False), gr.Markdown(visible=False)), outputs=[diff_view, open_pr_button, pr_url], ) df_pr.select( fn=df_pr_row_selected, outputs=[ pr_modal, title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, spaces_pr, models_pr, datasets_pr, original_data, ], ) preview_diff_button.click( fn=preview_diff, inputs=[ title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, spaces_pr, models_pr, datasets_pr, original_data, ], outputs=[diff_view, open_pr_button], ) open_pr_button.click( fn=open_pr, inputs=[ title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, spaces_pr, models_pr, datasets_pr, original_data, ], outputs=pr_url, ) demo.load(fn=render_open_pr_page, outputs=open_pr_col) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False)