import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering # Load Translation Model translation_model_name = "VietAI/envit5-translation" translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name) translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name) # Translation Function def translate_text(text, source_lang, target_lang): prompt = f"Translate the following text from {source_lang} to {target_lang}: {text}" inputs = translation_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): output = translation_model.generate(**inputs, max_length=256) return translation_tokenizer.decode(output[0], skip_special_tokens=True) # Load Question Answering Model qa_model_name = "atharvamundada99/bert-large-question-answering-finetuned-legal" qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) # Question Answering Function def answer_question(question, context): inputs = qa_tokenizer(question, context, return_tensors="pt", truncation=True) with torch.no_grad(): outputs = qa_model(**inputs) answer_start = torch.argmax(outputs.start_logits) answer_end = torch.argmax(outputs.end_logits) + 1 answer = qa_tokenizer.convert_tokens_to_string( qa_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]) ) return answer if answer.strip() else "Sorry, I couldn't find a relevant answer." # Load Summarization Model summarization_model_name = "Falconsai/medical_summarization" summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_name) summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name) # Summarization Function def summarize_text(text): inputs = summarization_tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) with torch.no_grad(): summary_ids = summarization_model.generate(**inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4) return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) # Function to toggle UI visibility based on selected task def select_task(task): return ( gr.update(visible=(task == "Translation")), gr.update(visible=(task == "Question Answering")), gr.update(visible=(task == "Summarization")), ) # Function to clear inputs and outputs def clear_fields(): return "", "", "", "" def clear_fields_summary(): return "" # Gradio UI with gr.Blocks() as demo: gr.Markdown("## AI-Powered Language Processing") task_buttons = gr.Radio(["Translation", "Question Answering", "Summarization"], label="Choose a task") with gr.Group(visible=False) as translation_ui: source_lang = gr.Textbox(label="Source Language") target_lang = gr.Textbox(label="Target Language") text_input = gr.Textbox(label="Enter Text") translate_button = gr.Button("Translate") translation_output = gr.Textbox(label="Translated Text") clear_button_t = gr.Button("Clear") clear_button_t.click(clear_fields, inputs=[], outputs=[source_lang, target_lang, text_input, translation_output]) translate_button.click(translate_text, inputs=[text_input, source_lang, target_lang], outputs=translation_output) with gr.Group(visible=False) as qa_ui: question_input = gr.Textbox(label="Enter Question") context_input = gr.Textbox(label="Enter Context") answer_button = gr.Button("Get Answer") qa_output = gr.Textbox(label="Answer") clear_button_qa = gr.Button("Clear") clear_button_qa.click(clear_fields, inputs=[], outputs=[question_input, context_input, qa_output]) answer_button.click(answer_question, inputs=[question_input, context_input], outputs=qa_output) with gr.Group(visible=False) as summarization_ui: text_input_summary = gr.Textbox(label="Enter Text") summarize_button = gr.Button("Summarize") summary_output = gr.Textbox(label="Summary") clear_button_s = gr.Button("Clear") clear_button_s.click(clear_fields_summary, inputs=[], outputs=[text_input_summary, summary_output]) summarize_button.click(summarize_text, inputs=[text_input_summary], outputs=summary_output) task_buttons.change(select_task, inputs=[task_buttons], outputs=[translation_ui, qa_ui, summarization_ui]) demo.launch(share=True)