hackergeek98 commited on
Commit
db16ef6
·
verified ·
1 Parent(s): 5198fc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -11
app.py CHANGED
@@ -1,7 +1,9 @@
1
- # app.py
2
  import torch
3
  import gradio as gr
4
  import threading
 
 
 
5
  from transformers import (
6
  AutoModelForCausalLM,
7
  AutoTokenizer,
@@ -10,23 +12,127 @@ from transformers import (
10
  DataCollatorForLanguageModeling
11
  )
12
  from datasets import load_dataset
13
- import logging
14
- import sys
15
- from urllib.parse import urlparse
16
 
17
  # Configure logging
18
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19
 
20
- def parse_hf_dataset_url(url: str):
21
- # ... (keep previous URL parsing logic) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def train(dataset_url: str):
24
  try:
25
- # ... (keep previous training logic) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
- logging.error(f"Critical error: {str(e)}")
29
- return f"❌ Critical error: {str(e)}"
30
 
31
  # Gradio interface
32
  with gr.Blocks(title="Phi-2 Training") as demo:
@@ -51,6 +157,5 @@ if __name__ == "__main__":
51
  demo.launch(
52
  server_name="0.0.0.0",
53
  server_port=7860,
54
- enable_queue=True,
55
- share=False
56
  )
 
 
1
  import torch
2
  import gradio as gr
3
  import threading
4
+ import logging
5
+ import sys
6
+ from urllib.parse import urlparse
7
  from transformers import (
8
  AutoModelForCausalLM,
9
  AutoTokenizer,
 
12
  DataCollatorForLanguageModeling
13
  )
14
  from datasets import load_dataset
 
 
 
15
 
16
  # Configure logging
17
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
18
 
19
+ def parse_hf_dataset_url(url: str) -> tuple[str, str | None]:
20
+ """Parse Hugging Face dataset URL into (dataset_name, config)"""
21
+ parsed = urlparse(url)
22
+ path_parts = parsed.path.split('/')
23
+
24
+ try:
25
+ # Find 'datasets' in path
26
+ datasets_idx = path_parts.index('datasets')
27
+ except ValueError:
28
+ raise ValueError("Invalid Hugging Face dataset URL")
29
+
30
+ dataset_parts = path_parts[datasets_idx+1:]
31
+ dataset_name = "/".join(dataset_parts[0:2])
32
+
33
+ # Try to find config (common pattern for datasets with viewer)
34
+ try:
35
+ viewer_idx = dataset_parts.index('viewer')
36
+ config = dataset_parts[viewer_idx+1] if viewer_idx+1 < len(dataset_parts) else None
37
+ except ValueError:
38
+ config = None
39
+
40
+ return dataset_name, config
41
 
42
  def train(dataset_url: str):
43
  try:
44
+ # Parse dataset URL
45
+ dataset_name, dataset_config = parse_hf_dataset_url(dataset_url)
46
+ logging.info(f"Loading dataset: {dataset_name} (config: {dataset_config})")
47
+
48
+ # Load model and tokenizer
49
+ model_name = "microsoft/phi-2"
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
51
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
52
+
53
+ # Add padding token
54
+ if tokenizer.pad_token is None:
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+
57
+ # Load dataset from Hugging Face Hub
58
+ dataset = load_dataset(
59
+ dataset_name,
60
+ dataset_config,
61
+ trust_remote_code=True
62
+ )
63
+
64
+ # Handle dataset splits
65
+ if "train" not in dataset:
66
+ raise ValueError("Dataset must have a 'train' split")
67
 
68
+ train_dataset = dataset["train"]
69
+ eval_dataset = dataset.get("validation", dataset.get("test", None))
70
+
71
+ # Split if no validation set
72
+ if eval_dataset is None:
73
+ split = train_dataset.train_test_split(test_size=0.1, seed=42)
74
+ train_dataset = split["train"]
75
+ eval_dataset = split["test"]
76
+
77
+ # Tokenization function
78
+ def tokenize_function(examples):
79
+ return tokenizer(
80
+ examples["text"], # Adjust column name as needed
81
+ padding="max_length",
82
+ truncation=True,
83
+ max_length=256,
84
+ return_tensors="pt",
85
+ )
86
+
87
+ # Tokenize datasets
88
+ tokenized_train = train_dataset.map(
89
+ tokenize_function,
90
+ batched=True,
91
+ remove_columns=train_dataset.column_names
92
+ )
93
+ tokenized_eval = eval_dataset.map(
94
+ tokenize_function,
95
+ batched=True,
96
+ remove_columns=eval_dataset.column_names
97
+ )
98
+
99
+ # Data collator
100
+ data_collator = DataCollatorForLanguageModeling(
101
+ tokenizer=tokenizer,
102
+ mlm=False
103
+ )
104
+
105
+ # Training arguments
106
+ training_args = TrainingArguments(
107
+ output_dir="./phi2-results",
108
+ per_device_train_batch_size=2,
109
+ per_device_eval_batch_size=2,
110
+ num_train_epochs=3,
111
+ logging_dir="./logs",
112
+ logging_steps=10,
113
+ fp16=False,
114
+ )
115
+
116
+ # Trainer
117
+ trainer = Trainer(
118
+ model=model,
119
+ args=training_args,
120
+ train_dataset=tokenized_train,
121
+ eval_dataset=tokenized_eval,
122
+ data_collator=data_collator,
123
+ )
124
+
125
+ # Start training
126
+ logging.info("Training started...")
127
+ trainer.train()
128
+ trainer.save_model("./phi2-trained-model")
129
+ logging.info("Training completed!")
130
+
131
+ return "✅ Training succeeded! Model saved."
132
+
133
  except Exception as e:
134
+ logging.error(f"Training failed: {str(e)}")
135
+ return f"❌ Training failed: {str(e)}"
136
 
137
  # Gradio interface
138
  with gr.Blocks(title="Phi-2 Training") as demo:
 
157
  demo.launch(
158
  server_name="0.0.0.0",
159
  server_port=7860,
160
+ enable_queue=True
 
161
  )