rahideer commited on
Commit
b77a775
·
verified ·
1 Parent(s): 73cb629

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
+ from datasets import load_dataset
4
+ from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification
5
+
6
+ # Load the XNLI dataset (Multilingual NLI dataset) for demonstration
7
+ dataset = load_dataset("xnli", split="validation")
8
+
9
+ # Initialize tokenizer and retriever for multilingual support (using XLM-Roberta)
10
+ tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
11
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_multilingual_dataset")
12
+
13
+ # Initialize the RAG model
14
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
15
+
16
+ # Define Streamlit app
17
+ st.title('Multilingual RAG Translator/Answer Bot')
18
+
19
+ st.markdown("This app uses a multilingual RAG model to answer your questions in the language of the query. Ask questions in languages like Urdu, Hindi, or French!")
20
+
21
+ # User input for query
22
+ user_query = st.text_input("Ask a question in Urdu, Hindi, or French:")
23
+
24
+ if user_query:
25
+ # Tokenize the input question
26
+ inputs = tokenizer(user_query, return_tensors="pt", padding=True, truncation=True)
27
+ input_ids = inputs['input_ids']
28
+
29
+ # Use the retriever to get relevant context
30
+ retrieved_docs = retriever.retrieve(input_ids)
31
+
32
+ # Generate an answer using the context
33
+ generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs)
34
+ answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
35
+
36
+ # Display the answer
37
+ st.write(f"Answer: {answer}")
38
+
39
+ # Display the most relevant documents
40
+ st.subheader("Relevant Documents:")
41
+ for doc in retrieved_docs:
42
+ st.write(doc['text'][:300] + '...') # Display first 300 characters of each doc