HCK-GPT / app.py
hackermoon1's picture
Update app.py
e8cdb05 verified
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import os
from model.utils import preprocess_input, save_feedback
from model.auto_learn import trigger_auto_learning
# Carregar modelo e tokenizador
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Mover para GPU se disponível
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Função principal de inferência
def generate_text(prompt, max_length=100, temperature=0.7):
inputs = preprocess_input(prompt, tokenizer)
input_ids = inputs["input_ids"].to(device)
outputs = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
num_return_sequences=1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Função para coletar feedback e disparar autoaprendizado
def submit_feedback(prompt, generated_text, user_feedback):
save_feedback(prompt, generated_text, user_feedback)
trigger_auto_learning() # Dispara fine-tuning se necessário
return "Feedback salvo com sucesso!"
# Interface com Gradio
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# GPT-2 no Hugging Face")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Digite seu prompt")
max_length = gr.Slider(50, 500, value=100, label="Comprimento máximo")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperatura")
generate_btn = gr.Button("Gerar Texto")
with gr.Column():
output = gr.Textbox(label="Texto Gerado")
feedback = gr.Textbox(label="Feedback (opcional)")
feedback_btn = gr.Button("Enviar Feedback")
generate_btn.click(
fn=generate_text,
inputs=[prompt, max_length, temperature],
outputs=output
)
feedback_btn.click(
fn=submit_feedback,
inputs=[prompt, output, feedback],
outputs=gr.Textbox()
)
return demo
# Iniciar a interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)