Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 38,242 Bytes
107bed2 9f1f12f 107bed2 9f1f12f 107bed2 5650c4a 107bed2 5650c4a 107bed2 5650c4a 107bed2 bb253d0 107bed2 28fe89a 107bed2 bb253d0 107bed2 bf37db9 103de61 107bed2 bf37db9 107bed2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 |
import logging
import os
import random
import time
import traceback
from io import BytesIO
import io
import base64
from openai import OpenAI
import uuid
import requests
import gradio as gr
import requests
from PIL import Image, PngImagePlugin
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# API Configuration
API_TOKEN = os.environ.get("HIDREAM_API_TOKEN")
API_REQUEST_URL = os.environ.get("API_REQUEST_URL")
API_RESULT_URL = os.environ.get("API_RESULT_URL")
API_IMAGE_URL = os.environ.get("API_IMAGE_URL")
API_VERSION = os.environ.get("API_VERSION")
API_MODEL_NAME = os.environ.get("API_MODEL_NAME")
OSS_IMAGE_BUCKET = os.environ.get("OSS_IMAGE_BUCKET")
OSS_MEDIA_BUCKET = os.environ.get("OSS_MEDIA_BUCKET")
OSS_TOKEN_URL = os.environ.get("OSS_TOKEN_URL")
MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", "3"))
POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", "1"))
MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", "300"))
def get_oss_token(is_image=True, prefix="p_"):
head = {
"Cookie": os.environ.get("OSS_AUTH_COOKIE", "")
}
if is_image:
filename = f"p_{uuid.uuid4()}" if prefix == "p_" else f"j_{uuid.uuid4()}"
bucket = OSS_IMAGE_BUCKET
else:
filename = f"{uuid.uuid4()}.mp4"
bucket = OSS_MEDIA_BUCKET
token_url = f"{OSS_TOKEN_URL}{bucket}?filename={filename}"
req = requests.get(token_url, headers=head)
if req.status_code == 200 and req.json()["code"] == 0:
return req.json()["result"], filename
else:
print(req.status_code, req.text)
return None, None
def upload_to_gcs(signed_url: str, file_io, is_image=True):
if is_image:
headers = {
"Content-Type": "image/png", # ensure content-type matches the signed url
}
else:
headers = {
"Content-Type": "video/mp4", # ensure content-type matches the signed url
}
# with open(file_path, "rb") as f:
# response = requests.put(signed_url, data=f, headers=headers)
response = requests.put(signed_url, data=file_io, headers=headers)
if response.status_code == 200:
print("✅ Upload success")
else:
print(f"❌ Upload failed, status code: {response.status_code}, response content: {response.text}")
# Instruction refinement prompt
INSTRUCTION_PROMPT = """Your Role: You are an analytical assistant. Your task is to process a source image and a corresponding editing instruction, assuming the instruction accurately describes a desired transformation. You will 1) describe the source image, 2) output the editing instruction (potentially refined for clarity based on the source image context), and 3) describe the *imagined* result of applying that instruction.
Input:
1. Source Image: The original 'before' image.
2. Source Instruction: A text instruction describing the edit to be performed on the Source Image. You *must assume* this instruction is accurate and feasible for the purpose of this task.
Task Breakdown:
1. **Describe Source Image:** Generate a description (e.g., key subject, setting) of the Source Image by analyzing it. This will be the first line of your output.
2. **Output Editing Instruction:** This step determines the second line of your output.
* **Assumption:** The provided Source Instruction *accurately* describes the desired edit.
* **Goal:** Output a concise, single-line instruction based on the Source Instruction.
* **Refinement based on Source Image:** While the Source Instruction is assumed correct, analyze the Source Image to see if the instruction needs refinement for specificity. If the Source Image contains multiple similar objects and the Source Instruction is potentially ambiguous (e.g., "change the car color" when there are three cars), refine the instruction to be specific, using positional qualifiers (e.g., 'the left car', 'the bird on the top branch'), size ('the smaller dog', 'the largest building'), or other distinguishing visual features apparent in the Source Image. If the Source Instruction is already specific or if there's no ambiguity in the Source Image context, you can use it directly or with minor phrasing adjustments for naturalness. The *core meaning* of the Source Instruction must be preserved.
* **Output:** Present the resulting specific, single-line instruction as the second line.
3. **Describe Imagined Target Image:** Based *only* on the Source Image description (Line 1) and the Editing Instruction (Line 2), generate a description of the *imagined outcome*.
* Describe the scene from Line 1 *as if* the instruction from Line 2 has been successfully applied. Conceptualize the result of the edit on the source description.
* This description must be purely a logical prediction based on applying the instruction (Line 2) to the description in Line 1. Do *not* invent details not implied by the instruction or observed in the source image beyond the specified edit. This will be the third line of your output.
Output Format:
* Your response *must* consist of exactly three lines.
* Do not include any other explanations, comments, introductory phrases, labels (like "Line 1:"), or formatting.
* Your output should be in English.
[Description of the Source Image]
[The specific, single-line editing instruction based on the Source Instruction and Source Image context]
[Description of the Imagined Target Image based on Lines 1 & 2]
Now, please generate the three-line output based on the Source Image and the Source Instruction: {source_instruction}
"""
def filter_response(src_instruction):
try:
src_instruction = src_instruction.strip().split("\n")
src_instruction = [k.strip() for k in src_instruction if k.strip()]
src_instruction = [k for k in src_instruction if len(k) > 0]
if len(src_instruction) != 3:
return ""
instruction = src_instruction[1]
target_description = src_instruction[2]
instruction = instruction.strip().strip(".")
inst_format = "Editing Instruction: {}. Target Image Description: {}"
return inst_format.format(instruction, target_description)
except:
return ""
import httpx
# Create a custom httpx client with verification disabled
insecure_client = httpx.Client(
verify=False, # THIS DISABLES SSL VERIFICATION - SECURITY RISK
timeout=httpx.Timeout(60.0, connect=10.0)
)
def refine_instruction(src_image, src_instruction):
MAX_TOKENS_RESPONSE = 500 # Limit response tokens as output format is structured
client = OpenAI(http_client=insecure_client)
src_image = src_image.convert("RGB")
src_image_buffer = io.BytesIO()
src_image.save(src_image_buffer, format="JPEG")
src_image_buffer.seek(0)
src_base64 = base64.b64encode(src_image_buffer.read()).decode('utf-8')
encoded_str = f"data:image/jpeg;base64,{src_base64}"
image_content = [
{"type": "image_url", "image_url": {"url": encoded_str,}},
]
instruction_text = INSTRUCTION_PROMPT.format(source_instruction=src_instruction)
message_content = [
{"type": "text", "text": instruction_text},
*image_content # Unpack the list of image dictionaries
]
completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a professional digital artist."},
{"role": "user", "content": message_content}
],
max_tokens=MAX_TOKENS_RESPONSE, # Good practice to set max tokens
temperature=0.2 # Lower temperature for more deterministic output
)
evaluation_result = completion.choices[0].message.content
refined_instruction = filter_response(evaluation_result)
return refined_instruction
# Resolution options
ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
# Log configuration details
logger.info(f"API configuration loaded: REQUEST_URL={API_REQUEST_URL}, RESULT_URL={API_RESULT_URL}, VERSION={API_VERSION}, MODEL={API_MODEL_NAME}")
logger.info(f"OSS configuration: IMAGE_BUCKET={OSS_IMAGE_BUCKET}, MEDIA_BUCKET={OSS_MEDIA_BUCKET}, TOKEN_URL={OSS_TOKEN_URL}")
logger.info(f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s")
class APIError(Exception):
"""Custom exception for API-related errors"""
pass
def create_request(prompt, image, guidance_scale=5.0, image_guidance_scale=4.0, seed=-1):
"""
Create an image editing request to the API.
Args:
prompt (str): Text prompt describing the image edit
image (PIL.Image): Input image to edit
guidance_scale (float): Strength of instruction following
image_guidance_scale (float): Strength of image preservation
seed (int): Seed for reproducibility, -1 for random
Returns:
tuple: (task_id, seed) - Task ID if successful and the seed used
Raises:
APIError: If the API request fails
"""
logger.info(f"Starting create_request with prompt='{prompt[:50]}...', guidance_scale={guidance_scale}, image_guidance_scale={image_guidance_scale}, seed={seed}")
image_io = io.BytesIO()
image = image.convert("RGB")
image.save(image_io, format="PNG")
image_io.seek(0)
token_url, filename = get_oss_token(is_image=True)
upload_to_gcs(token_url, image_io, is_image=True)
if not prompt or not prompt.strip():
logger.error("Empty prompt provided to create_request")
raise ValueError("Prompt cannot be empty")
if not image:
logger.error("No image provided to create_request")
raise ValueError("Image cannot be empty")
# Generate random seed if not provided
if seed == -1:
seed = random.randint(1, 1000000)
logger.info(f"Generated random seed: {seed}")
# Validate seed
try:
seed = int(seed)
if seed < -1 or seed > 1000000:
logger.info(f"Invalid seed value: {seed}, forcing to 8888")
seed = 8888
except (TypeError, ValueError) as e:
logger.error(f"Seed validation failed: {str(e)}")
raise ValueError(f"Seed must be an integer but got {seed}")
headers = {
"Authorization": f"Bearer {API_TOKEN}",
"X-accept-language": "en",
"X-source": "api",
"Content-Type": "application/json",
}
generate_data = {
"module": "image_edit",
"images": [filename, ],
"prompt": prompt,
"params": {
"seed": seed,
"custom_params": {
"sample_steps": 28,
"guidance_scale": guidance_scale,
"image_guidance_scale": image_guidance_scale
},
},
"version": API_VERSION,
}
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(f"Sending API request [attempt {retry_count+1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'")
response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10)
# Log response status code
logger.info(f"API request response status: {response.status_code}")
response.raise_for_status()
result = response.json()
if not result or "result" not in result:
logger.error(f"Invalid API response format: {str(result)}")
raise APIError(f"Invalid response format from API when sending request: {str(result)}")
task_id = result.get("result", {}).get("task_id")
if not task_id:
logger.error(f"No task ID in API response: {str(result)}")
raise APIError(f"No task ID returned from API: {str(result)}")
logger.info(f"Successfully created task with ID: {task_id}, seed: {seed}")
return task_id, seed
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
error_message = f"HTTP error {status_code}"
try:
error_detail = e.response.json()
error_message += f": {error_detail}"
logger.error(f"API response error content: {error_detail}")
except:
logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}")
if status_code == 401:
logger.error(f"Authentication failed with API token. Status code: {status_code}")
raise APIError("Authentication failed. Please check your API token.")
elif status_code == 429:
retry_count += 1
wait_time = min(2 ** retry_count, 10) # Exponential backoff
logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(wait_time)
elif 400 <= status_code < 500:
try:
error_detail = e.response.json()
error_message += f": {error_detail.get('message', 'Client error')}"
except:
pass
logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}")
raise APIError(error_message)
else:
retry_count += 1
logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.RequestException as e:
logger.error(f"Request error: {str(e)}")
logger.debug(f"Request error details: {traceback.format_exc()}")
raise APIError(f"Failed to connect to API: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in create_request: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise APIError(f"Unexpected error: {str(e)}")
logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'")
raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
def get_results(task_id):
"""
Check the status of an image generation task.
Args:
task_id (str): The task ID to check
Returns:
dict: Task result information
Raises:
APIError: If the API request fails
"""
logger.debug(f"Checking status for task ID: {task_id}")
if not task_id:
logger.error("Empty task ID provided to get_results")
raise ValueError("Task ID cannot be empty")
url = f"{API_RESULT_URL}?task_id={task_id}"
headers = {
"Authorization": f"Bearer {API_TOKEN}",
"X-accept-language": "en",
}
try:
response = requests.get(url, headers=headers, timeout=10)
logger.debug(f"Status check response code: {response.status_code}")
response.raise_for_status()
result = response.json()
if not result or "result" not in result:
logger.warning(f"Invalid response format from API when checking task {task_id}: {str(result)}")
raise APIError(f"Invalid response format from API when checking task {task_id}: {str(result)}")
return result
except requests.exceptions.Timeout:
logger.warning(f"Request timed out when checking task {task_id}")
return None
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
logger.warning(f"HTTP error {status_code} when checking task {task_id}")
try:
error_content = e.response.json()
logger.error(f"Error response content: {error_content}")
except:
logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}")
if status_code == 401:
logger.error(f"Authentication failed when checking task {task_id}")
raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}")
elif 400 <= status_code < 500:
try:
error_detail = e.response.json()
error_message = f"HTTP error {status_code}: {error_detail.get('message', 'Client error')}"
except:
error_message = f"HTTP error {status_code}"
logger.error(error_message)
return None
else:
logger.warning(f"Server error {status_code} when checking task {task_id}")
return None
except requests.exceptions.RequestException as e:
logger.warning(f"Network error when checking task {task_id}: {str(e)}")
logger.debug(f"Network error details: {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return None
def download_image(image_url):
"""
Download an image from a URL and return it as a PIL Image.
Converts WebP to PNG format while preserving original image data.
Args:
image_url (str): URL of the image
Returns:
PIL.Image: Downloaded image object converted to PNG format
Raises:
APIError: If the download fails
"""
logger.info(f"Starting download_image from URL: {image_url}")
if not image_url:
logger.error("Empty image URL provided to download_image")
raise ValueError("Image URL cannot be empty when downloading image")
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(f"Downloading image [attempt {retry_count+1}/{MAX_RETRY_COUNT}] from {image_url}")
response = requests.get(image_url, timeout=15)
logger.debug(f"Image download response status: {response.status_code}, Content-Type: {response.headers.get('Content-Type')}, Content-Length: {response.headers.get('Content-Length')}")
response.raise_for_status()
# Open the image from response content
image = Image.open(BytesIO(response.content))
logger.info(f"Image opened successfully. Format: {image.format}, Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
# Get original metadata before conversion
original_metadata = {}
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
original_metadata[key] = value
logger.debug(f"Original image metadata: {original_metadata}")
# Convert to PNG regardless of original format (WebP, JPEG, etc.)
if image.format != 'PNG':
logger.info(f"Converting image from {image.format} to PNG format")
png_buffer = BytesIO()
# If the image has an alpha channel, preserve it, otherwise convert to RGB
if 'A' in image.getbands():
logger.debug("Preserving alpha channel in image conversion")
image_to_save = image
else:
logger.debug("Converting image to RGB mode")
image_to_save = image.convert('RGB')
image_to_save.save(png_buffer, format='PNG')
png_buffer.seek(0)
image = Image.open(png_buffer)
logger.debug(f"Image converted to PNG. New size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
# Preserve original metadata
for key, value in original_metadata.items():
image.info[key] = value
logger.debug("Original metadata preserved in converted image")
logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
return image
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
logger.error(f"HTTP error {status_code} when downloading image from {image_url}")
try:
error_content = e.response.text[:500]
logger.error(f"Error response content: {error_content}")
except:
logger.error("Could not read error response content")
if 400 <= status_code < 500:
error_message = f"HTTP error {status_code} when downloading image"
logger.error(error_message)
raise APIError(error_message)
else:
retry_count += 1
logger.warning(f"Server error {status_code}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.RequestException as e:
retry_count += 1
logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
logger.debug(f"Network error details: {traceback.format_exc()}")
time.sleep(1)
except Exception as e:
logger.error(f"Error processing image from {image_url}: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise APIError(f"Failed to process image: {str(e)}")
logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries")
raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
def add_metadata_to_image(image, metadata):
"""
Add metadata to a PIL image.
Args:
image (PIL.Image): The image to add metadata to
metadata (dict): Metadata to add to the image
Returns:
PIL.Image: Image with metadata
"""
logger.debug(f"Adding metadata to image: {metadata}")
if not image:
logger.error("Null image provided to add_metadata_to_image")
return None
try:
# Get any existing metadata
existing_metadata = {}
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
existing_metadata[key] = value
logger.debug(f"Existing image metadata: {existing_metadata}")
# Merge with new metadata (new values override existing ones)
all_metadata = {**existing_metadata, **metadata}
logger.debug(f"Combined metadata: {all_metadata}")
# Create a new metadata dictionary for PNG
meta = PngImagePlugin.PngInfo()
# Add each metadata item
for key, value in all_metadata.items():
meta.add_text(key, str(value))
# Save with metadata to a buffer
buffer = BytesIO()
image.save(buffer, format='PNG', pnginfo=meta)
logger.debug("Image saved to buffer with metadata")
# Reload the image from the buffer
buffer.seek(0)
result_image = Image.open(buffer)
logger.debug("Image reloaded from buffer with metadata")
return result_image
except Exception as e:
logger.error(f"Failed to add metadata to image: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return image # Return original image if metadata addition fails
# Create Gradio interface
def create_ui():
logger.info("Creating Gradio UI")
with gr.Blocks(title="HiDream-E1-Full Image Editor", theme=gr.themes.Base()) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gr.Markdown("""
# HiDream-E1-Full Image Editor
Edit images using natural language instructions with state-of-the-art AI [🤗 HuggingFace](https://huggingface.co./HiDream-ai/HiDream-E1-Full) | [GitHub](https://github.com/HiDream-ai/HiDream-E1) | [Twitter](https://x.com/vivago_ai)
<span style="color: #FF5733; font-weight: bold">For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/).</span>
""")
with gr.Row():
# Input column
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Input Image",
# height=400,
show_download_button=True,
show_label=True,
scale=1,
container=True,
image_mode="RGB"
)
instruction = gr.Textbox(
label="Editing Instruction",
placeholder="e.g., convert the image into a Ghibli style",
lines=3
)
gr.Markdown("""
<div style="padding: 8px; margin-bottom: 10px; background-color: #E2F0FF; border-left: 5px solid #2E86DE; color: #2C3E50;">
<strong>Note:</strong> For optimal results, we recommend using the <strong>Refine Instruction</strong> button which formats your input into:
<br><em>"Editing Instruction: [your instruction]. Target Image Description: [expected result]"</em>
</div>
""")
with gr.Row():
refine_btn = gr.Button("Refine Instruction")
generate_btn = gr.Button("Generate", variant="primary", size="lg")
with gr.Accordion("Advanced Settings", open=True):
gr.Markdown("""
<div style="padding: 8px; margin: 15px 0; background-color: #FFF3CD; border-left: 5px solid #FFDD57; color: #856404;">
<strong>Important:</strong> Adjust these parameters based on your editing needs:
<ul>
<li>For style changes, use higher image preservation strength (e.g., 3.0-4.0)</li>
<li>For local edits like adding, deleting, replacing elements, use lower image preservation strength (e.g., 2.0-3.0)</li>
<li>If you notice visual artifacts or distortions in the generated image, try <strong>reduce the image preservation strength value</strong>.</li>
</ul>
</div>
""")
with gr.Row():
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
step=0.1,
value=5.0,
label="Instruction Following Strength"
)
image_guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
step=0.1,
value=3.0,
label="Image Preservation Strength"
)
seed = gr.Number(
label="Seed (use -1 for random)",
value=82706,
precision=0
)
progress = gr.Progress(track_tqdm=False)
# Output column
with gr.Column(scale=1):
output_image = gr.Image(
label="Generated Image",
type="pil",
# height=400,
interactive=False,
show_download_button=True,
scale=1,
container=True,
image_mode="RGB"
)
with gr.Accordion("Image Information", open=False):
image_info = gr.JSON(label="Details")
def refine_instruction_ui(image, instruction):
if not image or not instruction:
return instruction
try:
refined = refine_instruction(image, instruction)
if len(refined) > 0:
return refined
else:
logger.warning("Instruction refinement service returned empty result")
gr.Warning("Instruction refinement service is currently not working. Please try again later.")
return instruction
except Exception as e:
logger.error(f"Error refining instruction: {str(e)}")
gr.Warning("Instruction refinement service is currently not working. Please try again later.")
return instruction
# Generate function with progress updates
def generate_with_progress(image, instruction, seed, guidance_scale, image_guidance_scale, progress=gr.Progress()):
logger.info(f"Starting image generation with instruction='{instruction[:50]}...', seed={seed}")
try:
if not image:
logger.error("No image provided in UI")
return None, None
if not instruction.strip():
logger.error("Empty instruction provided in UI")
return None, None
# Create request
logger.info("Creating API request")
task_id, used_seed = create_request(
prompt=instruction,
image=image,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
seed=seed
)
# Poll for results
start_time = time.time()
last_completion_ratio = 0
progress(0, desc="Initializing...")
logger.info(f"Starting to poll for results for task ID: {task_id}")
while time.time() - start_time < MAX_POLL_TIME:
result = get_results(task_id)
if not result:
time.sleep(POLL_INTERVAL)
continue
sub_results = result.get("result", {}).get("sub_task_results", [])
if not sub_results:
time.sleep(POLL_INTERVAL)
continue
status = sub_results[0].get("task_status")
logger.debug(f"Task status for ID {task_id}: {status}")
# Get and display completion ratio
completion_ratio = sub_results[0].get('task_completion', 0) * 100
if completion_ratio != last_completion_ratio:
# Only update UI when completion ratio changes
last_completion_ratio = completion_ratio
progress(completion_ratio / 100, desc=f"Generating image")
logger.info(f"Generation progress - Task ID: {task_id}, Completion: {completion_ratio:.1f}%")
# Check task status
if status == 1: # Success
logger.info(f"Task completed successfully - Task ID: {task_id}")
progress(1.0, desc="Generation complete")
image_name = sub_results[0].get("image")
if not image_name:
logger.error(f"No image name in successful response. Response: {sub_results[0]}")
return None, None
image_url = f"{API_IMAGE_URL}{image_name}.png"
logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}")
image = download_image(image_url)
if image:
# Add metadata to the image
logger.info(f"Adding metadata to image - Task ID: {task_id}")
metadata = {
"prompt": instruction,
"seed": str(used_seed),
"model": API_MODEL_NAME,
"guidance_scale": str(guidance_scale),
"image_guidance_scale": str(image_guidance_scale),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"generated_by": "HiDream-E1-Full Editor"
}
image_with_metadata = add_metadata_to_image(image, metadata)
# Create info for display
info = {
"model": API_MODEL_NAME,
"prompt": instruction,
"seed": used_seed,
"guidance_scale": guidance_scale,
"image_guidance_scale": image_guidance_scale,
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S")
}
logger.info(f"Image generation complete - Task ID: {task_id}")
return image_with_metadata, info
else:
logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}")
return None, None
elif status in {3, 4}: # Failed or Canceled
error_msg = sub_results[0].get("task_error", "Unknown error")
logger.error(f"Task failed - Task ID: {task_id}, Status: {status}, Error: {error_msg}")
return None, None
time.sleep(POLL_INTERVAL)
logger.error(f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s")
return None, None
except Exception as e:
logger.error(f"Error during image generation: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return None, None
# Set up event handlers
refine_btn.click(
fn=refine_instruction_ui,
inputs=[input_image, instruction],
outputs=[instruction]
)
generate_btn.click(
fn=generate_with_progress,
inputs=[input_image, instruction, seed, guidance_scale, image_guidance_scale],
outputs=[output_image, image_info]
)
# Define a combined function to refine instruction and then generate image
def refine_and_generate(image, instruction, seed, guidance_scale, image_guidance_scale, progress=gr.Progress()):
try:
# First refine the instruction
if not image or not instruction:
return None, None, instruction
logger.info(f"Refining instruction: '{instruction[:50]}...'")
refined_instruction = refine_instruction_ui(image, instruction)
if not refined_instruction or refined_instruction.strip() == "":
logger.warning("Instruction refinement failed, using original instruction")
refined_instruction = instruction
gr.Warning("Instruction refinement failed, using original instruction instead.")
else:
logger.info(f"Instruction refined to: '{refined_instruction[:50]}...'")
# Then generate with the refined instruction
progress(0.2, desc="Instruction refined, generating image...")
generated_image, image_info = generate_with_progress(image, refined_instruction, seed, guidance_scale, image_guidance_scale, progress)
return generated_image, image_info, refined_instruction
except Exception as e:
logger.error(f"Error in refine_and_generate: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
gr.Warning(f"An error occurred during processing: {str(e)}")
return None, None, instruction
# Examples
gr.Examples(
examples=[
["assets/test_1.png", "convert the image into a Ghibli style",82706, 5, 4],
["assets/test_1.png", "change the image into Disney Pixar style",82706, 5, 4],
["assets/test_1.png", "add a sunglasses to the girl",82706, 5, 2],
["assets/test_2.jpg", "convert this image into a ink sketch image",82706, 5, 2],
["assets/test_2.jpg", "add butterfly",82706, 5, 2],
["assets/test_2.jpg", "remove the wooden sign",82706, 5, 2],
],
inputs=[input_image, instruction, seed, guidance_scale, image_guidance_scale],
outputs=[output_image, image_info, instruction],
fn=refine_and_generate,
cache_examples=True,
# cache_mode = "lazy"
)
logger.info("Gradio UI created successfully")
return demo
# Launch app
if __name__ == "__main__":
logger.info("Starting HiDream-E1-Full Image Generator application")
demo = create_ui()
logger.info("Launching Gradio interface with queue")
demo.queue(max_size=50, default_concurrency_limit=8).launch(show_api=False)
logger.info("Application shutdown") |