|
import streamlit as st |
|
import PyPDF2 |
|
import os |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
from transformers import pipeline |
|
|
|
st.set_page_config(page_title="π PDF RAG QA", layout="wide") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main {background-color: #f7faff;} |
|
h1 {color: #4a4a8a;} |
|
.stTextInput>div>div>input {border: 2px solid #d0d7ff;} |
|
.stButton button {background-color: #4a4a8a; color: white;} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
st.title("π Ask Me Anything About Machine Learning") |
|
st.caption("Using RAG (Retrieval-Augmented Generation) and a preloaded PDF") |
|
|
|
|
|
PDF_FILE = "ml_large_dataset.pdf" |
|
|
|
def load_pdf(file_path): |
|
with open(file_path, "rb") as f: |
|
reader = PyPDF2.PdfReader(f) |
|
return [page.extract_text() for page in reader.pages] |
|
|
|
def chunk_text(pages, max_len=1000): |
|
text = " ".join(pages) |
|
words = text.split() |
|
return [' '.join(words[i:i+max_len]) for i in range(0, len(words), max_len)] |
|
|
|
@st.cache_resource |
|
def setup_rag(): |
|
pages = load_pdf(PDF_FILE) |
|
chunks = chunk_text(pages) |
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
embeddings = model.encode(chunks) |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(np.array(embeddings)) |
|
qa = pipeline("question-answering", model="deepset/roberta-base-squad2") |
|
return chunks, model, index, qa |
|
|
|
def retrieve_answer(question, chunks, model, index, qa_pipeline, k=6): |
|
q_embed = model.encode([question]) |
|
_, I = index.search(np.array(q_embed), k) |
|
context = "\n\n".join([chunks[i] for i in I[0]]) |
|
result = qa_pipeline(question=question, context=context) |
|
return result['answer'] |
|
|
|
chunks, embed_model, faiss_index, qa_model = setup_rag() |
|
|
|
st.subheader("π¬ Ask a Question") |
|
question = st.text_input("Enter your question:", placeholder="e.g., What is supervised learning?") |
|
|
|
if question: |
|
with st.spinner("π§ Searching for the answer..."): |
|
answer = retrieve_answer(question, chunks, embed_model, faiss_index, qa_model) |
|
st.markdown("#### π Answer:") |
|
st.write(answer) |
|
|