import os import io,copy,requests,spaces,gradio as gr,numpy as np from transformers import AutoProcessor,AutoModelForCausalLM from PIL import Image,ImageDraw,ImageFont from unittest.mock import patch import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time from datetime import datetime,timezone from collections import defaultdict from apscheduler.schedulers.background import BackgroundScheduler import json from modules.classifyTags import classify_tags,process_tags from modules.florence2 import process_image,single_task_list,update_task_dropdown from modules.reorganizer_model import reorganizer_list,reorganizer_class from modules.tag_enhancer import prompt_enhancer os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1' TITLE = "Multi-Tagger" DESCRIPTION = """ Multi-Tagger is a versatile application that combines the Waifu Diffusion and Florence 2 models for advanced image analysis and captioning. Perfect for AI artists and enthusiasts, it offers a range of features: - Batch processing for multiple images - Multi-category tagging with structured tag display. - CUDA or CPU support. - Image tagging, various captioning tasks which includes: Caption, Detailed Caption, Object Detection with visual outputs and much more. Example image by [me.](https://huggingface.co./Werli) """ # Dataset v3 series of models: SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" # Dataset v2 series of models: MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" # IdolSankaku series of models: EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1" SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1" # Files to download from the repos MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','_','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||'] def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args() def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh class Timer: def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)] def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now)) def report(self,is_clear_checkpoints=True): max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1] for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint() def report_all(self): print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear() def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)] class Predictor: def __init__(self): self.model_target_size = None self.last_loaded_repo = None def download_model(self, model_repo): csv_path = huggingface_hub.hf_hub_download( model_repo, LABEL_FILENAME, ) model_path = huggingface_hub.hf_hub_download( model_repo, MODEL_FILENAME, ) return csv_path, model_path def load_model(self, model_repo): if model_repo == self.last_loaded_repo: return csv_path, model_path = self.download_model(model_repo) tags_df = pd.read_csv(csv_path) sep_tags = load_labels(tags_df) self.tag_names = sep_tags[0] self.rating_indexes = sep_tags[1] self.general_indexes = sep_tags[2] self.character_indexes = sep_tags[3] model = rt.InferenceSession(model_path) _, height, width, _ = model.get_inputs()[0].shape self.model_target_size = height self.last_loaded_repo = model_repo self.model = model def prepare_image(self, path): image = Image.open(path) image = image.convert("RGBA") target_size = self.model_target_size canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") # Pad image to square image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) # Resize if max_dim != target_size: padded_image = padded_image.resize( (target_size, target_size), Image.BICUBIC, ) # Convert to numpy array image_array = np.asarray(padded_image, dtype=np.float32) # Convert PIL-native RGB to BGR image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def create_file(self, content: str, directory: str, fileName: str) -> str: # Write the content to a file file_path = os.path.join(directory, fileName) if fileName.endswith('.json'): with open(file_path, 'w', encoding="utf-8") as file: file.write(content) else: with open(file_path, 'w+', encoding="utf-8") as file: file.write(content) return file_path def predict( self, gallery, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, characters_merge_enabled, reorganizer_model_repo, additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress() ): # Clear tag_results before starting a new prediction tag_results.clear() gallery_len = len(gallery) print(f"Predict load model: {model_repo}, gallery length: {gallery_len}") timer = Timer() # Create a timer progressRatio = 0.5 if reorganizer_model_repo else 1 progressTotal = gallery_len + 1 current_progress = 0 self.load_model(model_repo) current_progress += progressRatio/progressTotal; progress(current_progress, desc="Initialize wd model finished") timer.checkpoint(f"Initialize wd model") txt_infos = [] output_dir = tempfile.mkdtemp() if not os.path.exists(output_dir): os.makedirs(output_dir) sorted_general_strings = "" # Create categorized output string categorized_output_strings = [] rating = None character_res = None general_res = None if reorganizer_model_repo: print(f"Reorganizer load model {reorganizer_model_repo}") reorganizer = reorganizer_class(reorganizer_model_repo, loadModel=True) current_progress += progressRatio/progressTotal; progress(current_progress, desc="Initialize reoganizer model finished") timer.checkpoint(f"Initialize reoganizer model") timer.report() prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()] append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()] if prepend_list and append_list: append_list = [item for item in append_list if item not in prepend_list] # Dictionary to track counters for each filename name_counters = defaultdict(int) for idx, value in enumerate(gallery): try: image_path = value[0] image_name = os.path.splitext(os.path.basename(image_path))[0] # Increment the counter for the current name name_counters[image_name] += 1 if name_counters[image_name] > 1: image_name = f"{image_name}_{name_counters[image_name]:02d}" image = self.prepare_image(image_path) input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name print(f"Gallery {idx:02d}: Starting run wd model...") preds = self.model.run([label_name], {input_name: image})[0] labels = list(zip(self.tag_names, preds[0].astype(float))) # First 4 labels are actually ratings: pick one with argmax ratings_names = [labels[i] for i in self.rating_indexes] rating = dict(ratings_names) # Then we have general tags: pick any where prediction confidence > threshold general_names = [labels[i] for i in self.general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_thresh = mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_thresh] general_res = dict(general_res) # Everything else is characters: pick any where prediction confidence > threshold character_names = [labels[i] for i in self.character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_thresh = mcut_threshold(character_probs) character_thresh = max(0.15, character_thresh) character_res = [x for x in character_names if x[1] > character_thresh] character_res = dict(character_res) character_list = list(character_res.keys()) sorted_general_list = sorted( general_res.items(), key=lambda x: x[1], reverse=True, ) sorted_general_list = [x[0] for x in sorted_general_list] # Remove values from character_list that already exist in sorted_general_list character_list = [item for item in character_list if item not in sorted_general_list] # Remove values from sorted_general_list that already exist in prepend_list or append_list if prepend_list: sorted_general_list = [item for item in sorted_general_list if item not in prepend_list] if append_list: sorted_general_list = [item for item in sorted_general_list if item not in append_list] sorted_general_list = prepend_list + sorted_general_list + append_list sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)") classified_tags, unclassified_tags = classify_tags(sorted_general_list) # Create a single string of ALL categorized tags for the current image categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()]) categorized_output_strings.append(categorized_output_string) # Collect all categorized output strings into a single string final_categorized_output = ', '.join(categorized_output_strings) # Create a .txt file for "Output (string)" and "Categorized Output (string)" txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}" txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt") txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"}) # Create a .json file for "Categorized (tags)" json_content = json.dumps(classified_tags, indent=4) json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json") txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"}) # Save a copy of the uploaded image in PNG format image_path = value[0] image = Image.open(image_path) image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG") txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"}) current_progress += progressRatio/progressTotal; progress(current_progress, desc=f"image{idx:02d}, predict finished") timer.checkpoint(f"image{idx:02d}, predict finished") if reorganizer_model_repo: print(f"Starting reorganizer...") reorganize_strings = reorganizer.reorganize(sorted_general_strings) reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings) reorganize_strings = re.sub(r"\n+", ",", reorganize_strings) reorganize_strings = re.sub(r",,+", ",", reorganize_strings) sorted_general_strings += ",\n\n" + reorganize_strings current_progress += progressRatio/progressTotal; progress(current_progress, desc=f"image{idx:02d}, reorganizer finished") timer.checkpoint(f"image{idx:02d}, reorganizer finished") txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt") txt_infos.append({"path":txt_file, "name": image_name + ".txt"}) # Store the result in tag_results using image_path as the key tag_results[image_path] = { "strings": sorted_general_strings, "strings2": categorized_output_string, # Store the categorized output string here "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags, "enhanced_tags": "" # Initialize as empty string } timer.report() except Exception as e: print(traceback.format_exc()) print("Error predict: " + str(e)) # Zip creation logic: download = [] if txt_infos is not None and len(txt_infos) > 0: downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip") with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip: for info in txt_infos: # Get file name from lookup taggers_zip.write(info["path"], arcname=info["name"]) download.append(downloadZipPath) # End zip creation logic if reorganizer_model_repo: reorganizer.release_vram() del reorganizer progress(1, desc=f"Predict completed") timer.report_all() # Print all recorded times print("Predict is complete.") return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData): if not selected_state: return selected_state tag_result = { "strings": "", "strings2": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}", "enhanced_tags": "" } if selected_state.value["image"]["path"] in tag_results: tag_result = tag_results[selected_state.value["image"]["path"]] return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"] def append_gallery(gallery:list,image:str): if gallery is None:gallery=[] if not image:return gallery,None gallery.append(image);return gallery,None def extend_gallery(gallery:list,images): if gallery is None:gallery=[] if not images:return gallery gallery.extend(images);return gallery def remove_image_from_gallery(gallery:list,selected_image:str): if not gallery or not selected_image:return gallery selected_image=ast.literal_eval(selected_image) if selected_image in gallery:gallery.remove(selected_image) return gallery args = parse_args() predictor = Predictor() dropdown_list = [ EVA02_LARGE_MODEL_DSV3_REPO, SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, # --- MOAT_MODEL_DSV2_REPO, SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO, # --- SWINV2_MODEL_IS_DSV1_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO, ] def _restart_space(): HF_TOKEN=os.getenv('HF_TOKEN') if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.') huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False) scheduler=BackgroundScheduler() # Add a job to restart the space every 2 days (172800 seconds) restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800) scheduler.start() next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc) NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process." css = """ #output {height: 500px; overflow: auto; border: 1px solid #ccc;} label.float.svelte-i3tvor {position: relative !important;} .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));} """ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo: gr.Markdown(value=f"

{TITLE}

") gr.Markdown(value=DESCRIPTION) gr.Markdown(NEXT_RESTART) with gr.Tab(label="Waifu Diffusion"): with gr.Row(): with gr.Column(): submit = gr.Button(value="Submit", variant="primary", size="lg") with gr.Column(variant="panel"): # Create an Image component for uploading images image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150) with gr.Row(): upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm") remove_button = gr.Button("Remove Selected Image", size="sm") gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images") model_repo = gr.Dropdown( dropdown_list, value=EVA02_LARGE_MODEL_DSV3_REPO, label="Model", ) with gr.Row(): general_thresh = gr.Slider( 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold", scale=3, ) general_mcut_enabled = gr.Checkbox( value=False, label="Use MCut threshold", scale=1, ) with gr.Row(): character_thresh = gr.Slider( 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold", scale=3, ) character_mcut_enabled = gr.Checkbox( value=False, label="Use MCut threshold", scale=1, ) with gr.Row(): characters_merge_enabled = gr.Checkbox( value=True, label="Merge characters into the string output", scale=1, ) with gr.Row(): reorganizer_model_repo = gr.Dropdown( [None] + reorganizer_list, value=None, label="Reorganizer Model", info="Use a model to create a description for you", ) with gr.Row(): additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)") additional_tags_append = gr.Text(label="Append Additional tags (comma split)") with gr.Row(): clear = gr.ClearButton( components=[ gallery, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, characters_merge_enabled, reorganizer_model_repo, additional_tags_prepend, additional_tags_append, ], variant="secondary", size="lg", ) with gr.Column(variant="panel"): download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0 character_res = gr.Label(label="Output (characters)") # 1 sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2 final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3 pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4 enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5 prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6 categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7 rating = gr.Label(label="Rating") # 8 general_res = gr.Label(label="Output (tags)") # 9 unclassified = gr.JSON(label="Unclassified (tags)") # 10 clear.add( [ download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, prompt_enhancer_model, enhanced_tags, ] ) tag_results = gr.State({}) # Define the event listener to add the uploaded image to the gallery image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input]) # When the upload button is clicked, add the new images to the gallery upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery) # Event to update the selected image when an image is clicked in the gallery selected_image = gr.Textbox(label="Selected Image", visible=False) gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags]) # Event to remove a selected image from the gallery remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery) # Event to for the Prompt Enhancer Button pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags]) submit.click( predictor.predict, inputs=[ gallery, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, characters_merge_enabled, reorganizer_model_repo, additional_tags_prepend, additional_tags_append, tag_results, ], outputs=[download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, tag_results,], ) gr.Examples( [["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]], inputs=[ image_input, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, ], ) with gr.Tab(label="Tag Categorizer + Enhancer"): with gr.Row(): with gr.Column(variant="panel"): input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...") submit_button = gr.Button(value="Submit", variant="primary", size="lg") with gr.Column(variant="panel"): categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8) categorized_json = gr.JSON(label="Categorized (tags) - JSON") submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json]) with gr.Column(variant="panel"): pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags]) with gr.Tab(label="Florence 2 Image Captioning"): with gr.Row(): with gr.Column(variant="panel"): input_img = gr.Image(label="Input Picture") task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task') task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption") task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt) text_input = gr.Textbox(label="Text Input (optional)") submit_btn = gr.Button(value="Submit") with gr.Column(variant="panel"): output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8) output_img = gr.Image(label="Output Image") gr.Examples( examples=[ ["images/image1.png", 'Object Detection'], ["images/image2.png", 'OCR with Region'] ], inputs=[input_img, task_prompt], outputs=[output_text, output_img], fn=process_image, cache_examples=False, label='Try examples' ) submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img]) demo.queue(max_size=2).launch()