File size: 2,457 Bytes
8a750df
12f7691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a750df
12f7691
 
 
8a750df
12f7691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a750df
 
12f7691
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch

# Initialize models
retriever = SentenceTransformer("all-MiniLM-L6-v2")
generator = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")


# Simple vector store
class VectorStore:
    def __init__(self):
        self.documents = []
        self.embeddings = []

    def add_document(self, document):
        self.documents.append(document)
        embedding = retriever.encode(document)
        self.embeddings.append(embedding)

    def search(self, query, k=3):
        query_embedding = retriever.encode(query)
        similarities = np.dot(self.embeddings, query_embedding) / (
            np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
        )
        top_k_indices = np.argsort(similarities)[-k:][::-1]
        return [self.documents[i] for i in top_k_indices]


# Initialize vector store
vector_store = VectorStore()

# Load sample dataset (e.g., Wikipedia snippets)
dataset = load_dataset(
    "wikipedia", "20220301.simple", split="train[:1000]", trust_remote_code=True
)
for doc in dataset["text"]:
    vector_store.add_document(doc)


# RAG function
def rag_query(query, max_length=100):
    # Retrieve relevant documents
    retrieved_docs = vector_store.search(query)
    context = " ".join(retrieved_docs)

    # Generate response
    input_text = f"Context: {context}\n\nQuestion: {query}\nAnswer:"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)

    with torch.no_grad():
        outputs = generator.generate(
            inputs.input_ids,
            max_length=max_length + len(inputs.input_ids[0]),
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split("Answer:")[-1].strip()


# Gradio interface
def gradio_interface(query):
    return rag_query(query)


iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Textbox(label="Enter your question"),
    outputs=gr.Textbox(label="Answer"),
    title="RAG System with Hugging Face and Gradio",
    description="Ask questions based on a Wikipedia-based knowledge base.",
)

if __name__ == "__main__":
    iface.launch()