flytoe commited on
Commit
c942b0f
·
verified ·
1 Parent(s): 6ab4778

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -21
app.py CHANGED
@@ -1,46 +1,66 @@
1
- import torch
2
- from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
3
  from datasets import load_dataset
 
 
 
4
 
5
- # 1️⃣ Modell & Tokenizer laden
6
- model_name = "allenai/scibert_scivocab_uncased"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
9
 
10
- # 2️⃣ Dataset laden (mit spezifischer Konfiguration: "arxiv" oder "pubmed")
11
- dataset = load_dataset("armanc/scientific_papers", "arxiv", trust_remote_code=True) # Oder "pubmed"
 
 
 
12
 
13
- # 3️⃣ Tokenisierung der Texte
14
  def tokenize_function(examples):
15
- return tokenizer(examples["text"], padding="max_length", truncation=True)
 
 
16
 
17
- tokenized_datasets = dataset.map(tokenize_function, batched=True)
 
18
 
19
- # 4️⃣ Trainingsparameter setzen
20
  training_args = TrainingArguments(
21
  output_dir="./results",
22
  evaluation_strategy="epoch",
23
- save_strategy="epoch",
24
  per_device_train_batch_size=8,
25
  per_device_eval_batch_size=8,
26
  num_train_epochs=3,
 
27
  weight_decay=0.01,
28
  logging_dir="./logs",
 
29
  )
30
 
31
- # 5️⃣ Training starten
32
  trainer = Trainer(
33
  model=model,
34
  args=training_args,
35
- train_dataset=tokenized_datasets["train"],
36
- eval_dataset=tokenized_datasets["validation"],
37
  )
38
-
39
  trainer.train()
40
 
41
- # 6️⃣ Speichern des Modells nach dem Training
42
- model.save_pretrained("./trained_model")
43
  tokenizer.save_pretrained("./trained_model")
44
 
45
- print(dataset) # Zeigt die Struktur des Datensatzes
46
- print("✅ Training abgeschlossen! Modell gespeichert.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
3
+ import gradio as gr
4
+ import torch
5
 
6
+ # Schritt 1: Dataset laden und überprüfen
7
+ # Falls "KeyError: 'text'" auftritt, Spaltennamen prüfen
 
 
8
 
9
+ dataset = load_dataset("armanc/scientific_papers", "arxiv") # Falls du PubMed nutzt, ersetze "arxiv" mit "pubmed"
10
+ print(dataset)
11
+
12
+ # Schritt 2: Tokenizer vorbereiten
13
+ tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
14
 
 
15
  def tokenize_function(examples):
16
+ return tokenizer(examples["abstract"], padding="max_length", truncation=True)
17
+
18
+ dataset = dataset.map(tokenize_function, batched=True)
19
 
20
+ # Schritt 3: Modell laden
21
+ model = AutoModelForSequenceClassification.from_pretrained("allenai/scibert_scivocab_uncased", num_labels=3)
22
 
23
+ # Schritt 4: Trainingsparameter setzen
24
  training_args = TrainingArguments(
25
  output_dir="./results",
26
  evaluation_strategy="epoch",
 
27
  per_device_train_batch_size=8,
28
  per_device_eval_batch_size=8,
29
  num_train_epochs=3,
30
+ learning_rate=5e-5,
31
  weight_decay=0.01,
32
  logging_dir="./logs",
33
+ logging_steps=500,
34
  )
35
 
36
+ # Schritt 5: Trainer erstellen und Training starten
37
  trainer = Trainer(
38
  model=model,
39
  args=training_args,
40
+ train_dataset=dataset["train"],
41
+ eval_dataset=dataset["validation"],
42
  )
 
43
  trainer.train()
44
 
45
+ # Schritt 6: Modell speichern
46
+ trainer.save_model("./trained_model")
47
  tokenizer.save_pretrained("./trained_model")
48
 
49
+ # Schritt 7: Modell für Gradio bereitstellen
50
+ def predict(text):
51
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+ logits = outputs.logits
55
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
56
+ return {f"Label {i}": float(probabilities[0][i]) for i in range(len(probabilities[0]))}
57
+
58
+ iface = gr.Interface(
59
+ fn=predict,
60
+ inputs=gr.Textbox(lines=5, placeholder="Paste an abstract here..."),
61
+ outputs=gr.Label(),
62
+ title="Scientific Paper Evaluator",
63
+ description="This AI model scores scientific papers based on relevance, uniqueness, and redundancy."
64
+ )
65
+
66
+ iface.launch()