Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import torch | |
# Initialize models | |
retriever = SentenceTransformer("all-MiniLM-L6-v2") | |
generator = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
# Simple vector store | |
class VectorStore: | |
def __init__(self): | |
self.documents = [] | |
self.embeddings = [] | |
def add_document(self, document): | |
self.documents.append(document) | |
embedding = retriever.encode(document) | |
self.embeddings.append(embedding) | |
def search(self, query, k=3): | |
query_embedding = retriever.encode(query) | |
similarities = np.dot(self.embeddings, query_embedding) / ( | |
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding) | |
) | |
top_k_indices = np.argsort(similarities)[-k:][::-1] | |
return [self.documents[i] for i in top_k_indices] | |
# Initialize vector store | |
vector_store = VectorStore() | |
# Load sample dataset (e.g., Wikipedia snippets) | |
dataset = load_dataset( | |
"wikipedia", "20220301.simple", split="train[:1000]", trust_remote_code=True | |
) | |
for doc in dataset["text"]: | |
vector_store.add_document(doc) | |
# RAG function | |
def rag_query(query, max_length=100): | |
# Retrieve relevant documents | |
retrieved_docs = vector_store.search(query) | |
context = " ".join(retrieved_docs) | |
# Generate response | |
input_text = f"Context: {context}\n\nQuestion: {query}\nAnswer:" | |
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = generator.generate( | |
inputs.input_ids, | |
max_length=max_length + len(inputs.input_ids[0]), | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.split("Answer:")[-1].strip() | |
# Gradio interface | |
def gradio_interface(query): | |
return rag_query(query) | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Textbox(label="Enter your question"), | |
outputs=gr.Textbox(label="Answer"), | |
title="RAG System with Hugging Face and Gradio", | |
description="Ask questions based on a Wikipedia-based knowledge base.", | |
) | |
if __name__ == "__main__": | |
iface.launch() | |