Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,113 Bytes
d06d36f 1c00c70 725bd04 d06d36f 725bd04 d06d36f 725bd04 d06d36f 725bd04 d06d36f 1c00c70 d06d36f 1c00c70 d06d36f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import datasets
import polars as pl
from loguru import logger
from polars import datatypes as pdt
BASE_REPO_ID = "ai-conferences/ICLR2025"
PATCH_REPO_ID = "ai-conferences/ICLR2025-patches"
PAPER_PAGE_REPO_ID = "hysts-bot-data/paper-pages-slim"
def get_patch_latest_values(
df: pl.DataFrame, all_columns: list[str], id_col: str, timestamp_col: str = "timestamp", delimiter: str = ","
) -> pl.DataFrame:
df = df.sort(timestamp_col)
list_cols = [
col for col, dtype in df.schema.items() if col not in (id_col, timestamp_col) and dtype.base_type() is pdt.List
]
df = df.with_columns(
[
pl.when(pl.col(c).is_not_null()).then(pl.col(c).list.join(delimiter)).otherwise(None).alias(c)
for c in list_cols
]
)
update_columns = [col for col in df.columns if col not in (id_col, timestamp_col)]
melted = df.unpivot(on=update_columns, index=[timestamp_col, id_col]).drop_nulls()
latest_rows = (
melted.sort(timestamp_col)
.group_by([id_col, "variable"])
.agg(pl.col("value").last())
.pivot("variable", index=id_col, values="value")
)
latest_rows = latest_rows.with_columns(
[
pl.when(pl.col(c).is_not_null()).then(pl.col(c).str.split(delimiter)).otherwise(None).alias(c)
for c in list_cols
]
)
missing_cols = [c for c in all_columns if c not in latest_rows.columns and c != id_col]
if missing_cols:
latest_rows = latest_rows.with_columns([pl.lit(None).alias(c) for c in missing_cols])
return latest_rows.select([id_col] + [col for col in all_columns if col != id_col])
def format_author_claim_ratio(row: dict) -> str:
n_linked_authors = row["n_linked_authors"]
n_authors = row["n_authors"]
if n_linked_authors is None or n_authors is None:
return ""
author_linked = "✅" if n_linked_authors > 0 else ""
return f"{n_linked_authors}/{n_authors} {author_linked}".strip()
df_orig = (
datasets.load_dataset(BASE_REPO_ID, split="train")
.to_polars()
.rename({"paper_url": "openreview", "submission_number": "paper_id"})
.with_columns(
pl.lit([], dtype=pl.List(pl.Utf8)).alias(col_name) for col_name in ["space_ids", "model_ids", "dataset_ids"]
)
)
df_paper_page = (
datasets.load_dataset(PAPER_PAGE_REPO_ID, split="train")
.to_polars()
.drop(["summary", "author_names", "ai_keywords"])
)
df_orig = df_orig.join(df_paper_page, on="arxiv_id", how="left")
try:
df_patches = (
datasets.load_dataset(PATCH_REPO_ID, split="train")
.to_polars()
.drop("diff")
.with_columns(pl.col("timestamp").str.strptime(pl.Datetime, "%+"))
)
df_patches = get_patch_latest_values(df_patches, df_orig.columns, id_col="paper_id", timestamp_col="timestamp")
df_orig = (
df_orig.join(df_patches, on="paper_id", how="left")
.with_columns(
[
pl.coalesce([pl.col(col + "_right"), pl.col(col)]).alias(col)
for col in df_orig.columns
if col != "paper_id"
]
)
.select(df_orig.columns)
)
except Exception as e: # noqa: BLE001
logger.warning(e)
# format authors
df_orig = df_orig.with_columns(pl.col("authors").list.join(", ").alias("authors_str"))
# format links
df_orig = df_orig.with_columns(
[
pl.format("[link]({})", pl.col(col)).fill_null("").alias(f"{col}_md")
for col in ["openreview", "project_page", "github"]
]
)
# format paper page link
df_orig = df_orig.with_columns(
(pl.lit("https://huggingface.co./papers/") + pl.col("arxiv_id")).alias("paper_page")
).with_columns(pl.format("[{}]({})", pl.col("arxiv_id"), pl.col("paper_page")).fill_null("").alias("paper_page_md"))
# count authors
df_orig = df_orig.with_columns(pl.col("authors").list.len().alias("n_authors"))
df_orig = df_orig.with_columns(
pl.col("author_usernames")
.map_elements(lambda lst: sum(x is not None for x in lst) if lst is not None else None, return_dtype=pl.Int64)
.alias("n_linked_authors")
)
df_orig = df_orig.with_columns(
pl.struct(["n_linked_authors", "n_authors"])
.map_elements(format_author_claim_ratio, return_dtype=pl.Utf8)
.alias("claimed")
)
# TODO: Fix this once https://github.com/gradio-app/gradio/issues/10916 is fixed # noqa: FIX002, TD002
# format numbers as strings
df_orig = df_orig.with_columns(
[pl.col(col).cast(pl.Utf8).fill_null("").alias(col) for col in ["upvotes", "num_comments"]]
)
# format spaces, models, datasets
for repo_id_col, markdown_col, base_url in [
("space_ids", "Spaces", "https://huggingface.co./spaces/"),
("model_ids", "Models", "https://huggingface.co./"),
("dataset_ids", "Datasets", "https://huggingface.co./datasets/"),
]:
df_orig = df_orig.with_columns(
pl.col(repo_id_col)
.map_elements(
lambda lst: "\n".join([f"[link]({base_url}{x})" for x in lst]) if lst is not None else None, # noqa: B023
return_dtype=pl.Utf8,
)
.fill_null("")
.alias(markdown_col)
)
|