import torch import numpy as np from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch.nn.functional as F import spacy from typing import List, Dict, Tuple import logging import os import gradio as gr from fastapi.middleware.cors import CORSMiddleware from concurrent.futures import ThreadPoolExecutor from functools import partial import time from datetime import datetime logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MAX_LENGTH = 512 MODEL_NAME = "microsoft/deberta-v3-small" WINDOW_SIZE = 6 WINDOW_OVERLAP = 2 CONFIDENCE_THRESHOLD = 0.65 BATCH_SIZE = 8 MAX_WORKERS = 4 class TextWindowProcessor: def __init__(self): try: self.nlp = spacy.load("en_core_web_sm") except OSError: logger.info("Downloading spacy model...") spacy.cli.download("en_core_web_sm") self.nlp = spacy.load("en_core_web_sm") if 'sentencizer' not in self.nlp.pipe_names: self.nlp.add_pipe('sentencizer') disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer'] self.nlp.disable_pipes(*disabled_pipes) self.executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) def split_into_sentences(self, text: str) -> List[str]: doc = self.nlp(text) return [str(sent).strip() for sent in doc.sents] def create_windows(self, sentences: List[str], window_size: int, overlap: int) -> List[str]: if len(sentences) < window_size: return [" ".join(sentences)] windows = [] stride = window_size - overlap for i in range(0, len(sentences) - window_size + 1, stride): window = sentences[i:i + window_size] windows.append(" ".join(window)) return windows def create_centered_windows(self, sentences: List[str], window_size: int) -> Tuple[List[str], List[List[int]]]: windows = [] window_sentence_indices = [] for i in range(len(sentences)): half_window = window_size // 2 start_idx = max(0, i - half_window) end_idx = min(len(sentences), i + half_window + 1) window = sentences[start_idx:end_idx] windows.append(" ".join(window)) window_sentence_indices.append(list(range(start_idx, end_idx))) return windows, window_sentence_indices class TextClassifier: def __init__(self): if not torch.cuda.is_available(): torch.set_num_threads(MAX_WORKERS) torch.set_num_interop_threads(MAX_WORKERS) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_name = MODEL_NAME self.tokenizer = None self.model = None self.processor = TextWindowProcessor() self.initialize_model() def initialize_model(self): logger.info("Initializing model and tokenizer...") from transformers import DebertaV2TokenizerFast self.tokenizer = DebertaV2TokenizerFast.from_pretrained( self.model_name, model_max_length=MAX_LENGTH, use_fast=True ) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_name, num_labels=2 ).to(self.device) model_path = "model_20250209_184929_acc1.0000.pt" if os.path.exists(model_path): logger.info(f"Loading custom model from {model_path}") checkpoint = torch.load(model_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) else: logger.warning("Custom model file not found. Using base model.") self.model.eval() def quick_scan(self, text: str) -> Dict: if not text.strip(): return { 'prediction': 'unknown', 'confidence': 0.0, 'num_windows': 0 } sentences = self.processor.split_into_sentences(text) windows = self.processor.create_windows(sentences, WINDOW_SIZE, WINDOW_OVERLAP) predictions = [] for i in range(0, len(windows), BATCH_SIZE): batch_windows = windows[i:i + BATCH_SIZE] inputs = self.tokenizer( batch_windows, truncation=True, padding=True, max_length=MAX_LENGTH, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = F.softmax(outputs.logits, dim=-1) for idx, window in enumerate(batch_windows): prediction = { 'window': window, 'human_prob': probs[idx][1].item(), 'ai_prob': probs[idx][0].item(), 'prediction': 'human' if probs[idx][1] > probs[idx][0] else 'ai' } predictions.append(prediction) del inputs, outputs, probs if torch.cuda.is_available(): torch.cuda.empty_cache() if not predictions: return { 'prediction': 'unknown', 'confidence': 0.0, 'num_windows': 0 } avg_human_prob = sum(p['human_prob'] for p in predictions) / len(predictions) avg_ai_prob = sum(p['ai_prob'] for p in predictions) / len(predictions) return { 'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', 'confidence': max(avg_human_prob, avg_ai_prob), 'num_windows': len(predictions) } def detailed_scan(self, text: str) -> Dict: text = text.rstrip() if not text.strip(): return { 'sentence_predictions': [], 'highlighted_text': '', 'full_text': '', 'overall_prediction': { 'prediction': 'unknown', 'confidence': 0.0, 'num_sentences': 0 } } sentences = self.processor.split_into_sentences(text) if not sentences: return {} windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE) sentence_appearances = {i: 0 for i in range(len(sentences))} sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))} for i in range(0, len(windows), BATCH_SIZE): batch_windows = windows[i:i + BATCH_SIZE] batch_indices = window_sentence_indices[i:i + BATCH_SIZE] inputs = self.tokenizer( batch_windows, truncation=True, padding=True, max_length=MAX_LENGTH, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = F.softmax(outputs.logits, dim=-1) for window_idx, indices in enumerate(batch_indices): center_idx = len(indices) // 2 center_weight = 0.7 edge_weight = 0.3 / (len(indices) - 1) for pos, sent_idx in enumerate(indices): weight = center_weight if pos == center_idx else edge_weight sentence_appearances[sent_idx] += weight sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item() sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item() del inputs, outputs, probs if torch.cuda.is_available(): torch.cuda.empty_cache() sentence_predictions = [] for i in range(len(sentences)): if sentence_appearances[i] > 0: human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i] ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i] if i > 0 and i < len(sentences) - 1: prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1] prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1] next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1] next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1] current_pred = 'human' if human_prob > ai_prob else 'ai' prev_pred = 'human' if prev_human > prev_ai else 'ai' next_pred = 'human' if next_human > next_ai else 'ai' if current_pred != prev_pred or current_pred != next_pred: smooth_factor = 0.1 human_prob = (human_prob * (1 - smooth_factor) + (prev_human + next_human) * smooth_factor / 2) ai_prob = (ai_prob * (1 - smooth_factor) + (prev_ai + next_ai) * smooth_factor / 2) sentence_predictions.append({ 'sentence': sentences[i], 'human_prob': human_prob, 'ai_prob': ai_prob, 'prediction': 'human' if human_prob > ai_prob else 'ai', 'confidence': max(human_prob, ai_prob) }) return { 'sentence_predictions': sentence_predictions, 'highlighted_text': self.format_predictions_html(sentence_predictions), 'full_text': text, 'overall_prediction': self.aggregate_predictions(sentence_predictions) } def format_predictions_html(self, sentence_predictions: List[Dict]) -> str: html_parts = [] for pred in sentence_predictions: sentence = pred['sentence'] confidence = pred['confidence'] if confidence >= CONFIDENCE_THRESHOLD: if pred['prediction'] == 'human': color = "#90EE90" else: color = "#FFB6C6" else: if pred['prediction'] == 'human': color = "#E8F5E9" else: color = "#FFEBEE" html_parts.append(f'{sentence}') return " ".join(html_parts) def aggregate_predictions(self, predictions: List[Dict]) -> Dict: if not predictions: return { 'prediction': 'unknown', 'confidence': 0.0, 'num_sentences': 0 } total_human_prob = sum(p['human_prob'] for p in predictions) total_ai_prob = sum(p['ai_prob'] for p in predictions) num_sentences = len(predictions) avg_human_prob = total_human_prob / num_sentences avg_ai_prob = total_ai_prob / num_sentences return { 'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', 'confidence': max(avg_human_prob, avg_ai_prob), 'num_sentences': num_sentences } def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple: start_time = time.time() word_count = len(text.split()) original_mode = mode if word_count < 200 and mode == "detailed": mode = "quick" if mode == "quick": result = classifier.quick_scan(text) quick_analysis = f""" PREDICTION: {result['prediction'].upper()} Confidence: {result['confidence']*100:.1f}% Windows analyzed: {result['num_windows']} """ if original_mode == "detailed": quick_analysis += f"\n\nNote: Switched to quick mode because text contains only {word_count} words. Minimum 200 words required for detailed analysis." execution_time = (time.time() - start_time) * 1000 return ( text, "Quick scan mode - no sentence-level analysis available", quick_analysis ) else: analysis = classifier.detailed_scan(text) detailed_analysis = [] for pred in analysis['sentence_predictions']: confidence = pred['confidence'] * 100 detailed_analysis.append(f"Sentence: {pred['sentence']}") detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}") detailed_analysis.append(f"Confidence: {confidence:.1f}%") detailed_analysis.append("-" * 50) final_pred = analysis['overall_prediction'] overall_result = f""" FINAL PREDICTION: {final_pred['prediction'].upper()} Overall confidence: {final_pred['confidence']*100:.1f}% Number of sentences analyzed: {final_pred['num_sentences']} """ execution_time = (time.time() - start_time) * 1000 return ( analysis['highlighted_text'], "\n".join(detailed_analysis), overall_result ) classifier = TextClassifier() demo = gr.Interface( fn=lambda text, mode: analyze_text(text, mode, classifier), inputs=[ gr.Textbox( lines=8, placeholder="Enter text to analyze...", label="Input Text" ), gr.Radio( choices=["quick", "detailed"], value="quick", label="Analysis Mode", info="Quick mode for faster analysis, Detailed mode for sentence-level analysis" ) ], outputs=[ gr.HTML(label="Highlighted Analysis"), gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10), gr.Textbox(label="Overall Result", lines=4) ], title="AI Text Detector", description="Analyze text to detect if it was written by a human or AI. Choose between quick scan and detailed sentence-level analysis. 200+ words suggested for accurate predictions.", api_name="predict", flagging_mode="never" ) app = demo.app app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["GET", "POST", "OPTIONS"], allow_headers=["*"], ) if __name__ == "__main__": demo.queue() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )