hackergeek98 commited on
Commit
501033d
Β·
verified Β·
1 Parent(s): e504c1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -17
app.py CHANGED
@@ -10,12 +10,40 @@ from transformers import (
10
  from datasets import load_dataset
11
  import logging
12
  import sys
 
13
 
14
  # Configure logging
15
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16
 
17
- def train(dataset_name: str, dataset_config: str = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
 
 
 
 
19
  # Load model and tokenizer
20
  model_name = "microsoft/phi-2"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
@@ -26,31 +54,45 @@ def train(dataset_name: str, dataset_config: str = None):
26
  tokenizer.pad_token = tokenizer.eos_token
27
 
28
  # Load dataset from Hugging Face Hub
29
- logging.info(f"Loading dataset: {} (config: {dataset_config})")
30
  dataset = load_dataset(
31
  dataset_name,
32
- dataset_config, # Optional config (e.g., language for Common Voice)
33
- split="train+validation", # Combine splits
34
- trust_remote_code=True # Required for some datasets
35
  )
36
 
37
- # Split into train/validation
38
- dataset = dataset.train_test_split(test_size=0.1, seed=42)
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Tokenization function (adjust based on dataset columns)
41
  def tokenize_function(examples):
42
  return tokenizer(
43
- examples["text"], # Replace "text" with your dataset's text column
44
  padding="max_length",
45
  truncation=True,
46
  max_length=256,
47
  return_tensors="pt",
48
  )
49
 
50
- tokenized_dataset = dataset.map(
 
 
 
 
 
 
51
  tokenize_function,
52
  batched=True,
53
- remove_columns=dataset["train"].column_names
54
  )
55
 
56
  # Data collator
@@ -74,8 +116,8 @@ def train(dataset_name: str, dataset_config: str = None):
74
  trainer = Trainer(
75
  model=model,
76
  args=training_args,
77
- train_dataset=tokenized_dataset["train"],
78
- eval_dataset=tokenized_dataset["test"],
79
  data_collator=data_collator,
80
  )
81
 
@@ -91,20 +133,22 @@ def train(dataset_name: str, dataset_config: str = None):
91
  logging.error(f"Training failed: {str(e)}")
92
  return f"❌ Training failed: {str(e)}"
93
 
94
- # Gradio UI with dataset input
95
  with gr.Blocks(title="Phi-2 Training") as demo:
96
  gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
97
 
98
  with gr.Row():
99
- dataset_name = gr.Textbox(label="Dataset Name", value="mozilla-foundation/common_voice_11_0")
100
- dataset_config = gr.Textbox(label="Dataset Config (optional)", value="en")
 
 
101
 
102
  start_btn = gr.Button("Start Training", variant="primary")
103
  status_output = gr.Textbox(label="Status", interactive=False)
104
 
105
  start_btn.click(
106
  fn=train,
107
- inputs=[dataset_name, dataset_config],
108
  outputs=status_output
109
  )
110
 
 
10
  from datasets import load_dataset
11
  import logging
12
  import sys
13
+ from urllib.parse import urlparse
14
 
15
  # Configure logging
16
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
17
 
18
+ def parse_hf_dataset_url(url: str) -> tuple[str, str | None]:
19
+ """Parse Hugging Face dataset URL into (dataset_name, config)"""
20
+ parsed = urlparse(url)
21
+ path_parts = parsed.path.split('/')
22
+
23
+ try:
24
+ # Find 'datasets' in path
25
+ datasets_idx = path_parts.index('datasets')
26
+ except ValueError:
27
+ raise ValueError("Invalid Hugging Face dataset URL")
28
+
29
+ dataset_parts = path_parts[datasets_idx+1:]
30
+ dataset_name = "/".join(dataset_parts[0:2])
31
+
32
+ # Try to find config (common pattern for datasets with viewer)
33
+ try:
34
+ viewer_idx = dataset_parts.index('viewer')
35
+ config = dataset_parts[viewer_idx+1] if viewer_idx+1 < len(dataset_parts) else None
36
+ except ValueError:
37
+ config = None
38
+
39
+ return dataset_name, config
40
+
41
+ def train(dataset_url: str):
42
  try:
43
+ # Parse dataset URL
44
+ dataset_name, dataset_config = parse_hf_dataset_url(dataset_url)
45
+ logging.info(f"Loading dataset: {dataset_name} (config: {dataset_config})")
46
+
47
  # Load model and tokenizer
48
  model_name = "microsoft/phi-2"
49
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
54
  tokenizer.pad_token = tokenizer.eos_token
55
 
56
  # Load dataset from Hugging Face Hub
 
57
  dataset = load_dataset(
58
  dataset_name,
59
+ dataset_config,
60
+ trust_remote_code=True
 
61
  )
62
 
63
+ # Handle dataset splits
64
+ if "train" not in dataset:
65
+ raise ValueError("Dataset must have a 'train' split")
66
+
67
+ train_dataset = dataset["train"]
68
+ eval_dataset = dataset.get("validation", None)
69
+
70
+ # Split if no validation set
71
+ if eval_dataset is None:
72
+ split = train_dataset.train_test_split(test_size=0.1, seed=42)
73
+ train_dataset = split["train"]
74
+ eval_dataset = split["test"]
75
 
76
+ # Tokenization function
77
  def tokenize_function(examples):
78
  return tokenizer(
79
+ examples["text"], # Adjust column name as needed
80
  padding="max_length",
81
  truncation=True,
82
  max_length=256,
83
  return_tensors="pt",
84
  )
85
 
86
+ # Tokenize datasets
87
+ tokenized_train = train_dataset.map(
88
+ tokenize_function,
89
+ batched=True,
90
+ remove_columns=train_dataset.column_names
91
+ )
92
+ tokenized_eval = eval_dataset.map(
93
  tokenize_function,
94
  batched=True,
95
+ remove_columns=eval_dataset.column_names
96
  )
97
 
98
  # Data collator
 
116
  trainer = Trainer(
117
  model=model,
118
  args=training_args,
119
+ train_dataset=tokenized_train,
120
+ eval_dataset=tokenized_eval,
121
  data_collator=data_collator,
122
  )
123
 
 
133
  logging.error(f"Training failed: {str(e)}")
134
  return f"❌ Training failed: {str(e)}"
135
 
136
+ # Gradio UI with dataset URL input
137
  with gr.Blocks(title="Phi-2 Training") as demo:
138
  gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
139
 
140
  with gr.Row():
141
+ dataset_url = gr.Textbox(
142
+ label="Dataset URL",
143
+ value="https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0"
144
+ )
145
 
146
  start_btn = gr.Button("Start Training", variant="primary")
147
  status_output = gr.Textbox(label="Status", interactive=False)
148
 
149
  start_btn.click(
150
  fn=train,
151
+ inputs=[dataset_url],
152
  outputs=status_output
153
  )
154