File size: 11,086 Bytes
98a24df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e289c6c
98a24df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import streamlit as st
import torch
import torch.hub
import re
import os
import time

# --- Set Page Config First ---
st.set_page_config(
    page_title="AI Text Detector",
    layout="centered",
    initial_sidebar_state="collapsed"
)

# --- Improved CSS for a cleaner UI ---
st.markdown("""
<style>
    /* Modern clean font for the entire app */
    @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
    
    html, body, [class*="css"] {
        font-family: 'Inter', sans-serif;
    }
    
    /* Header styling */
    h1 {
        font-weight: 700;
        color: #1E3A8A;
        padding-bottom: 1rem;
        border-bottom: 2px solid #E5E7EB;
        margin-bottom: 2rem;
    }
    
    /* Text area styling */
    .stTextArea textarea {
        border: 1px solid #D1D5DB;
        border-radius: 8px;
        font-size: 16px;
        padding: 12px;
        background-color: #F9FAFB;
        box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
        transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out;
    }
    
    .stTextArea textarea:focus {
        border-color: #3B82F6;
        box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.3);
        outline: none;
    }
    
    /* Button styling */
    .stButton button {
        border-radius: 8px;
        font-weight: 600;
        padding: 10px 16px;
        background-color: #2563EB;
        color: white;
        border: none;
        width: 100%;
        transition: background-color 0.2s ease;
    }
    
    .stButton button:hover {
        background-color: #1D4ED8;
    }
    
    /* Result box styling */
    .result-box {
        border-radius: 8px;
        padding: 20px;
        margin-top: 24px;
        text-align: center;
        background-color: white;
        box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1), 0 1px 2px rgba(0, 0, 0, 0.06);
        border: 1px solid #E5E7EB;
    }
    
    /* Result highlights */
    .highlight-human {
        color: #059669;
        font-weight: 600;
        background: rgba(5, 150, 105, 0.1);
        padding: 4px 10px;
        border-radius: 8px;
        display: inline-block;
    }
    
    .highlight-ai {
        color: #DC2626;
        font-weight: 600;
        background: rgba(220, 38, 38, 0.1);
        padding: 4px 10px;
        border-radius: 8px;
        display: inline-block;
    }
    
    /* Footer styling */
    .footer {
        text-align: center;
        margin-top: 40px;
        padding-top: 20px;
        border-top: 1px solid #E5E7EB;
        color: #6B7280;
        font-size: 14px;
    }
    
    /* Progress bar styling */
    .stProgress > div > div {
        background-color: #2563EB;
    }
    
    /* General spacing */
    .block-container {
        padding-top: 2rem;
        padding-bottom: 2rem;
    }
</style>
""", unsafe_allow_html=True)

# --- Configuration ---
MODEL1_PATH = "modernbert.bin"
MODEL2_URL = "https://huggingface.co./mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
MODEL3_URL = "https://huggingface.co./mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
BASE_MODEL = "answerdotai/ModernBERT-base"
NUM_LABELS = 41
HUMAN_LABEL_INDEX = 24
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Model Loading Functions ---
@st.cache_resource(show_spinner=False)
def load_tokenizer(model_name):
    from transformers import AutoTokenizer
    return AutoTokenizer.from_pretrained(model_name)

@st.cache_resource(show_spinner=False)
def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE):
    from transformers import AutoModelForSequenceClassification
    
    # Load base model architecture
    model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels)
    
    try:
        # Load weights
        if is_url:
            state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=False)
        else:
            if not os.path.exists(model_path_or_url):
                return None
            state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False)
            
        model.load_state_dict(state_dict)
        model.to(_device).eval()
        return model
    except Exception:
        return None

# --- Text Processing Functions ---
def clean_text(text):
    if not isinstance(text, str):
        return ""
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    text = re.sub(r"\n\s*\n+", "\n\n", text)
    text = re.sub(r"[ \t]+", " ", text)
    text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text)
    text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
    return text.strip()

def classify_text(text, tokenizer, model_1, model_2, model_3, device, label_mapping, human_label_index):
    if not all([model_1, model_2, model_3, tokenizer]):
        return {"error": True, "message": "Models failed to load properly."}

    cleaned_text = clean_text(text)
    if not cleaned_text:
        return None

    try:
        inputs = tokenizer(
            cleaned_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=tokenizer.model_max_length
        ).to(device)

        with torch.no_grad():
            logits_1 = model_1(**inputs).logits
            logits_2 = model_2(**inputs).logits
            logits_3 = model_3(**inputs).logits

            softmax_1 = torch.softmax(logits_1, dim=1)
            softmax_2 = torch.softmax(logits_2, dim=1)
            softmax_3 = torch.softmax(logits_3, dim=1)

            averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
            probabilities = averaged_probabilities[0].cpu()

            if not (0 <= human_label_index < len(probabilities)):
                return {"error": True, "message": "Configuration error."}

            human_prob = probabilities[human_label_index].item() * 100

            mask = torch.ones_like(probabilities, dtype=torch.bool)
            mask[human_label_index] = False
            ai_total_prob = probabilities[mask].sum().item() * 100

            ai_probs_only = probabilities.clone()
            ai_probs_only[human_label_index] = -float('inf')
            ai_argmax_index = torch.argmax(ai_probs_only).item()
            ai_argmax_model = label_mapping.get(ai_argmax_index, f"Unknown AI (Index {ai_argmax_index})")

            if human_prob >= ai_total_prob:
                return {"is_human": True, "probability": human_prob, "model": "Human"}
            else:
                return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model}

    except Exception as e:
        return {"error": True, "message": f"Analysis failed: {str(e)}"}

# --- Label Mapping ---
LABEL_MAPPING = {
    0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
    6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
    11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
    14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
    18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
    22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
    27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
    31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
    35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
    39: 'text-davinci-002', 40: 'text-davinci-003'
}

# --- Main UI ---
st.title("AI Text Detector")

# Initialization with a progress bar
with st.spinner(""):
    # Create a progress bar
    progress_bar = st.progress(0)
    st.info("Initializing AI detection models...")
    
    # Step 1: Load tokenizer
    progress_bar.progress(20)
    time.sleep(0.5)  # Small delay for visual feedback
    TOKENIZER = load_tokenizer(BASE_MODEL)
    
    # Step 2: Load first model
    progress_bar.progress(40)
    time.sleep(0.5)  # Small delay for visual feedback
    MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE)
    
    # Step 3: Load second model
    progress_bar.progress(60)
    time.sleep(0.5)  # Small delay for visual feedback
    MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
    
    # Step 4: Load third model
    progress_bar.progress(80)
    time.sleep(0.5)  # Small delay for visual feedback
    MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
    
    # Complete initialization
    progress_bar.progress(100)
    time.sleep(0.5)  # Small delay for visual feedback

# Clear the initialization messages
st.empty()

# Check if models loaded successfully
if not all([TOKENIZER, MODEL_1, MODEL_2, MODEL_3]):
    st.error("Failed to initialize one or more AI detection models. Please try refreshing the page.")
    st.stop()

# Input area
input_text = st.text_area(
    label="Enter text to analyze:",
    placeholder="Type or paste your content here for AI detection analysis...",
    height=200,
    key="text_input"
)

# Analyze button and output
analyze_button = st.button("Analyze Text", key="analyze_button")
result_placeholder = st.empty()

if analyze_button:
    if input_text and input_text.strip():
        with st.spinner('Analyzing text...'):
            classification_result = classify_text(
                input_text,
                TOKENIZER,
                MODEL_1,
                MODEL_2,
                MODEL_3,
                DEVICE,
                LABEL_MAPPING,
                HUMAN_LABEL_INDEX
            )

        # Display result
        if classification_result is None:
            result_placeholder.warning("Please enter some text to analyze.")
        elif classification_result.get("error"):
            error_message = classification_result.get("message", "An unknown error occurred during analysis.")
            result_placeholder.error(f"Analysis Error: {error_message}")
        elif classification_result["is_human"]:
            prob = classification_result['probability']
            result_html = (
                f"<div class='result-box'>"
                f"<b>The text is</b> <span class='highlight-human'><b>{prob:.2f}%</b> likely <b>Human written</b>.</span>"
                f"</div>"
            )
            result_placeholder.markdown(result_html, unsafe_allow_html=True)
        else:  # AI generated
            prob = classification_result['probability']
            model_name = classification_result['model']
            result_html = (
                f"<div class='result-box'>"
                f"<b>The text is</b> <span class='highlight-ai'><b>{prob:.2f}%</b> likely <b>AI generated</b>.</span><br><br>"
                f"<b>Most Likely AI Model: {model_name}</b>"
                f"</div>"
            )
            result_placeholder.markdown(result_html, unsafe_allow_html=True)
    else:
        result_placeholder.warning("Please enter some text to analyze.")

# Footer
st.markdown("<div class='footer'>Developed by Eeman Majumder</div>", unsafe_allow_html=True)