Spaces:
Running
Running
add prompt safety classifier
Browse files- app.py +57 -4
- requirements.txt +2 -0
app.py
CHANGED
@@ -9,8 +9,11 @@ import os
|
|
9 |
from dotenv import load_dotenv
|
10 |
import json
|
11 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
|
12 |
import uuid
|
13 |
import threading
|
|
|
14 |
|
15 |
# Load environment variables first
|
16 |
load_dotenv()
|
@@ -41,6 +44,38 @@ BACKENDS = {
|
|
41 |
},
|
42 |
}
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
class BackendStatus:
|
46 |
|
@@ -600,6 +635,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
600 |
)
|
601 |
return
|
602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
# Status message
|
604 |
status_message = f"๐ PROCESSING: '{prompt}'"
|
605 |
|
@@ -749,10 +802,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
749 |
|
750 |
# Launch with increased max_threads
|
751 |
if __name__ == "__main__":
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
demo.queue(max_size=4).launch(
|
757 |
server_name="0.0.0.0",
|
758 |
max_threads=16, # Increase thread count for better concurrency
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
import json
|
11 |
from PIL import Image, ImageDraw, ImageFont
|
12 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
13 |
+
import torch
|
14 |
import uuid
|
15 |
import threading
|
16 |
+
import functools
|
17 |
|
18 |
# Load environment variables first
|
19 |
load_dotenv()
|
|
|
44 |
},
|
45 |
}
|
46 |
|
47 |
+
MODEL_URL = "MichalMlodawski/nsfw-text-detection-large"
|
48 |
+
TITLE = "๐ผ๏ธ๐ Image Prompt Safety Classifier ๐ก๏ธ"
|
49 |
+
DESCRIPTION = "โจ Enter an image generation prompt to classify its safety level! โจ"
|
50 |
+
|
51 |
+
# Load model and tokenizer
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
|
53 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
|
54 |
+
|
55 |
+
# Define class names with emojis and detailed descriptions
|
56 |
+
CLASS_NAMES = {
|
57 |
+
0: "โ
SAFE - This prompt is appropriate and harmless.",
|
58 |
+
1: "โ ๏ธ QUESTIONABLE - This prompt may require further review.",
|
59 |
+
2: "๐ซ UNSAFE - This prompt is likely to generate inappropriate content."
|
60 |
+
}
|
61 |
+
|
62 |
+
|
63 |
+
@functools.lru_cache(maxsize=128)
|
64 |
+
def classify_text(text):
|
65 |
+
inputs = tokenizer(text,
|
66 |
+
return_tensors="pt",
|
67 |
+
truncation=True,
|
68 |
+
padding=True,
|
69 |
+
max_length=1024)
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
outputs = model(**inputs)
|
73 |
+
|
74 |
+
logits = outputs.logits
|
75 |
+
predicted_class = torch.argmax(logits, dim=1).item()
|
76 |
+
|
77 |
+
return predicted_class, CLASS_NAMES[predicted_class]
|
78 |
+
|
79 |
|
80 |
class BackendStatus:
|
81 |
|
|
|
635 |
)
|
636 |
return
|
637 |
|
638 |
+
# Check if the prompt is safe
|
639 |
+
classification, message = classify_text(prompt)
|
640 |
+
if classification != 0:
|
641 |
+
# Handle unsafe prompt case
|
642 |
+
yield (
|
643 |
+
message,
|
644 |
+
message,
|
645 |
+
gr.update(visible=True),
|
646 |
+
gr.update(visible=False),
|
647 |
+
None,
|
648 |
+
None,
|
649 |
+
None,
|
650 |
+
None,
|
651 |
+
session_id, # Return the session ID
|
652 |
+
None,
|
653 |
+
)
|
654 |
+
return
|
655 |
+
|
656 |
# Status message
|
657 |
status_message = f"๐ PROCESSING: '{prompt}'"
|
658 |
|
|
|
802 |
|
803 |
# Launch with increased max_threads
|
804 |
if __name__ == "__main__":
|
805 |
+
demo.queue(max_size=50).launch(
|
806 |
+
server_name="0.0.0.0",
|
807 |
+
max_threads=16, # Increase thread count for better concurrency
|
808 |
+
)
|
809 |
demo.queue(max_size=4).launch(
|
810 |
server_name="0.0.0.0",
|
811 |
max_threads=16, # Increase thread count for better concurrency
|
requirements.txt
CHANGED
@@ -3,3 +3,5 @@ aiohttp
|
|
3 |
plotly
|
4 |
python-dotenv
|
5 |
pydantic==2.8.2
|
|
|
|
|
|
3 |
plotly
|
4 |
python-dotenv
|
5 |
pydantic==2.8.2
|
6 |
+
torch
|
7 |
+
transformers==4.37.2
|