qwen2vl-flux / app.py
LPX55's picture
a
7fa6974
from typing import Tuple
import requests
import random
import numpy as np
import gradio as gr
import spaces
import torch
from PIL import Image
from huggingface_hub import login
import os
import time
from gradio_imageslider import ImageSlider
import requests
from io import BytesIO
import PIL.Image
import requests
import shutil
import glob
from huggingface_hub import snapshot_download, hf_hub_download
MAX_SEED = np.iinfo(np.int32).max
IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN: login(token=HF_TOKEN)
cp_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints')
snapshot_download("Djrango/Qwen2vl-Flux", local_dir=cp_dir)
hf_hub_download(repo_id="TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline", local_dir=f"{cp_dir}/anyline")
try:
shutil.move("checkpoints/anyline/Anyline/MTEED.pth", f"{cp_dir}/anyline")
except:
print("anyline fail")
snapshot_download("depth-anything/Depth-Anything-V2-Large", local_dir=f"{cp_dir}/depth-anything-v2")
snapshot_download("facebook/sam2-hiera-large", local_dir=f"{cp_dir}/segment-anything-2")
# https://github.com/facebookresearch/sam2/issues/26
os.makedirs("sam2_configs", exist_ok=True)
for p in glob.glob(f"{cp_dir}/segment-anything-2/*.yaml"):
shutil.copy(p, "sam2_configs")
from modelmod import FluxModel
model = FluxModel(device=DEVICE, is_turbo=False, required_features=['controlnet', 'depth'], is_quantization=True) # , 'sam'
QWEN2VLFLUX_MODES = ["variation", "img2img", "inpaint", "controlnet", "controlnet-inpaint"]
QWEN2VLFLUX_ASPECT_RATIO = ["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"]
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
def resize_image_dimensions(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = IMAGE_SIZE
) -> Tuple[int, int]:
width, height = original_resolution_wh
# if width <= maximum_dimension and height <= maximum_dimension:
# width = width - (width % 32)
# height = height - (height % 32)
# return width, height
if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)
return new_width, new_height
def fetch_from_url(url: str, name: str):
try:
print(f"start to fetch {name} from url", url)
response = requests.get(url)
response.raise_for_status()
image = PIL.Image.open(BytesIO(response.content))
print(f"fetch {name} success")
return image
except Exception as e:
print(e)
return None
@spaces.GPU(duration=100)
@torch.inference_mode()
def process(
mode: str,
input_image_editor: dict,
ref_image: Image.Image,
image_url: str,
mask_url: str,
ref_url: str,
input_text: str,
strength: float,
num_inference_steps: int,
guidance_scale: float,
aspect_ratio: str,
attn_mode: bool,
center_x: float,
center_y: float,
radius: float,
line_mode: bool,
line_strength: float,
depth_mode: bool,
depth_strength: float,
progress=gr.Progress(track_tqdm=True)
):
#if not input_text:
# gr.Info("Please enter a text prompt.")
# return None
kwargs = {}
image = input_image_editor['background']
mask = input_image_editor['layers'][0]
if image_url: image = fetch_from_url(image_url, "image")
if mask_url: mask = fetch_from_url(mask_url, "mask")
if ref_url: ref_image = fetch_from_url(ref_url, "refernce image")
if not image:
gr.Info("Please upload an image.")
return None
if ref_image: kwargs["input_image_b"] = ref_image
if mode == "inpaint" or mode == "controlnet-inpaint":
if not mask:
gr.Info("Please draw a mask on the image.")
return None
kwargs["mask_image"] = mask
if attn_mode:
kwargs["center_x"] = center_x
kwargs["center_y"] = center_y
kwargs["radius"] = radius
with calculateDuration("run inference"):
result = model.generate(
input_image_a=image,
prompt=input_text,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
aspect_ratio=aspect_ratio,
mode=mode,
denoise_strength=strength,
line_mode=line_mode,
line_strength=line_strength,
depth_mode=depth_mode,
depth_strength=depth_strength,
imageCount=1,
**kwargs
)[0]
#return result
return [image, result]
CSS = """
.title { text-align: center; }
"""
with gr.Blocks(fill_width=True, css=CSS) as demo:
gr.Markdown("# Qwen2VL-Flux", elem_classes="title")
with gr.Row():
with gr.Column():
gen_mode = gr.Radio(label="Generation mode", choices=QWEN2VLFLUX_MODES, value="variation")
with gr.Row():
input_image_editor = gr.ImageEditor(label='Image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB',
layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
ref_image = gr.Image(label='Reference image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB')
with gr.Accordion("Image from URL", open=False):
image_url = gr.Textbox(label="Image url", show_label=True, max_lines=1, placeholder="Enter your image url (Optional)")
mask_url = gr.Textbox(label="Mask image url", show_label=True, max_lines=1, placeholder="Enter your mask image url (Optional)")
ref_url = gr.Textbox(label="Reference image url", show_label=True, max_lines=1, placeholder="Enter your reference image url (Optional)")
with gr.Accordion("Prompt Settings", open=True):
input_text = gr.Textbox(label="Prompt", show_label=True, max_lines=1, placeholder="Enter your prompt")
submit_button = gr.Button(value='Submit', variant='primary')
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
denoise_strength = gr.Slider(label="Denoise strength", minimum=0, maximum=1, step=0.01, value=0.75)
aspect_ratio = gr.Radio(label="Output image ratio", choices=QWEN2VLFLUX_ASPECT_RATIO, value="1:1")
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=28)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.5, value=3.5)
with gr.Accordion("Attention Control", open=True):
with gr.Row():
attn_mode = gr.Checkbox(label="Attention Control", value=False)
center_x = gr.Slider(label="X coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5)
center_y = gr.Slider(label="Y coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5)
radius = gr.Slider(label="Radius of attention circle", minimum=0, maximum=1, step=0.01, value=0.5)
with gr.Accordion("ControlNet Settings", open=True):
with gr.Row():
line_mode = gr.Checkbox(label="Line mode", value=True)
line_strength = gr.Slider(label="Line strength", minimum=0, maximum=1, step=0.01, value=0.4)
depth_mode = gr.Checkbox(label="Depth mode", value=True)
depth_strength = gr.Slider(label="Depth strength", minimum=0, maximum=1, step=0.01, value=0.2)
with gr.Column():
#output_image = gr.Image(label="Generated image", type="pil", format="png", show_download_button=True, show_share_button=False)
output_image = ImageSlider(label="Generated image", type="pil")
gr.on(triggers=[submit_button.click, input_text.submit], fn=process,
inputs=[gen_mode, input_image_editor, ref_image, image_url, mask_url, ref_url,
input_text, denoise_strength, num_inference_steps, guidance_scale, aspect_ratio,
attn_mode, center_x, center_y, radius, line_mode, line_strength, depth_mode, depth_strength],
outputs=[output_image], queue=True)
demo.queue().launch(debug=True, show_error=True)
#demo.queue().launch(debug=True, show_error=True, ssr_mode=False) # Gradio 5