Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.hub | |
import re | |
import os | |
import time | |
# --- Set Page Config First --- | |
st.set_page_config( | |
page_title="AI Text Detector", | |
layout="centered", | |
initial_sidebar_state="collapsed" | |
) | |
# --- Improved CSS for a cleaner UI --- | |
st.markdown(""" | |
<style> | |
/* Modern clean font for the entire app */ | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
html, body, [class*="css"] { | |
font-family: 'Inter', sans-serif; | |
} | |
/* Header styling */ | |
h1 { | |
font-weight: 700; | |
color: #1E3A8A; | |
padding-bottom: 1rem; | |
border-bottom: 2px solid #E5E7EB; | |
margin-bottom: 2rem; | |
} | |
/* Text area styling */ | |
.stTextArea textarea { | |
border: 1px solid #D1D5DB; | |
border-radius: 8px; | |
font-size: 16px; | |
padding: 12px; | |
background-color: #F9FAFB; | |
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05); | |
transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out; | |
} | |
.stTextArea textarea:focus { | |
border-color: #3B82F6; | |
box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.3); | |
outline: none; | |
} | |
/* Button styling */ | |
.stButton button { | |
border-radius: 8px; | |
font-weight: 600; | |
padding: 10px 16px; | |
background-color: #2563EB; | |
color: white; | |
border: none; | |
width: 100%; | |
transition: background-color 0.2s ease; | |
} | |
.stButton button:hover { | |
background-color: #1D4ED8; | |
} | |
/* Result box styling */ | |
.result-box { | |
border-radius: 8px; | |
padding: 20px; | |
margin-top: 24px; | |
text-align: center; | |
background-color: white; | |
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1), 0 1px 2px rgba(0, 0, 0, 0.06); | |
border: 1px solid #E5E7EB; | |
} | |
/* Result highlights */ | |
.highlight-human { | |
color: #059669; | |
font-weight: 600; | |
background: rgba(5, 150, 105, 0.1); | |
padding: 4px 10px; | |
border-radius: 8px; | |
display: inline-block; | |
} | |
.highlight-ai { | |
color: #DC2626; | |
font-weight: 600; | |
background: rgba(220, 38, 38, 0.1); | |
padding: 4px 10px; | |
border-radius: 8px; | |
display: inline-block; | |
} | |
/* Footer styling */ | |
.footer { | |
text-align: center; | |
margin-top: 40px; | |
padding-top: 20px; | |
border-top: 1px solid #E5E7EB; | |
color: #6B7280; | |
font-size: 14px; | |
} | |
/* Progress bar styling */ | |
.stProgress > div > div { | |
background-color: #2563EB; | |
} | |
/* General spacing */ | |
.block-container { | |
padding-top: 2rem; | |
padding-bottom: 2rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# --- Configuration --- | |
MODEL1_PATH = "modernbert.bin" | |
MODEL2_URL = "https://huggingface.co./mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12" | |
MODEL3_URL = "https://huggingface.co./mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" | |
BASE_MODEL = "answerdotai/ModernBERT-base" | |
NUM_LABELS = 41 | |
HUMAN_LABEL_INDEX = 24 | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# --- Model Loading Functions --- | |
def load_tokenizer(model_name): | |
from transformers import AutoTokenizer | |
return AutoTokenizer.from_pretrained(model_name) | |
def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE): | |
from transformers import AutoModelForSequenceClassification | |
# Load base model architecture | |
model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels) | |
try: | |
# Load weights | |
if is_url: | |
state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=False) | |
else: | |
if not os.path.exists(model_path_or_url): | |
return None | |
state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False) | |
model.load_state_dict(state_dict) | |
model.to(_device).eval() | |
return model | |
except Exception: | |
return None | |
# --- Text Processing Functions --- | |
def clean_text(text): | |
if not isinstance(text, str): | |
return "" | |
text = text.replace("\r\n", "\n").replace("\r", "\n") | |
text = re.sub(r"\n\s*\n+", "\n\n", text) | |
text = re.sub(r"[ \t]+", " ", text) | |
text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text) | |
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text) | |
return text.strip() | |
def classify_text(text, tokenizer, model_1, model_2, model_3, device, label_mapping, human_label_index): | |
if not all([model_1, model_2, model_3, tokenizer]): | |
return {"error": True, "message": "Models failed to load properly."} | |
cleaned_text = clean_text(text) | |
if not cleaned_text: | |
return None | |
try: | |
inputs = tokenizer( | |
cleaned_text, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=tokenizer.model_max_length | |
).to(device) | |
with torch.no_grad(): | |
logits_1 = model_1(**inputs).logits | |
logits_2 = model_2(**inputs).logits | |
logits_3 = model_3(**inputs).logits | |
softmax_1 = torch.softmax(logits_1, dim=1) | |
softmax_2 = torch.softmax(logits_2, dim=1) | |
softmax_3 = torch.softmax(logits_3, dim=1) | |
averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3 | |
probabilities = averaged_probabilities[0].cpu() | |
if not (0 <= human_label_index < len(probabilities)): | |
return {"error": True, "message": "Configuration error."} | |
human_prob = probabilities[human_label_index].item() * 100 | |
mask = torch.ones_like(probabilities, dtype=torch.bool) | |
mask[human_label_index] = False | |
ai_total_prob = probabilities[mask].sum().item() * 100 | |
ai_probs_only = probabilities.clone() | |
ai_probs_only[human_label_index] = -float('inf') | |
ai_argmax_index = torch.argmax(ai_probs_only).item() | |
ai_argmax_model = label_mapping.get(ai_argmax_index, f"Unknown AI (Index {ai_argmax_index})") | |
if human_prob >= ai_total_prob: | |
return {"is_human": True, "probability": human_prob, "model": "Human"} | |
else: | |
return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model} | |
except Exception as e: | |
return {"error": True, "message": f"Analysis failed: {str(e)}"} | |
# --- Label Mapping --- | |
LABEL_MAPPING = { | |
0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', | |
6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', | |
11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', | |
14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', | |
18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', | |
22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', | |
27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', | |
31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', | |
35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', | |
39: 'text-davinci-002', 40: 'text-davinci-003' | |
} | |
# --- Main UI --- | |
st.title("AI Text Detector") | |
# Initialization with a progress bar | |
with st.spinner(""): | |
# Create a progress bar | |
progress_bar = st.progress(0) | |
st.info("Initializing AI detection models...") | |
# Step 1: Load tokenizer | |
progress_bar.progress(20) | |
time.sleep(0.5) # Small delay for visual feedback | |
TOKENIZER = load_tokenizer(BASE_MODEL) | |
# Step 2: Load first model | |
progress_bar.progress(40) | |
time.sleep(0.5) # Small delay for visual feedback | |
MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE) | |
# Step 3: Load second model | |
progress_bar.progress(60) | |
time.sleep(0.5) # Small delay for visual feedback | |
MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) | |
# Step 4: Load third model | |
progress_bar.progress(80) | |
time.sleep(0.5) # Small delay for visual feedback | |
MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) | |
# Complete initialization | |
progress_bar.progress(100) | |
time.sleep(0.5) # Small delay for visual feedback | |
# Clear the initialization messages | |
st.empty() | |
# Check if models loaded successfully | |
if not all([TOKENIZER, MODEL_1, MODEL_2, MODEL_3]): | |
st.error("Failed to initialize one or more AI detection models. Please try refreshing the page.") | |
st.stop() | |
# Input area | |
input_text = st.text_area( | |
label="Enter text to analyze:", | |
placeholder="Type or paste your content here for AI detection analysis...", | |
height=200, | |
key="text_input" | |
) | |
# Analyze button and output | |
analyze_button = st.button("Analyze Text", key="analyze_button") | |
result_placeholder = st.empty() | |
if analyze_button: | |
if input_text and input_text.strip(): | |
with st.spinner('Analyzing text...'): | |
classification_result = classify_text( | |
input_text, | |
TOKENIZER, | |
MODEL_1, | |
MODEL_2, | |
MODEL_3, | |
DEVICE, | |
LABEL_MAPPING, | |
HUMAN_LABEL_INDEX | |
) | |
# Display result | |
if classification_result is None: | |
result_placeholder.warning("Please enter some text to analyze.") | |
elif classification_result.get("error"): | |
error_message = classification_result.get("message", "An unknown error occurred during analysis.") | |
result_placeholder.error(f"Analysis Error: {error_message}") | |
elif classification_result["is_human"]: | |
prob = classification_result['probability'] | |
result_html = ( | |
f"<div class='result-box'>" | |
f"<b>The text is</b> <span class='highlight-human'><b>{prob:.2f}%</b> likely <b>Human written</b>.</span>" | |
f"</div>" | |
) | |
result_placeholder.markdown(result_html, unsafe_allow_html=True) | |
else: # AI generated | |
prob = classification_result['probability'] | |
model_name = classification_result['model'] | |
result_html = ( | |
f"<div class='result-box'>" | |
f"<b>The text is</b> <span class='highlight-ai'><b>{prob:.2f}%</b> likely <b>AI generated</b>.</span><br><br>" | |
f"<b>Most Likely AI Model: {model_name}</b>" | |
f"</div>" | |
) | |
result_placeholder.markdown(result_html, unsafe_allow_html=True) | |
else: | |
result_placeholder.warning("Please enter some text to analyze.") | |
# Footer | |
st.markdown("<div class='footer'>Developed by Eeman Majumder</div>", unsafe_allow_html=True) |