Spaces:
Running
Running
File size: 4,596 Bytes
e29acd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
# Load Translation Model
translation_model_name = "VietAI/envit5-translation"
translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
# Translation Function
def translate_text(text, source_lang, target_lang):
prompt = f"Translate the following text from {source_lang} to {target_lang}: {text}"
inputs = translation_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
output = translation_model.generate(**inputs, max_length=256)
return translation_tokenizer.decode(output[0], skip_special_tokens=True)
# Load Question Answering Model
qa_model_name = "atharvamundada99/bert-large-question-answering-finetuned-legal"
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
# Question Answering Function
def answer_question(question, context):
inputs = qa_tokenizer(question, context, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = qa_model(**inputs)
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
answer = qa_tokenizer.convert_tokens_to_string(
qa_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
)
return answer if answer.strip() else "Sorry, I couldn't find a relevant answer."
# Load Summarization Model
summarization_model_name = "Falconsai/medical_summarization"
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_name)
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name)
# Summarization Function
def summarize_text(text):
inputs = summarization_tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
with torch.no_grad():
summary_ids = summarization_model.generate(**inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4)
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Function to toggle UI visibility based on selected task
def select_task(task):
return (
gr.update(visible=(task == "Translation")),
gr.update(visible=(task == "Question Answering")),
gr.update(visible=(task == "Summarization")),
)
# Function to clear inputs and outputs
def clear_fields():
return "", "", "", ""
def clear_fields_summary():
return ""
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## AI-Powered Language Processing")
task_buttons = gr.Radio(["Translation", "Question Answering", "Summarization"], label="Choose a task")
with gr.Group(visible=False) as translation_ui:
source_lang = gr.Textbox(label="Source Language")
target_lang = gr.Textbox(label="Target Language")
text_input = gr.Textbox(label="Enter Text")
translate_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translated Text")
clear_button_t = gr.Button("Clear")
clear_button_t.click(clear_fields, inputs=[], outputs=[source_lang, target_lang, text_input, translation_output])
translate_button.click(translate_text, inputs=[text_input, source_lang, target_lang], outputs=translation_output)
with gr.Group(visible=False) as qa_ui:
question_input = gr.Textbox(label="Enter Question")
context_input = gr.Textbox(label="Enter Context")
answer_button = gr.Button("Get Answer")
qa_output = gr.Textbox(label="Answer")
clear_button_qa = gr.Button("Clear")
clear_button_qa.click(clear_fields, inputs=[], outputs=[question_input, context_input, qa_output])
answer_button.click(answer_question, inputs=[question_input, context_input], outputs=qa_output)
with gr.Group(visible=False) as summarization_ui:
text_input_summary = gr.Textbox(label="Enter Text")
summarize_button = gr.Button("Summarize")
summary_output = gr.Textbox(label="Summary")
clear_button_s = gr.Button("Clear")
clear_button_s.click(clear_fields_summary, inputs=[], outputs=[text_input_summary, summary_output])
summarize_button.click(summarize_text, inputs=[text_input_summary], outputs=summary_output)
task_buttons.change(select_task, inputs=[task_buttons], outputs=[translation_ui, qa_ui, summarization_ui])
demo.launch(share=True) |