openfree commited on
Commit
1a19be7
·
verified ·
1 Parent(s): cb7cf23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -95
app.py CHANGED
@@ -9,7 +9,7 @@ import requests
9
  from urllib.parse import urlparse
10
  import xml.etree.ElementTree as ET
11
 
12
- model_path = r'ssocean/NAIP'
13
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
 
15
  global model, tokenizer
@@ -74,50 +74,54 @@ def fetch_arxiv_paper(arxiv_input):
74
 
75
  @spaces.GPU(duration=60, enable_queue=True)
76
  def predict(title, abstract):
77
- title = title.replace("\n", " ").strip().replace(''',"'")
78
- abstract = abstract.replace("\n", " ").strip().replace(''',"'")
79
  global model, tokenizer
80
  if model is None:
81
  try:
82
- # First try loading without quantization
83
  model = AutoModelForSequenceClassification.from_pretrained(
84
  model_path,
85
  num_labels=1,
86
- device_map='auto',
87
- torch_dtype=torch.float32 if device == 'cpu' else torch.float16
88
  )
 
 
89
  except Exception as e:
90
- print(f"Standard loading failed, trying without device mapping: {str(e)}")
91
- # Fallback to basic loading
92
  model = AutoModelForSequenceClassification.from_pretrained(
93
  model_path,
94
  num_labels=1,
95
  torch_dtype=torch.float32
96
  )
97
- if torch.cuda.is_available():
98
- model = model.cuda()
99
-
100
  tokenizer = AutoTokenizer.from_pretrained(model_path)
101
  model.eval()
102
 
103
- text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
 
 
 
 
104
 
105
  try:
106
  inputs = tokenizer(text, return_tensors="pt")
107
- if torch.cuda.is_available():
108
- inputs = {k: v.cuda() for k, v in inputs.items()}
109
-
110
  with torch.no_grad():
111
  outputs = model(**inputs)
112
  probability = torch.sigmoid(outputs.logits).item()
113
 
114
- if probability + 0.05 >= 1.0:
115
- return round(1, 4)
116
- return round(probability + 0.05, 4)
117
 
118
  except Exception as e:
119
  print(f"Prediction error: {str(e)}")
120
- return 0.0 # Return default value in case of error
121
 
122
  def get_grade_and_emoji(score):
123
  if score >= 0.900: return "AAA 🌟"
@@ -152,8 +156,8 @@ example_papers = [
152
  ]
153
 
154
  def validate_input(title, abstract):
155
- title = title.replace("\n", " ").strip().replace(''',"'")
156
- abstract = abstract.replace("\n", " ").strip().replace(''',"'")
157
 
158
  non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
159
  non_latin_in_title = non_latin_pattern.findall(title)
@@ -193,69 +197,9 @@ css = """
193
  .gradio-container {
194
  font-family: 'Arial', sans-serif;
195
  }
196
- .main-title {
197
- text-align: center;
198
- color: #2563eb;
199
- font-size: 2.5rem !important;
200
- margin-bottom: 1rem !important;
201
- background: linear-gradient(45deg, #2563eb, #1d4ed8);
202
- -webkit-background-clip: text;
203
- -webkit-text-fill-color: transparent;
204
- }
205
- .sub-title {
206
- text-align: center;
207
- color: #4b5563;
208
- font-size: 1.5rem !important;
209
- margin-bottom: 2rem !important;
210
- }
211
- .input-section {
212
- background: white;
213
- padding: 2rem;
214
- border-radius: 1rem;
215
- box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1);
216
- }
217
- .result-section {
218
- background: #f8fafc;
219
- padding: 2rem;
220
- border-radius: 1rem;
221
- margin-top: 2rem;
222
- }
223
- .methodology-section {
224
- background: #ecfdf5;
225
- padding: 2rem;
226
- border-radius: 1rem;
227
- margin-top: 2rem;
228
- }
229
- .example-section {
230
- background: #fff7ed;
231
- padding: 2rem;
232
- border-radius: 1rem;
233
- margin-top: 2rem;
234
- }
235
- .grade-display {
236
- font-size: 3rem;
237
- text-align: center;
238
- margin: 1rem 0;
239
- }
240
- .arxiv-input {
241
- margin-bottom: 1.5rem;
242
- padding: 1rem;
243
- background: #f3f4f6;
244
- border-radius: 0.5rem;
245
- }
246
- .arxiv-link {
247
- color: #2563eb;
248
- text-decoration: underline;
249
- font-size: 0.9em;
250
- margin-top: 0.5em;
251
- }
252
- .arxiv-note {
253
- color: #666;
254
- font-size: 0.9em;
255
- margin-top: 0.5em;
256
- margin-bottom: 0.5em;
257
- }
258
  """
 
259
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
260
  gr.Markdown(
261
  """
@@ -263,22 +207,19 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
263
  ## https://discord.gg/openfreeai
264
  """
265
  )
266
- # Visitor Badge - 들여쓰기 수정
267
  gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space">
268
  <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space&countColor=%23263759" />
269
  </a>""")
270
-
271
 
272
  with gr.Row():
273
  with gr.Column(elem_classes="input-section"):
274
- # arXiv Input
275
  with gr.Group(elem_classes="arxiv-input"):
276
  gr.Markdown("### 📑 Import from arXiv")
277
  arxiv_input = gr.Textbox(
278
  lines=1,
279
  placeholder="Enter arXiv URL or ID (e.g., 2501.09751)",
280
  label="arXiv Paper URL/ID",
281
- value="https://arxiv.org/pdf/2502.07316" # Default example URL
282
  )
283
  gr.Markdown("""
284
  <p class="arxiv-note">
@@ -289,7 +230,6 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
289
  fetch_button = gr.Button("🔍 Fetch Paper Details", variant="secondary")
290
 
291
  gr.Markdown("### 📝 Or Enter Paper Details Manually")
292
-
293
  title_input = gr.Textbox(
294
  lines=2,
295
  placeholder="Enter Paper Title (minimum 3 words)...",
@@ -306,7 +246,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
306
  with gr.Column(elem_classes="result-section"):
307
  with gr.Group():
308
  score_output = gr.Number(label="🎯 Impact Score")
309
- grade_output = gr.Textbox(label="🏆 Grade", value="", elem_classes="grade-display")
310
 
311
  with gr.Row(elem_classes="methodology-section"):
312
  gr.Markdown(
@@ -338,20 +278,19 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
338
  """
339
  )
340
 
341
- # Example Papers Section
342
  with gr.Row(elem_classes="example-section"):
343
  gr.Markdown("### 📋 Example Papers")
344
  for paper in example_papers:
345
  gr.Markdown(
346
  f"""
347
- #### {paper['title']}
348
  **Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))}
349
  {paper['abstract']}
350
  *{paper['note']}*
351
  ---
352
- """)
 
353
 
354
- # Event handlers
355
  title_input.change(
356
  update_button_status,
357
  inputs=[title_input, abstract_input],
@@ -362,7 +301,6 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
362
  inputs=[title_input, abstract_input],
363
  outputs=[validation_status, submit_button]
364
  )
365
-
366
  fetch_button.click(
367
  process_arxiv_input,
368
  inputs=[arxiv_input],
@@ -381,4 +319,4 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
381
  )
382
 
383
  if __name__ == "__main__":
384
- iface.launch()
 
9
  from urllib.parse import urlparse
10
  import xml.etree.ElementTree as ET
11
 
12
+ model_path = r'ssocean/NAIP'
13
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
 
15
  global model, tokenizer
 
74
 
75
  @spaces.GPU(duration=60, enable_queue=True)
76
  def predict(title, abstract):
77
+ title = title.replace("\n", " ").strip().replace("''", "'")
78
+ abstract = abstract.replace("\n", " ").strip().replace("''", "'")
79
  global model, tokenizer
80
  if model is None:
81
  try:
82
+ # Always load in full float32 precision
83
  model = AutoModelForSequenceClassification.from_pretrained(
84
  model_path,
85
  num_labels=1,
86
+ device_map=None,
87
+ torch_dtype=torch.float32
88
  )
89
+ # 명시적으로 device에 올리기
90
+ model.to(device)
91
  except Exception as e:
92
+ print(f"Standard loading failed, retrying in float32: {str(e)}")
93
+ # Fallback: basic 로딩, 역시 float32
94
  model = AutoModelForSequenceClassification.from_pretrained(
95
  model_path,
96
  num_labels=1,
97
  torch_dtype=torch.float32
98
  )
99
+ model.to(device)
 
 
100
  tokenizer = AutoTokenizer.from_pretrained(model_path)
101
  model.eval()
102
 
103
+ text = (
104
+ f"Given a certain paper, Title: {title}\n"
105
+ f"Abstract: {abstract}.\n"
106
+ "Predict its normalized academic impact (between 0 and 1):"
107
+ )
108
 
109
  try:
110
  inputs = tokenizer(text, return_tensors="pt")
111
+ # inputs를 device로 이동
112
+ inputs = {k: v.to(device) for k, v in inputs.items()}
113
+
114
  with torch.no_grad():
115
  outputs = model(**inputs)
116
  probability = torch.sigmoid(outputs.logits).item()
117
 
118
+ # 소폭 올림 보정
119
+ score = min(1.0, probability + 0.05)
120
+ return round(score, 4)
121
 
122
  except Exception as e:
123
  print(f"Prediction error: {str(e)}")
124
+ return 0.0 # 오류 기본값
125
 
126
  def get_grade_and_emoji(score):
127
  if score >= 0.900: return "AAA 🌟"
 
156
  ]
157
 
158
  def validate_input(title, abstract):
159
+ title = title.replace("\n", " ").strip().replace("''", "'")
160
+ abstract = abstract.replace("\n", " ").strip().replace("''", "'")
161
 
162
  non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
163
  non_latin_in_title = non_latin_pattern.findall(title)
 
197
  .gradio-container {
198
  font-family: 'Arial', sans-serif;
199
  }
200
+ /* ... 이하 CSS는 동일 ... */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  """
202
+
203
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
204
  gr.Markdown(
205
  """
 
207
  ## https://discord.gg/openfreeai
208
  """
209
  )
 
210
  gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space">
211
  <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space&countColor=%23263759" />
212
  </a>""")
 
213
 
214
  with gr.Row():
215
  with gr.Column(elem_classes="input-section"):
 
216
  with gr.Group(elem_classes="arxiv-input"):
217
  gr.Markdown("### 📑 Import from arXiv")
218
  arxiv_input = gr.Textbox(
219
  lines=1,
220
  placeholder="Enter arXiv URL or ID (e.g., 2501.09751)",
221
  label="arXiv Paper URL/ID",
222
+ value="https://arxiv.org/pdf/2502.07316"
223
  )
224
  gr.Markdown("""
225
  <p class="arxiv-note">
 
230
  fetch_button = gr.Button("🔍 Fetch Paper Details", variant="secondary")
231
 
232
  gr.Markdown("### 📝 Or Enter Paper Details Manually")
 
233
  title_input = gr.Textbox(
234
  lines=2,
235
  placeholder="Enter Paper Title (minimum 3 words)...",
 
246
  with gr.Column(elem_classes="result-section"):
247
  with gr.Group():
248
  score_output = gr.Number(label="🎯 Impact Score")
249
+ grade_output = gr.Textbox(label="🏆 Grade", elem_classes="grade-display")
250
 
251
  with gr.Row(elem_classes="methodology-section"):
252
  gr.Markdown(
 
278
  """
279
  )
280
 
 
281
  with gr.Row(elem_classes="example-section"):
282
  gr.Markdown("### 📋 Example Papers")
283
  for paper in example_papers:
284
  gr.Markdown(
285
  f"""
286
+ #### {paper['title']}
287
  **Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))}
288
  {paper['abstract']}
289
  *{paper['note']}*
290
  ---
291
+ """
292
+ )
293
 
 
294
  title_input.change(
295
  update_button_status,
296
  inputs=[title_input, abstract_input],
 
301
  inputs=[title_input, abstract_input],
302
  outputs=[validation_status, submit_button]
303
  )
 
304
  fetch_button.click(
305
  process_arxiv_input,
306
  inputs=[arxiv_input],
 
319
  )
320
 
321
  if __name__ == "__main__":
322
+ iface.launch()