ApsidalSolid4's picture
Update app.py
6ca0d72 verified
# AI Text Detector Code Analysis
# IMPORTS AND CONFIGURATION
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification # HuggingFace transformers for NLP models
import torch.nn.functional as F
import spacy # Used for sentence splitting
from typing import List, Dict, Tuple
import logging
import os
import gradio as gr # Used for creating the web UI
from fastapi.middleware.cors import CORSMiddleware
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import time
from datetime import datetime
# Basic logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# GLOBAL PARAMETERS
MAX_LENGTH = 512 # Maximum token length for the model input
MODEL_NAME = "microsoft/deberta-v3-small" # Using Microsoft's DeBERTa v3 small model as the base
WINDOW_SIZE = 6 # Number of sentences in each analysis window
WINDOW_OVERLAP = 2 # Number of sentences that overlap between adjacent windows
CONFIDENCE_THRESHOLD = 0.65 # Threshold for highlighting predictions with stronger colors
BATCH_SIZE = 8 # Number of windows to process in a single batch for efficiency
MAX_WORKERS = 4 # Maximum number of worker threads for parallel processing
# TEXT WINDOW PROCESSOR
# This class handles sentence splitting and window creation for text analysis
class TextWindowProcessor:
def __init__(self):
# Initialize SpaCy with minimal pipeline for sentence splitting
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
# Auto-download SpaCy model if not available
logger.info("Downloading spacy model...")
spacy.cli.download("en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")
# Add sentencizer if not already present
if 'sentencizer' not in self.nlp.pipe_names:
self.nlp.add_pipe('sentencizer')
# Disable unnecessary components for better performance
disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer']
self.nlp.disable_pipes(*disabled_pipes)
# Setup ThreadPoolExecutor for parallel processing
self.executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
# Split text into individual sentences using SpaCy
def split_into_sentences(self, text: str) -> List[str]:
doc = self.nlp(text)
return [str(sent).strip() for sent in doc.sents]
# Create overlapping windows of fixed size (for quick scan)
def create_windows(self, sentences: List[str], window_size: int, overlap: int) -> List[str]:
if len(sentences) < window_size:
return [" ".join(sentences)] # Return single window if not enough sentences
windows = []
stride = window_size - overlap
for i in range(0, len(sentences) - window_size + 1, stride):
window = sentences[i:i + window_size]
windows.append(" ".join(window))
return windows
# Create windows centered around each sentence (for detailed scan)
# This provides better analysis of individual sentences with proper context
def create_centered_windows(self, sentences: List[str], window_size: int) -> Tuple[List[str], List[List[int]]]:
windows = []
window_sentence_indices = []
for i in range(len(sentences)):
half_window = window_size // 2
start_idx = max(0, i - half_window)
end_idx = min(len(sentences), i + half_window + 1)
window = sentences[start_idx:end_idx]
windows.append(" ".join(window))
window_sentence_indices.append(list(range(start_idx, end_idx)))
return windows, window_sentence_indices
# TEXT CLASSIFIER
# This class handles the actual AI/Human classification using a pre-trained model
class TextClassifier:
def __init__(self):
# Configure CPU threading if CUDA not available
if not torch.cuda.is_available():
torch.set_num_threads(MAX_WORKERS)
torch.set_num_interop_threads(MAX_WORKERS)
# Set device (GPU if available, otherwise CPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = MODEL_NAME
self.tokenizer = None
self.model = None
self.processor = TextWindowProcessor()
self.initialize_model()
# Initialize the model and tokenizer
def initialize_model(self):
logger.info("Initializing model and tokenizer...")
# Using DeBERTa tokenizer specifically for better compatibility
from transformers import DebertaV2TokenizerFast
self.tokenizer = DebertaV2TokenizerFast.from_pretrained(
self.model_name,
model_max_length=MAX_LENGTH,
use_fast=True # Use fast tokenizer for better performance
)
# Load classification model with 2 labels (AI and Human)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=2
).to(self.device)
# Try to load custom fine-tuned model weights if available
model_path = "model_20250209_184929_acc1.0000.pt"
if os.path.exists(model_path):
logger.info(f"Loading custom model from {model_path}")
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
logger.warning("Custom model file not found. Using base model.")
# Set model to evaluation mode
self.model.eval()
# Quick scan analysis - faster but less detailed
# Uses fixed-size windows with overlap
def quick_scan(self, text: str) -> Dict:
if not text.strip():
return {
'prediction': 'unknown',
'confidence': 0.0,
'num_windows': 0
}
# Split text into sentences and then into windows
sentences = self.processor.split_into_sentences(text)
windows = self.processor.create_windows(sentences, WINDOW_SIZE, WINDOW_OVERLAP)
predictions = []
# Process windows in batches for efficiency
for i in range(0, len(windows), BATCH_SIZE):
batch_windows = windows[i:i + BATCH_SIZE]
# Tokenize and prepare input for the model
inputs = self.tokenizer(
batch_windows,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt"
).to(self.device)
# Run inference with no gradient calculation
with torch.no_grad():
outputs = self.model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
# Process predictions for each window
for idx, window in enumerate(batch_windows):
prediction = {
'window': window,
'human_prob': probs[idx][1].item(),
'ai_prob': probs[idx][0].item(),
'prediction': 'human' if probs[idx][1] > probs[idx][0] else 'ai'
}
predictions.append(prediction)
# Clean up to free memory
del inputs, outputs, probs
if torch.cuda.is_available():
torch.cuda.empty_cache()
if not predictions:
return {
'prediction': 'unknown',
'confidence': 0.0,
'num_windows': 0
}
# Average probabilities across all windows for final prediction
avg_human_prob = sum(p['human_prob'] for p in predictions) / len(predictions)
avg_ai_prob = sum(p['ai_prob'] for p in predictions) / len(predictions)
return {
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
'confidence': max(avg_human_prob, avg_ai_prob),
'num_windows': len(predictions)
}
# Detailed scan analysis - slower but provides sentence-level insights
# Uses windows centered around each sentence for more precise analysis
def detailed_scan(self, text: str) -> Dict:
text = text.rstrip()
if not text.strip():
return {
'sentence_predictions': [],
'highlighted_text': '',
'full_text': '',
'overall_prediction': {
'prediction': 'unknown',
'confidence': 0.0,
'num_sentences': 0
}
}
# Split text into sentences
sentences = self.processor.split_into_sentences(text)
if not sentences:
return {}
# Create a window centered on each sentence
windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
# Track appearances and scores for each sentence
sentence_appearances = {i: 0 for i in range(len(sentences))}
sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
# Process windows in batches
for i in range(0, len(windows), BATCH_SIZE):
batch_windows = windows[i:i + BATCH_SIZE]
batch_indices = window_sentence_indices[i:i + BATCH_SIZE]
# Tokenize and prepare input
inputs = self.tokenizer(
batch_windows,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt"
).to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
# Process each window's predictions
for window_idx, indices in enumerate(batch_indices):
center_idx = len(indices) // 2
center_weight = 0.7 # Center sentence gets 70% weight
edge_weight = 0.3 / (len(indices) - 1) # Other sentences share 30%
# Apply weighted prediction to each sentence in window
for pos, sent_idx in enumerate(indices):
weight = center_weight if pos == center_idx else edge_weight
sentence_appearances[sent_idx] += weight
sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item()
sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item()
# Clean up memory
del inputs, outputs, probs
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Calculate final predictions for each sentence with smoothing between adjacent sentences
sentence_predictions = []
for i in range(len(sentences)):
if sentence_appearances[i] > 0:
human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
# Apply smoothing for sentences not at boundaries
if i > 0 and i < len(sentences) - 1:
prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1]
prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1]
next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1]
next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1]
current_pred = 'human' if human_prob > ai_prob else 'ai'
prev_pred = 'human' if prev_human > prev_ai else 'ai'
next_pred = 'human' if next_human > next_ai else 'ai'
# Only smooth if current sentence prediction differs from neighbors
if current_pred != prev_pred or current_pred != next_pred:
smooth_factor = 0.1 # 10% smoothing factor
human_prob = (human_prob * (1 - smooth_factor) +
(prev_human + next_human) * smooth_factor / 2)
ai_prob = (ai_prob * (1 - smooth_factor) +
(prev_ai + next_ai) * smooth_factor / 2)
sentence_predictions.append({
'sentence': sentences[i],
'human_prob': human_prob,
'ai_prob': ai_prob,
'prediction': 'human' if human_prob > ai_prob else 'ai',
'confidence': max(human_prob, ai_prob)
})
# Return detailed results
return {
'sentence_predictions': sentence_predictions,
'highlighted_text': self.format_predictions_html(sentence_predictions),
'full_text': text,
'overall_prediction': self.aggregate_predictions(sentence_predictions)
}
# Format predictions with color highlighting for visual assessment
def format_predictions_html(self, sentence_predictions: List[Dict]) -> str:
html_parts = []
for pred in sentence_predictions:
sentence = pred['sentence']
confidence = pred['confidence']
# Color coding: stronger colors for high confidence, lighter for low confidence
if confidence >= CONFIDENCE_THRESHOLD:
if pred['prediction'] == 'human':
color = "#90EE90" # Green for human (high confidence)
else:
color = "#FFB6C6" # Pink for AI (high confidence)
else:
if pred['prediction'] == 'human':
color = "#E8F5E9" # Light green for human (low confidence)
else:
color = "#FFEBEE" # Light pink for AI (low confidence)
html_parts.append(f'<span style="background-color: {color};">{sentence}</span>')
return " ".join(html_parts)
# Aggregate individual sentence predictions into an overall result
def aggregate_predictions(self, predictions: List[Dict]) -> Dict:
if not predictions:
return {
'prediction': 'unknown',
'confidence': 0.0,
'num_sentences': 0
}
# Calculate average probabilities across all sentences
total_human_prob = sum(p['human_prob'] for p in predictions)
total_ai_prob = sum(p['ai_prob'] for p in predictions)
num_sentences = len(predictions)
avg_human_prob = total_human_prob / num_sentences
avg_ai_prob = total_ai_prob / num_sentences
return {
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
'confidence': max(avg_human_prob, avg_ai_prob),
'num_sentences': num_sentences
}
# MAIN ANALYSIS FUNCTION
# Brings everything together to analyze text based on selected mode
def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
start_time = time.time()
word_count = len(text.split())
# Auto-switch to quick mode for short texts
original_mode = mode
if word_count < 200 and mode == "detailed":
mode = "quick"
if mode == "quick":
# Perform quick analysis
result = classifier.quick_scan(text)
quick_analysis = f"""
PREDICTION: {result['prediction'].upper()}
Confidence: {result['confidence']*100:.1f}%
Windows analyzed: {result['num_windows']}
"""
# Notify if automatically switched from detailed to quick mode
if original_mode == "detailed":
quick_analysis += f"\n\nNote: Switched to quick mode because text contains only {word_count} words. Minimum 200 words required for detailed analysis."
execution_time = (time.time() - start_time) * 1000
return (
text, # Original text (no highlighting)
"Quick scan mode - no sentence-level analysis available",
quick_analysis
)
else:
# Perform detailed analysis
analysis = classifier.detailed_scan(text)
# Format sentence-by-sentence analysis text
detailed_analysis = []
for pred in analysis['sentence_predictions']:
confidence = pred['confidence'] * 100
detailed_analysis.append(f"Sentence: {pred['sentence']}")
detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}")
detailed_analysis.append(f"Confidence: {confidence:.1f}%")
detailed_analysis.append("-" * 50)
# Format overall result summary
final_pred = analysis['overall_prediction']
overall_result = f"""
FINAL PREDICTION: {final_pred['prediction'].upper()}
Overall confidence: {final_pred['confidence']*100:.1f}%
Number of sentences analyzed: {final_pred['num_sentences']}
"""
execution_time = (time.time() - start_time) * 1000
return (
analysis['highlighted_text'], # HTML-highlighted text
"\n".join(detailed_analysis), # Detailed sentence analysis
overall_result # Overall summary
)
# Initialize the classifier
classifier = TextClassifier()
# GRADIO USER INTERFACE
demo = gr.Interface(
fn=lambda text, mode: analyze_text(text, mode, classifier),
inputs=[
gr.Textbox(
lines=8,
placeholder="Enter text to analyze...",
label="Input Text"
),
gr.Radio(
choices=["quick", "detailed"],
value="quick",
label="Analysis Mode",
info="Quick mode for faster analysis, Detailed mode for sentence-level analysis"
)
],
outputs=[
gr.HTML(label="Highlighted Analysis"), # Shows color-coded result
gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10), # Detailed breakdown
gr.Textbox(label="Overall Result", lines=4) # Summary results
],
title="AI Text Detector",
description="Analyze text to detect if it was written by a human or AI. Choose between quick scan and detailed sentence-level analysis. 200+ words suggested for accurate predictions.",
api_name="predict",
flagging_mode="never"
)
# FastAPI configuration
app = demo.app
# Add CORS middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
# Start the server when run directly
if __name__ == "__main__":
demo.queue() # Enable request queuing
demo.launch(
server_name="0.0.0.0", # Listen on all interfaces
server_port=7860, # Default Gradio port
share=True # Generate public URL
)