cyber-tagger / app.py
CyberWaifu's picture
Change default threshold to overall micro_opt
cfb04b7 verified
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
import torchvision.transforms as transforms
# Constants
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
DEFAULT_THRESHOLD = 0.32626262307167053 # Default value if slider is not used
# Download model and metadata from Hugging Face Hub
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
# Initialize ONNX Runtime session and load metadata
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
with open(meta_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
def escape_tag(tag: str) -> str:
"""Escape underscores and parentheses for Markdown."""
return tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
"""Process an image for inference using same preprocessing as training"""
image_size=512
# Initialize the same transform used during training
transform = transforms.Compose([
transforms.ToTensor(),
])
img = pil_image # Use the PIL image directly
# Convert RGBA or Palette images to RGB
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# Get original dimensions
width, height = img.size
aspect_ratio = width / height
# Calculate new dimensions to maintain aspect ratio
if aspect_ratio > 1:
new_width = image_size
new_height = int(new_width / aspect_ratio)
else:
new_height = image_size
new_width = int(new_height * aspect_ratio)
# Resize with LANCZOS filter
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with padding
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
paste_x = (image_size - new_width) // 2
paste_y = (image_size - new_height) // 2
new_image.paste(img, (paste_x, paste_y))
# Apply transforms (without normalization)
img_tensor = transform(new_image)
return img_tensor.numpy() # Convert the PyTorch tensor to NumPy array
def run_inference(pil_image: Image.Image) -> np.ndarray:
"""
Preprocess the image and run the ONNX model inference.
Returns the refined logits as a numpy array.
"""
input_tensor = preprocess_image(pil_image)
input_name = session.get_inputs()[0].name
# Expand dimensions to make it (1, C, H, W)
input_tensor_expanded = np.expand_dims(input_tensor, axis=0)
# Only refined_logits are used (initial_logits is ignored)
_, refined_logits = session.run(None, {input_name: input_tensor_expanded})
return refined_logits[0]
def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
"""
Compute probabilities from logits and collect tag predictions.
Returns:
results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
prompt_tags_by_cat: Dictionary for prompt-style output (character, general).
all_artist_tags: All artist tags (with probabilities) regardless of threshold.
"""
probs = 1 / (1 + np.exp(-refined_logits))
idx_to_tag = metadata["idx_to_tag"]
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get("category_thresholds", {})
results_by_cat = {}
# For prompt style, only include character and general tags (artists handled separately)
prompt_tags_by_cat = {"character": [], "general": []}
all_artist_tags = []
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
thresh = category_thresholds.get(cat, default_threshold)
if cat == "artist":
all_artist_tags.append((tag, float(prob)))
if float(prob) >= thresh:
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
if cat in prompt_tags_by_cat:
prompt_tags_by_cat[cat].append((tag, float(prob)))
return results_by_cat, prompt_tags_by_cat, all_artist_tags
def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
"""
Format the tags for prompt-style output.
Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown.
Returns a comma-separated string of escaped tags.
"""
# Always select the best artist tag from all_artist_tags, regardless of threshold.
best_artist_tag = None
if all_artist_tags:
best_artist = max(all_artist_tags, key=lambda item: item[1])
best_artist_tag = escape_tag(best_artist[0])
# Sort character and general tags by probability (descending)
for cat in prompt_tags_by_cat:
prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
prompt_tags = []
if best_artist_tag:
prompt_tags.append(best_artist_tag)
prompt_tags.extend(character_tags)
prompt_tags.extend(general_tags)
return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
"""
Format the tags for detailed output.
Returns a Markdown-formatted string listing tags by category.
"""
if not results_by_cat:
return "No tags predicted for this image."
# Include an artist tag even if below threshold
if "artist" not in results_by_cat and all_artist_tags:
best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
lines = ["**Predicted Tags by Category:** \n"]
for cat, tag_list in results_by_cat.items():
tag_list.sort(key=lambda x: x[1], reverse=True)
lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
for tag, prob in tag_list:
lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})")
lines.append("") # blank line between categories
return "\n".join(lines)
def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
"""
Run inference on the image and return formatted tags based on the chosen output format.
The slider value (threshold) overrides the default threshold for tag selection.
"""
if pil_image is None:
return "Please upload an image."
refined_logits = run_inference(pil_image)
results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, default_threshold=threshold)
if output_format == "Prompt-style Tags":
return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
else:
return format_detailed_output(results_by_cat, all_artist_tags)
# Build the Gradio Blocks UI
demo = gr.Blocks(theme="gradio/soft")
with demo:
gr.Markdown(
"# 🏷️ Camie Tagger – Anime Image Tagging\n"
"This demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. "
"Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
)
gr.Markdown(
"*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags.)*"
)
with gr.Row():
with gr.Column():
image_in = gr.Image(type="pil", label="Input Image")
format_choice = gr.Radio(
choices=["Prompt-style Tags", "Detailed Output"],
value="Prompt-style Tags",
label="Output Format"
)
# Slider to modify the default threshold value used in inference.
threshold_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=DEFAULT_THRESHOLD,
label="Threshold"
)
tag_button = gr.Button("🔍 Tag Image")
with gr.Column():
output_box = gr.Markdown("") # Markdown output for formatted results
# Pass the threshold_slider value into the tag_image function
tag_button.click(fn=tag_image, inputs=[image_in, format_choice, threshold_slider], outputs=output_box)
gr.Markdown(
"----\n"
"**Model:** [Camie Tagger ONNX](https://huggingface.co./AngelBottomless/camie-tagger-onnxruntime) • "
"**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • "
"**ONNX Runtime:** for efficient CPU inference • "
"*Demo built with Gradio Blocks.*"
)
if __name__ == "__main__":
demo.launch()