Update app.py
Browse files
app.py
CHANGED
@@ -1,35 +1,67 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
import faiss
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
# Load pre-trained multilingual model for retrieval and generation
|
8 |
+
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
9 |
+
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
|
10 |
+
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="faiss")
|
11 |
+
|
12 |
+
# Set up FAISS for multilingual document retrieval
|
13 |
+
def setup_faiss():
|
14 |
+
# Load multilingual embeddings for documents (e.g., using LaBSE or multilingual BERT)
|
15 |
+
model_embed = SentenceTransformer('sentence-transformers/LaBSE')
|
16 |
+
|
17 |
+
# Example multilingual documents
|
18 |
+
docs = [
|
19 |
+
"How to learn programming?",
|
20 |
+
"Comment apprendre la programmation?",
|
21 |
+
"پروگرامنگ سیکھنے کا طریقہ کیا ہے؟"
|
22 |
+
]
|
23 |
+
|
24 |
+
embeddings = model_embed.encode(docs, convert_to_tensor=True)
|
25 |
+
faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
|
26 |
+
faiss_index.add(np.array(embeddings))
|
27 |
+
|
28 |
+
return faiss_index, docs
|
29 |
+
|
30 |
+
# Set up FAISS index
|
31 |
+
faiss_index, docs = setup_faiss()
|
32 |
+
|
33 |
+
# Retrieve documents based on query
|
34 |
+
def retrieve_docs(query):
|
35 |
+
# Embed the query
|
36 |
+
query_embedding = SentenceTransformer('sentence-transformers/LaBSE').encode([query], convert_to_tensor=True)
|
37 |
+
|
38 |
+
# Perform retrieval using FAISS
|
39 |
+
D, I = faiss_index.search(np.array(query_embedding), 1)
|
40 |
+
|
41 |
+
# Get the most relevant document
|
42 |
+
return docs[I[0][0]]
|
43 |
+
|
44 |
+
# Handle question-answering
|
45 |
+
def answer_question(query):
|
46 |
+
# Retrieve relevant document
|
47 |
+
retrieved_doc = retrieve_docs(query)
|
48 |
+
|
49 |
+
# Tokenize the input
|
50 |
+
inputs = tokenizer(query, retrieved_doc, return_tensors="pt", padding=True, truncation=True)
|
51 |
+
|
52 |
+
# Generate an answer
|
53 |
+
generated = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
54 |
+
|
55 |
+
# Decode the answer
|
56 |
+
answer = tokenizer.decode(generated[0], skip_special_tokens=True)
|
57 |
+
return answer
|
58 |
+
|
59 |
+
# Streamlit interface for user input
|
60 |
+
st.title("Multilingual RAG Translator/Answer Bot")
|
61 |
+
st.write("Ask a question in your preferred language (Urdu, French, Hindi)")
|
62 |
+
|
63 |
+
query = st.text_input("Enter your question:")
|
64 |
+
|
65 |
+
if query:
|
66 |
+
answer = answer_question(query)
|
67 |
+
st.write(f"Answer: {answer}")
|