Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import torch | |
import os | |
from model.utils import preprocess_input, save_feedback | |
from model.auto_learn import trigger_auto_learning | |
# Carregar modelo e tokenizador | |
model_name = "gpt2" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
# Mover para GPU se disponível | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Função principal de inferência | |
def generate_text(prompt, max_length=100, temperature=0.7): | |
inputs = preprocess_input(prompt, tokenizer) | |
input_ids = inputs["input_ids"].to(device) | |
outputs = model.generate( | |
input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
num_return_sequences=1, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# Função para coletar feedback e disparar autoaprendizado | |
def submit_feedback(prompt, generated_text, user_feedback): | |
save_feedback(prompt, generated_text, user_feedback) | |
trigger_auto_learning() # Dispara fine-tuning se necessário | |
return "Feedback salvo com sucesso!" | |
# Interface com Gradio | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# GPT-2 no Hugging Face") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Digite seu prompt") | |
max_length = gr.Slider(50, 500, value=100, label="Comprimento máximo") | |
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperatura") | |
generate_btn = gr.Button("Gerar Texto") | |
with gr.Column(): | |
output = gr.Textbox(label="Texto Gerado") | |
feedback = gr.Textbox(label="Feedback (opcional)") | |
feedback_btn = gr.Button("Enviar Feedback") | |
generate_btn.click( | |
fn=generate_text, | |
inputs=[prompt, max_length, temperature], | |
outputs=output | |
) | |
feedback_btn.click( | |
fn=submit_feedback, | |
inputs=[prompt, output, feedback], | |
outputs=gr.Textbox() | |
) | |
return demo | |
# Iniciar a interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(server_name="0.0.0.0", server_port=7860) |