import gradio as gr import spaces import torch from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification import torch.nn.functional as F import torch.nn as nn import re import requests from urllib.parse import urlparse import xml.etree.ElementTree as ET ################################################## # Global setup ################################################## model_path = "ssocean/NAIP" device = "cuda" if torch.cuda.is_available() else "cpu" model = None tokenizer = None ################################################## # Fetch paper info from arXiv ################################################## def fetch_arxiv_paper(arxiv_input): """ Fetch paper title & abstract from an arXiv URL or ID. """ try: if "arxiv.org" in arxiv_input: parsed = urlparse(arxiv_input) path = parsed.path arxiv_id = path.split("/")[-1].replace(".pdf", "") else: arxiv_id = arxiv_input.strip() api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" resp = requests.get(api_url) if resp.status_code != 200: return { "title": "", "abstract": "", "success": False, "message": "Error fetching paper from arXiv API", } root = ET.fromstring(resp.text) ns = {"arxiv": "http://www.w3.org/2005/Atom"} entry = root.find(".//arxiv:entry", ns) if entry is None: return {"title": "", "abstract": "", "success": False, "message": "Paper not found"} title = entry.find("arxiv:title", ns).text.strip() abstract = entry.find("arxiv:summary", ns).text.strip() return { "title": title, "abstract": abstract, "success": True, "message": "Paper fetched successfully!", } except Exception as e: return { "title": "", "abstract": "", "success": False, "message": f"Error fetching paper: {e}", } ################################################## # Prediction function ################################################## @spaces.GPU(duration=60, enable_queue=True) def predict(title, abstract): """ Predict a normalized academic impact score (0β1) from title & abstract. """ global model, tokenizer if model is None: # 1) Load config config = AutoConfig.from_pretrained(model_path) # 2) Remove quantization_config if it exists (avoid NoneType error in PEFT) if hasattr(config, "quantization_config"): del config.quantization_config # 3) Optionally set number of labels config.num_labels = 1 # 4) Load the model model_loaded = AutoModelForSequenceClassification.from_pretrained( model_path, config=config, torch_dtype=torch.float32, # float32 for stable cublasLt device_map=None, low_cpu_mem_usage=False ) model_loaded.to(device) model_loaded.eval() # 5) Load tokenizer tokenizer_loaded = AutoTokenizer.from_pretrained(model_path) # Assign to globals model, tokenizer = model_loaded, tokenizer_loaded text = ( f"Given a certain paper,\n" f"Title: {title.strip()}\n" f"Abstract: {abstract.strip()}\n" f"Predict its normalized academic impact (0~1):" ) try: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits prob = torch.sigmoid(logits).item() score = min(1.0, prob + 0.05) return round(score, 4) except Exception as e: print("Prediction error:", e) return 0.0 ################################################## # Grading ################################################## def get_grade_and_emoji(score): """Map a 0β1 score to an A/B/C style grade with an emoji indicator.""" if score >= 0.900: return "AAA π" if score >= 0.800: return "AA β" if score >= 0.650: return "A β¨" if score >= 0.600: return "BBB π΅" if score >= 0.550: return "BB π" if score >= 0.500: return "B π" if score >= 0.400: return "CCC π" if score >= 0.300: return "CC βοΈ" return "C π" ################################################## # Validation ################################################## def validate_input(title, abstract): """ Ensure the title has at least 3 words, the abstract at least 50, and check for ASCII-only characters. """ non_ascii = re.compile(r"[^\x00-\x7F]") if len(title.split()) < 3: return False, "Title must be at least 3 words." if len(abstract.split()) < 50: return False, "Abstract must be at least 50 words." if non_ascii.search(title): return False, "Title contains non-ASCII characters." if non_ascii.search(abstract): return False, "Abstract contains non-ASCII characters." return True, "Inputs look good." def update_button_status(title, abstract): """Enable or disable the predict button based on validation.""" valid, msg = validate_input(title, abstract) if not valid: return gr.update(value="Error: " + msg), gr.update(interactive=False) return gr.update(value=msg), gr.update(interactive=True) ################################################## # Process arXiv input ################################################## def process_arxiv_input(arxiv_input): """ Called when user clicks 'Fetch Paper Details' to fill in title/abstract from arXiv. """ if not arxiv_input.strip(): return "", "", "Please enter an arXiv URL or ID" res = fetch_arxiv_paper(arxiv_input) if res["success"]: return res["title"], res["abstract"], res["message"] return "", "", res["message"] ################################################## # Custom CSS ################################################## css = """ .gradio-container { font-family: Arial, sans-serif; } .main-title { text-align: center; color: #2563eb; font-size: 2.5rem!important; margin-bottom:1rem!important; background: linear-gradient(45deg,#2563eb,#1d4ed8); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .input-section { background:#fff; padding:1.5rem; border-radius:0.5rem; box-shadow:0 4px 6px rgba(0,0,0,0.1); } .result-section { background:#f7f9fc; padding:1.5rem; border-radius:0.5rem; margin-top:2rem; } .grade-display { font-size:2.5rem; text-align:center; margin-top:1rem; } .arxiv-input { margin-bottom:1.5rem; padding:1rem; background:#f3f4f6; border-radius:0.5rem; } .arxiv-link { color:#2563eb; text-decoration: underline; } """ ################################################## # Header HTML (social links) ################################################## header_html = """
Enter an arXiv ID or URL. For example:
2504.11651
or https://arxiv.org/pdf/2504.11651