simple_rag / app.py
mischeiwiller's picture
Update to app.py to rag system
12f7691 verified
import gradio as gr
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch
# Initialize models
retriever = SentenceTransformer("all-MiniLM-L6-v2")
generator = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
# Simple vector store
class VectorStore:
def __init__(self):
self.documents = []
self.embeddings = []
def add_document(self, document):
self.documents.append(document)
embedding = retriever.encode(document)
self.embeddings.append(embedding)
def search(self, query, k=3):
query_embedding = retriever.encode(query)
similarities = np.dot(self.embeddings, query_embedding) / (
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
)
top_k_indices = np.argsort(similarities)[-k:][::-1]
return [self.documents[i] for i in top_k_indices]
# Initialize vector store
vector_store = VectorStore()
# Load sample dataset (e.g., Wikipedia snippets)
dataset = load_dataset(
"wikipedia", "20220301.simple", split="train[:1000]", trust_remote_code=True
)
for doc in dataset["text"]:
vector_store.add_document(doc)
# RAG function
def rag_query(query, max_length=100):
# Retrieve relevant documents
retrieved_docs = vector_store.search(query)
context = " ".join(retrieved_docs)
# Generate response
input_text = f"Context: {context}\n\nQuestion: {query}\nAnswer:"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = generator.generate(
inputs.input_ids,
max_length=max_length + len(inputs.input_ids[0]),
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("Answer:")[-1].strip()
# Gradio interface
def gradio_interface(query):
return rag_query(query)
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Enter your question"),
outputs=gr.Textbox(label="Answer"),
title="RAG System with Hugging Face and Gradio",
description="Ask questions based on a Wikipedia-based knowledge base.",
)
if __name__ == "__main__":
iface.launch()