chengzeyi commited on
Commit
e3355c4
ยท
1 Parent(s): 11d8b5e

add prompt safety classifier

Browse files
Files changed (2) hide show
  1. app.py +57 -4
  2. requirements.txt +2 -0
app.py CHANGED
@@ -9,8 +9,11 @@ import os
9
  from dotenv import load_dotenv
10
  import json
11
  from PIL import Image, ImageDraw, ImageFont
 
 
12
  import uuid
13
  import threading
 
14
 
15
  # Load environment variables first
16
  load_dotenv()
@@ -41,6 +44,38 @@ BACKENDS = {
41
  },
42
  }
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  class BackendStatus:
46
 
@@ -600,6 +635,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
600
  )
601
  return
602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  # Status message
604
  status_message = f"๐Ÿ”„ PROCESSING: '{prompt}'"
605
 
@@ -749,10 +802,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
749
 
750
  # Launch with increased max_threads
751
  if __name__ == "__main__":
752
- # demo.queue(max_size=50).launch(
753
- # server_name="0.0.0.0",
754
- # max_threads=16, # Increase thread count for better concurrency
755
- # )
756
  demo.queue(max_size=4).launch(
757
  server_name="0.0.0.0",
758
  max_threads=16, # Increase thread count for better concurrency
 
9
  from dotenv import load_dotenv
10
  import json
11
  from PIL import Image, ImageDraw, ImageFont
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
+ import torch
14
  import uuid
15
  import threading
16
+ import functools
17
 
18
  # Load environment variables first
19
  load_dotenv()
 
44
  },
45
  }
46
 
47
+ MODEL_URL = "MichalMlodawski/nsfw-text-detection-large"
48
+ TITLE = "๐Ÿ–ผ๏ธ๐Ÿ” Image Prompt Safety Classifier ๐Ÿ›ก๏ธ"
49
+ DESCRIPTION = "โœจ Enter an image generation prompt to classify its safety level! โœจ"
50
+
51
+ # Load model and tokenizer
52
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
53
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
54
+
55
+ # Define class names with emojis and detailed descriptions
56
+ CLASS_NAMES = {
57
+ 0: "โœ… SAFE - This prompt is appropriate and harmless.",
58
+ 1: "โš ๏ธ QUESTIONABLE - This prompt may require further review.",
59
+ 2: "๐Ÿšซ UNSAFE - This prompt is likely to generate inappropriate content."
60
+ }
61
+
62
+
63
+ @functools.lru_cache(maxsize=128)
64
+ def classify_text(text):
65
+ inputs = tokenizer(text,
66
+ return_tensors="pt",
67
+ truncation=True,
68
+ padding=True,
69
+ max_length=1024)
70
+
71
+ with torch.no_grad():
72
+ outputs = model(**inputs)
73
+
74
+ logits = outputs.logits
75
+ predicted_class = torch.argmax(logits, dim=1).item()
76
+
77
+ return predicted_class, CLASS_NAMES[predicted_class]
78
+
79
 
80
  class BackendStatus:
81
 
 
635
  )
636
  return
637
 
638
+ # Check if the prompt is safe
639
+ classification, message = classify_text(prompt)
640
+ if classification != 0:
641
+ # Handle unsafe prompt case
642
+ yield (
643
+ message,
644
+ message,
645
+ gr.update(visible=True),
646
+ gr.update(visible=False),
647
+ None,
648
+ None,
649
+ None,
650
+ None,
651
+ session_id, # Return the session ID
652
+ None,
653
+ )
654
+ return
655
+
656
  # Status message
657
  status_message = f"๐Ÿ”„ PROCESSING: '{prompt}'"
658
 
 
802
 
803
  # Launch with increased max_threads
804
  if __name__ == "__main__":
805
+ demo.queue(max_size=50).launch(
806
+ server_name="0.0.0.0",
807
+ max_threads=16, # Increase thread count for better concurrency
808
+ )
809
  demo.queue(max_size=4).launch(
810
  server_name="0.0.0.0",
811
  max_threads=16, # Increase thread count for better concurrency
requirements.txt CHANGED
@@ -3,3 +3,5 @@ aiohttp
3
  plotly
4
  python-dotenv
5
  pydantic==2.8.2
 
 
 
3
  plotly
4
  python-dotenv
5
  pydantic==2.8.2
6
+ torch
7
+ transformers==4.37.2