import gradio as gr import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer from sentence_transformers import SentenceTransformer from datasets import load_dataset import torch # Initialize models retriever = SentenceTransformer("all-MiniLM-L6-v2") generator = AutoModelForCausalLM.from_pretrained("distilgpt2") tokenizer = AutoTokenizer.from_pretrained("distilgpt2") # Simple vector store class VectorStore: def __init__(self): self.documents = [] self.embeddings = [] def add_document(self, document): self.documents.append(document) embedding = retriever.encode(document) self.embeddings.append(embedding) def search(self, query, k=3): query_embedding = retriever.encode(query) similarities = np.dot(self.embeddings, query_embedding) / ( np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding) ) top_k_indices = np.argsort(similarities)[-k:][::-1] return [self.documents[i] for i in top_k_indices] # Initialize vector store vector_store = VectorStore() # Load sample dataset (e.g., Wikipedia snippets) dataset = load_dataset( "wikipedia", "20220301.simple", split="train[:1000]", trust_remote_code=True ) for doc in dataset["text"]: vector_store.add_document(doc) # RAG function def rag_query(query, max_length=100): # Retrieve relevant documents retrieved_docs = vector_store.search(query) context = " ".join(retrieved_docs) # Generate response input_text = f"Context: {context}\n\nQuestion: {query}\nAnswer:" inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = generator.generate( inputs.input_ids, max_length=max_length + len(inputs.input_ids[0]), num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response.split("Answer:")[-1].strip() # Gradio interface def gradio_interface(query): return rag_query(query) iface = gr.Interface( fn=gradio_interface, inputs=gr.Textbox(label="Enter your question"), outputs=gr.Textbox(label="Answer"), title="RAG System with Hugging Face and Gradio", description="Ask questions based on a Wikipedia-based knowledge base.", ) if __name__ == "__main__": iface.launch()