rahideer commited on
Commit
c0fc352
·
verified ·
1 Parent(s): 2c19e76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -34
app.py CHANGED
@@ -1,35 +1,67 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
- import groq
4
-
5
- # Initialize Groq API
6
- groq_client = groq.Client()
7
-
8
- # Initialize the zero-shot classification pipeline from Hugging Face
9
- classifier = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")
10
-
11
- # Function to perform zero-shot classification
12
- def classify_text(sequence, candidate_labels):
13
- result = classifier(sequence, candidate_labels)
14
- return result
15
-
16
- # Streamlit UI elements
17
- st.title("Zero-Shot Text Classification with XLM-RoBERTa")
18
- st.markdown("Enter a text and select candidate labels for classification.")
19
-
20
- # Text input from the user
21
- sequence = st.text_area("Enter text to classify", "", height=150)
22
-
23
- # Candidate labels
24
- candidate_labels = st.text_input("Enter candidate labels (comma separated)", "politics, health, education")
25
- candidate_labels = [label.strip() for label in candidate_labels.split(",")]
26
-
27
- # When the classify button is pressed
28
- if st.button("Classify Text"):
29
- if sequence:
30
- result = classify_text(sequence, candidate_labels)
31
- st.write("Classification Results:")
32
- st.write(f"Labels: {result['labels']}")
33
- st.write(f"Scores: {result['scores']}")
34
- else:
35
- st.error("Please enter text to classify.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+
7
+ # Load pre-trained multilingual model for retrieval and generation
8
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
9
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
10
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="faiss")
11
+
12
+ # Set up FAISS for multilingual document retrieval
13
+ def setup_faiss():
14
+ # Load multilingual embeddings for documents (e.g., using LaBSE or multilingual BERT)
15
+ model_embed = SentenceTransformer('sentence-transformers/LaBSE')
16
+
17
+ # Example multilingual documents
18
+ docs = [
19
+ "How to learn programming?",
20
+ "Comment apprendre la programmation?",
21
+ "پروگرامنگ سیکھنے کا طریقہ کیا ہے؟"
22
+ ]
23
+
24
+ embeddings = model_embed.encode(docs, convert_to_tensor=True)
25
+ faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
26
+ faiss_index.add(np.array(embeddings))
27
+
28
+ return faiss_index, docs
29
+
30
+ # Set up FAISS index
31
+ faiss_index, docs = setup_faiss()
32
+
33
+ # Retrieve documents based on query
34
+ def retrieve_docs(query):
35
+ # Embed the query
36
+ query_embedding = SentenceTransformer('sentence-transformers/LaBSE').encode([query], convert_to_tensor=True)
37
+
38
+ # Perform retrieval using FAISS
39
+ D, I = faiss_index.search(np.array(query_embedding), 1)
40
+
41
+ # Get the most relevant document
42
+ return docs[I[0][0]]
43
+
44
+ # Handle question-answering
45
+ def answer_question(query):
46
+ # Retrieve relevant document
47
+ retrieved_doc = retrieve_docs(query)
48
+
49
+ # Tokenize the input
50
+ inputs = tokenizer(query, retrieved_doc, return_tensors="pt", padding=True, truncation=True)
51
+
52
+ # Generate an answer
53
+ generated = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
54
+
55
+ # Decode the answer
56
+ answer = tokenizer.decode(generated[0], skip_special_tokens=True)
57
+ return answer
58
+
59
+ # Streamlit interface for user input
60
+ st.title("Multilingual RAG Translator/Answer Bot")
61
+ st.write("Ask a question in your preferred language (Urdu, French, Hindi)")
62
+
63
+ query = st.text_input("Enter your question:")
64
+
65
+ if query:
66
+ answer = answer_question(query)
67
+ st.write(f"Answer: {answer}")