davanstrien HF Staff commited on
Commit
d124aee
·
verified ·
1 Parent(s): 33ea3f5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info
5
+ import logging
6
+ from typing import Tuple, Literal
7
+ import functools
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Global variables
14
+ MODEL_NAME = "davanstrien/Smol-Hub-tldr"
15
+ model = None
16
+ tokenizer = None
17
+ device = None
18
+
19
+ def load_model():
20
+ global model, tokenizer, device
21
+ logger.info("Loading model and tokenizer...")
22
+ try:
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
25
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
26
+ model = model.to(device)
27
+ model.eval()
28
+ return True
29
+ except Exception as e:
30
+ logger.error(f"Failed to load model: {e}")
31
+ return False
32
+
33
+ @functools.lru_cache(maxsize=100)
34
+ def get_card_info(hub_id: str) -> Tuple[str, str]:
35
+ """Get card information from a Hugging Face hub_id."""
36
+ try:
37
+ info = model_info(hub_id)
38
+ card = ModelCard.load(hub_id)
39
+ return "model", card.text
40
+ except Exception as e:
41
+ logger.error(f"Error fetching model card for {hub_id}: {e}")
42
+ try:
43
+ info = dataset_info(hub_id)
44
+ card = DatasetCard.load(hub_id)
45
+ return "dataset", card.text
46
+ except Exception as e:
47
+ logger.error(f"Error fetching dataset card for {hub_id}: {e}")
48
+ raise ValueError(f"Could not find model or dataset with id {hub_id}")
49
+
50
+ @functools.lru_cache(maxsize=100)
51
+ def generate_summary(card_text: str, card_type: str) -> str:
52
+ """Generate a summary for the given card text."""
53
+ # Determine prefix based on card type
54
+ prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
55
+
56
+ # Format input according to the chat template
57
+ messages = [{"role": "user", "content": f"{prefix}{card_text}"}]
58
+ inputs = tokenizer.apply_chat_template(
59
+ messages, add_generation_prompt=True, return_tensors="pt"
60
+ )
61
+ inputs = inputs.to(device)
62
+
63
+ # Generate with optimized settings
64
+ with torch.no_grad():
65
+ outputs = model.generate(
66
+ inputs,
67
+ max_new_tokens=60,
68
+ pad_token_id=tokenizer.pad_token_id,
69
+ eos_token_id=tokenizer.eos_token_id,
70
+ temperature=0.4,
71
+ do_sample=True,
72
+ use_cache=True,
73
+ )
74
+
75
+ # Extract and clean up the summary
76
+ input_length = inputs.shape[1]
77
+ response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False)
78
+
79
+ # Extract just the summary part
80
+ try:
81
+ summary = response.split("<CARD_SUMMARY>")[-1].split("</CARD_SUMMARY>")[0].strip()
82
+ except IndexError:
83
+ summary = response.strip()
84
+
85
+ return summary
86
+
87
+ def summarize(hub_id: str = "", card_type: str = "model", content: str = "") -> str:
88
+ """Interface function for Gradio."""
89
+ try:
90
+ if hub_id:
91
+ # Fetch and validate card type
92
+ inferred_type, card_text = get_card_info(hub_id)
93
+ if card_type and card_type != inferred_type:
94
+ return f"Error: Provided card_type '{card_type}' doesn't match inferred type '{inferred_type}'"
95
+ card_type = inferred_type
96
+ elif content:
97
+ if not card_type:
98
+ return "Error: card_type must be provided when using direct content"
99
+ card_text = content
100
+ else:
101
+ return "Error: Either hub_id or content must be provided"
102
+
103
+ summary = generate_summary(card_text, card_type)
104
+ return summary
105
+
106
+ except Exception as e:
107
+ return f"Error: {str(e)}"
108
+
109
+ # Create the Gradio interface
110
+ def create_interface():
111
+ with gr.Blocks(title="Hub TLDR") as interface:
112
+ gr.Markdown("# Hugging Face Hub TLDR Generator")
113
+ gr.Markdown("Generate concise summaries of model and dataset cards from the Hugging Face Hub.")
114
+
115
+ with gr.Tab("Summarize by Hub ID"):
116
+ hub_id_input = gr.Textbox(
117
+ label="Hub ID",
118
+ placeholder="e.g., huggingface/llama-7b"
119
+ )
120
+ hub_id_type = gr.Radio(
121
+ choices=["model", "dataset"],
122
+ label="Card Type (optional)",
123
+ value="model"
124
+ )
125
+ hub_id_button = gr.Button("Generate Summary")
126
+ hub_id_output = gr.Textbox(label="Summary")
127
+
128
+ hub_id_button.click(
129
+ fn=summarize,
130
+ inputs=[hub_id_input, hub_id_type],
131
+ outputs=hub_id_output
132
+ )
133
+
134
+ with gr.Tab("Summarize Custom Content"):
135
+ content_input = gr.Textbox(
136
+ label="Content",
137
+ placeholder="Paste your model or dataset card content here...",
138
+ lines=10
139
+ )
140
+ content_type = gr.Radio(
141
+ choices=["model", "dataset"],
142
+ label="Card Type",
143
+ value="model"
144
+ )
145
+ content_button = gr.Button("Generate Summary")
146
+ content_output = gr.Textbox(label="Summary")
147
+
148
+ content_button.click(
149
+ fn=lambda content, card_type: summarize(content=content, card_type=card_type),
150
+ inputs=[content_input, content_type],
151
+ outputs=content_output
152
+ )
153
+
154
+ return interface
155
+
156
+ if __name__ == "__main__":
157
+ if load_model():
158
+ interface = create_interface()
159
+ interface.launch()
160
+ else:
161
+ print("Failed to load model. Please check the logs for details.")