R1 / app.py
hackergeek98's picture
Update app.py
767fba0 verified
raw
history blame
1.46 kB
# app.py
import torch
import gradio as gr
import threading
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
import logging
import sys
from urllib.parse import urlparse
# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def parse_hf_dataset_url(url: str):
# ... (keep previous URL parsing logic) ...
def train(dataset_url: str):
try:
# ... (keep previous training logic) ...
except Exception as e:
logging.error(f"Critical error: {str(e)}")
return f"❌ Critical error: {str(e)}"
# Gradio interface
with gr.Blocks(title="Phi-2 Training") as demo:
gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
with gr.Row():
dataset_url = gr.Textbox(
label="Dataset URL",
value="https://huggingface.co./datasets/mozilla-foundation/common_voice_11_0"
)
start_btn = gr.Button("Start Training", variant="primary")
status_output = gr.Textbox(label="Status", interactive=False)
start_btn.click(
fn=lambda url: threading.Thread(target=train, args=(url,)).start(),
inputs=[dataset_url],
outputs=status_output
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
enable_queue=True,
share=False
)