ragiator / app.py
chibuzordev's picture
Update app.py
679bed8 verified
# !pip install transformers faiss-cpu gradio sentence-transformers nlpaug scikit-learn
# rag_module.py
import json
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoTokenizer, T5ForConditionalGeneration
import gradio as gr
from rag_pipeline import RAGPipeline
from adversarial_framework import *
# Load all models and retrievers ONCE
rag = RAGPipeline(
embedder_model = "infly/inf-retriever-v1-1.5b",
reranker_model = "cross-encoder/ms-marco-MiniLM-L-6-v2",
generator_model = "google/flan-t5-base"
)
adv_pipeline = AdversarialAttackPipeline(answer_generator=rag.generate_answer)
# Define the Gradio wrapper
def gradio_wrapper(query, method, k):
stats_text, auc, fig, pert_q, pert_r, adv_r = adv_pipeline.evaluate_adversarial_robustness(
query=query,
method=method,
k=k
)
return stats_text, f"{auc}", fig, pert_q, pert_r, adv_r
gr.Interface(
fn=gradio_wrapper,
inputs=[
gr.Textbox(label="Enter a Question"),
gr.Dropdown(choices=["synonym", "delete", "contextual"], label="Perturbation Method"),
gr.Slider(1, 5, step=1, value=3, label="Top-K Retrieved Chunks")
],
outputs=[
gr.Textbox(label="πŸ“Š Summary Statistics"),
gr.Textbox(label="πŸ”Ί PSC-AUC Score"),
gr.Plot(label="πŸ“ˆ PSC Curve"),
gr.Textbox(label="🟠 Perturbed Query Example"),
gr.Textbox(label="🟒 Perturbed Response Example"),
gr.Textbox(label="πŸ”΄ Directly Perturbed Normal Response Example")
],
title="Adversarial Testing on RAGiant System",
description="Evaluate robustness against textual attacks and visualize degradation with ARI & PSC-AUC."
).launch()