import logging import os import random import time import traceback from io import BytesIO import io import base64 from openai import OpenAI import uuid import requests import gradio as gr import requests from PIL import Image, PngImagePlugin # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # API Configuration API_TOKEN = os.environ.get("HIDREAM_API_TOKEN") API_REQUEST_URL = os.environ.get("API_REQUEST_URL") API_RESULT_URL = os.environ.get("API_RESULT_URL") API_IMAGE_URL = os.environ.get("API_IMAGE_URL") API_VERSION = os.environ.get("API_VERSION") API_MODEL_NAME = os.environ.get("API_MODEL_NAME") OSS_IMAGE_BUCKET = os.environ.get("OSS_IMAGE_BUCKET") OSS_MEDIA_BUCKET = os.environ.get("OSS_MEDIA_BUCKET") OSS_TOKEN_URL = os.environ.get("OSS_TOKEN_URL") MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", "3")) POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", "1")) MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", "300")) def get_oss_token(is_image=True, prefix="p_"): head = { "Cookie": os.environ.get("OSS_AUTH_COOKIE", "") } if is_image: filename = f"p_{uuid.uuid4()}" if prefix == "p_" else f"j_{uuid.uuid4()}" bucket = OSS_IMAGE_BUCKET else: filename = f"{uuid.uuid4()}.mp4" bucket = OSS_MEDIA_BUCKET token_url = f"{OSS_TOKEN_URL}{bucket}?filename={filename}" req = requests.get(token_url, headers=head) if req.status_code == 200 and req.json()["code"] == 0: return req.json()["result"], filename else: print(req.status_code, req.text) return None, None def upload_to_gcs(signed_url: str, file_io, is_image=True): if is_image: headers = { "Content-Type": "image/png", # ensure content-type matches the signed url } else: headers = { "Content-Type": "video/mp4", # ensure content-type matches the signed url } # with open(file_path, "rb") as f: # response = requests.put(signed_url, data=f, headers=headers) response = requests.put(signed_url, data=file_io, headers=headers) if response.status_code == 200: print("✅ Upload success") else: print(f"❌ Upload failed, status code: {response.status_code}, response content: {response.text}") # Instruction refinement prompt INSTRUCTION_PROMPT = """Your Role: You are an analytical assistant. Your task is to process a source image and a corresponding editing instruction, assuming the instruction accurately describes a desired transformation. You will 1) describe the source image, 2) output the editing instruction (potentially refined for clarity based on the source image context), and 3) describe the *imagined* result of applying that instruction. Input: 1. Source Image: The original 'before' image. 2. Source Instruction: A text instruction describing the edit to be performed on the Source Image. You *must assume* this instruction is accurate and feasible for the purpose of this task. Task Breakdown: 1. **Describe Source Image:** Generate a description (e.g., key subject, setting) of the Source Image by analyzing it. This will be the first line of your output. 2. **Output Editing Instruction:** This step determines the second line of your output. * **Assumption:** The provided Source Instruction *accurately* describes the desired edit. * **Goal:** Output a concise, single-line instruction based on the Source Instruction. * **Refinement based on Source Image:** While the Source Instruction is assumed correct, analyze the Source Image to see if the instruction needs refinement for specificity. If the Source Image contains multiple similar objects and the Source Instruction is potentially ambiguous (e.g., "change the car color" when there are three cars), refine the instruction to be specific, using positional qualifiers (e.g., 'the left car', 'the bird on the top branch'), size ('the smaller dog', 'the largest building'), or other distinguishing visual features apparent in the Source Image. If the Source Instruction is already specific or if there's no ambiguity in the Source Image context, you can use it directly or with minor phrasing adjustments for naturalness. The *core meaning* of the Source Instruction must be preserved. * **Output:** Present the resulting specific, single-line instruction as the second line. 3. **Describe Imagined Target Image:** Based *only* on the Source Image description (Line 1) and the Editing Instruction (Line 2), generate a description of the *imagined outcome*. * Describe the scene from Line 1 *as if* the instruction from Line 2 has been successfully applied. Conceptualize the result of the edit on the source description. * This description must be purely a logical prediction based on applying the instruction (Line 2) to the description in Line 1. Do *not* invent details not implied by the instruction or observed in the source image beyond the specified edit. This will be the third line of your output. Output Format: * Your response *must* consist of exactly three lines. * Do not include any other explanations, comments, introductory phrases, labels (like "Line 1:"), or formatting. * Your output should be in English. [Description of the Source Image] [The specific, single-line editing instruction based on the Source Instruction and Source Image context] [Description of the Imagined Target Image based on Lines 1 & 2] Now, please generate the three-line output based on the Source Image and the Source Instruction: {source_instruction} """ def filter_response(src_instruction): try: src_instruction = src_instruction.strip().split("\n") src_instruction = [k.strip() for k in src_instruction if k.strip()] src_instruction = [k for k in src_instruction if len(k) > 0] if len(src_instruction) != 3: return "" instruction = src_instruction[1] target_description = src_instruction[2] instruction = instruction.strip().strip(".") inst_format = "Editing Instruction: {}. Target Image Description: {}" return inst_format.format(instruction, target_description) except: return "" import httpx # Create a custom httpx client with verification disabled insecure_client = httpx.Client( verify=False, # THIS DISABLES SSL VERIFICATION - SECURITY RISK timeout=httpx.Timeout(60.0, connect=10.0) ) def refine_instruction(src_image, src_instruction): MAX_TOKENS_RESPONSE = 500 # Limit response tokens as output format is structured client = OpenAI(http_client=insecure_client) src_image = src_image.convert("RGB") src_image_buffer = io.BytesIO() src_image.save(src_image_buffer, format="JPEG") src_image_buffer.seek(0) src_base64 = base64.b64encode(src_image_buffer.read()).decode('utf-8') encoded_str = f"data:image/jpeg;base64,{src_base64}" image_content = [ {"type": "image_url", "image_url": {"url": encoded_str,}}, ] instruction_text = INSTRUCTION_PROMPT.format(source_instruction=src_instruction) message_content = [ {"type": "text", "text": instruction_text}, *image_content # Unpack the list of image dictionaries ] completion = client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": "You are a professional digital artist."}, {"role": "user", "content": message_content} ], max_tokens=MAX_TOKENS_RESPONSE, # Good practice to set max tokens temperature=0.2 # Lower temperature for more deterministic output ) evaluation_result = completion.choices[0].message.content refined_instruction = filter_response(evaluation_result) return refined_instruction # Resolution options ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"] # Log configuration details logger.info(f"API configuration loaded: REQUEST_URL={API_REQUEST_URL}, RESULT_URL={API_RESULT_URL}, VERSION={API_VERSION}, MODEL={API_MODEL_NAME}") logger.info(f"OSS configuration: IMAGE_BUCKET={OSS_IMAGE_BUCKET}, MEDIA_BUCKET={OSS_MEDIA_BUCKET}, TOKEN_URL={OSS_TOKEN_URL}") logger.info(f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s") class APIError(Exception): """Custom exception for API-related errors""" pass def create_request(prompt, image, guidance_scale=5.0, image_guidance_scale=4.0, seed=-1): """ Create an image editing request to the API. Args: prompt (str): Text prompt describing the image edit image (PIL.Image): Input image to edit guidance_scale (float): Strength of instruction following image_guidance_scale (float): Strength of image preservation seed (int): Seed for reproducibility, -1 for random Returns: tuple: (task_id, seed) - Task ID if successful and the seed used Raises: APIError: If the API request fails """ logger.info(f"Starting create_request with prompt='{prompt[:50]}...', guidance_scale={guidance_scale}, image_guidance_scale={image_guidance_scale}, seed={seed}") image_io = io.BytesIO() image = image.convert("RGB") image.save(image_io, format="PNG") image_io.seek(0) token_url, filename = get_oss_token(is_image=True) upload_to_gcs(token_url, image_io, is_image=True) if not prompt or not prompt.strip(): logger.error("Empty prompt provided to create_request") raise ValueError("Prompt cannot be empty") if not image: logger.error("No image provided to create_request") raise ValueError("Image cannot be empty") # Generate random seed if not provided if seed == -1: seed = random.randint(1, 1000000) logger.info(f"Generated random seed: {seed}") # Validate seed try: seed = int(seed) if seed < -1 or seed > 1000000: logger.info(f"Invalid seed value: {seed}, forcing to 8888") seed = 8888 except (TypeError, ValueError) as e: logger.error(f"Seed validation failed: {str(e)}") raise ValueError(f"Seed must be an integer but got {seed}") headers = { "Authorization": f"Bearer {API_TOKEN}", "X-accept-language": "en", "X-source": "api", "Content-Type": "application/json", } generate_data = { "module": "image_edit", "images": [filename, ], "prompt": prompt, "params": { "seed": seed, "custom_params": { "sample_steps": 28, "guidance_scale": guidance_scale, "image_guidance_scale": image_guidance_scale }, }, "version": API_VERSION, } retry_count = 0 while retry_count < MAX_RETRY_COUNT: try: logger.info(f"Sending API request [attempt {retry_count+1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'") response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10) # Log response status code logger.info(f"API request response status: {response.status_code}") response.raise_for_status() result = response.json() if not result or "result" not in result: logger.error(f"Invalid API response format: {str(result)}") raise APIError(f"Invalid response format from API when sending request: {str(result)}") task_id = result.get("result", {}).get("task_id") if not task_id: logger.error(f"No task ID in API response: {str(result)}") raise APIError(f"No task ID returned from API: {str(result)}") logger.info(f"Successfully created task with ID: {task_id}, seed: {seed}") return task_id, seed except requests.exceptions.Timeout: retry_count += 1 logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.HTTPError as e: status_code = e.response.status_code error_message = f"HTTP error {status_code}" try: error_detail = e.response.json() error_message += f": {error_detail}" logger.error(f"API response error content: {error_detail}") except: logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}") if status_code == 401: logger.error(f"Authentication failed with API token. Status code: {status_code}") raise APIError("Authentication failed. Please check your API token.") elif status_code == 429: retry_count += 1 wait_time = min(2 ** retry_count, 10) # Exponential backoff logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(wait_time) elif 400 <= status_code < 500: try: error_detail = e.response.json() error_message += f": {error_detail.get('message', 'Client error')}" except: pass logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}") raise APIError(error_message) else: retry_count += 1 logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.RequestException as e: logger.error(f"Request error: {str(e)}") logger.debug(f"Request error details: {traceback.format_exc()}") raise APIError(f"Failed to connect to API: {str(e)}") except Exception as e: logger.error(f"Unexpected error in create_request: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") raise APIError(f"Unexpected error: {str(e)}") logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'") raise APIError(f"Failed after {MAX_RETRY_COUNT} retries") def get_results(task_id): """ Check the status of an image generation task. Args: task_id (str): The task ID to check Returns: dict: Task result information Raises: APIError: If the API request fails """ logger.debug(f"Checking status for task ID: {task_id}") if not task_id: logger.error("Empty task ID provided to get_results") raise ValueError("Task ID cannot be empty") url = f"{API_RESULT_URL}?task_id={task_id}" headers = { "Authorization": f"Bearer {API_TOKEN}", "X-accept-language": "en", } try: response = requests.get(url, headers=headers, timeout=10) logger.debug(f"Status check response code: {response.status_code}") response.raise_for_status() result = response.json() if not result or "result" not in result: logger.warning(f"Invalid response format from API when checking task {task_id}: {str(result)}") raise APIError(f"Invalid response format from API when checking task {task_id}: {str(result)}") return result except requests.exceptions.Timeout: logger.warning(f"Request timed out when checking task {task_id}") return None except requests.exceptions.HTTPError as e: status_code = e.response.status_code logger.warning(f"HTTP error {status_code} when checking task {task_id}") try: error_content = e.response.json() logger.error(f"Error response content: {error_content}") except: logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}") if status_code == 401: logger.error(f"Authentication failed when checking task {task_id}") raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}") elif 400 <= status_code < 500: try: error_detail = e.response.json() error_message = f"HTTP error {status_code}: {error_detail.get('message', 'Client error')}" except: error_message = f"HTTP error {status_code}" logger.error(error_message) return None else: logger.warning(f"Server error {status_code} when checking task {task_id}") return None except requests.exceptions.RequestException as e: logger.warning(f"Network error when checking task {task_id}: {str(e)}") logger.debug(f"Network error details: {traceback.format_exc()}") return None except Exception as e: logger.error(f"Unexpected error when checking task {task_id}: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") return None def download_image(image_url): """ Download an image from a URL and return it as a PIL Image. Converts WebP to PNG format while preserving original image data. Args: image_url (str): URL of the image Returns: PIL.Image: Downloaded image object converted to PNG format Raises: APIError: If the download fails """ logger.info(f"Starting download_image from URL: {image_url}") if not image_url: logger.error("Empty image URL provided to download_image") raise ValueError("Image URL cannot be empty when downloading image") retry_count = 0 while retry_count < MAX_RETRY_COUNT: try: logger.info(f"Downloading image [attempt {retry_count+1}/{MAX_RETRY_COUNT}] from {image_url}") response = requests.get(image_url, timeout=15) logger.debug(f"Image download response status: {response.status_code}, Content-Type: {response.headers.get('Content-Type')}, Content-Length: {response.headers.get('Content-Length')}") response.raise_for_status() # Open the image from response content image = Image.open(BytesIO(response.content)) logger.info(f"Image opened successfully. Format: {image.format}, Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}") # Get original metadata before conversion original_metadata = {} for key, value in image.info.items(): if isinstance(key, str) and isinstance(value, str): original_metadata[key] = value logger.debug(f"Original image metadata: {original_metadata}") # Convert to PNG regardless of original format (WebP, JPEG, etc.) if image.format != 'PNG': logger.info(f"Converting image from {image.format} to PNG format") png_buffer = BytesIO() # If the image has an alpha channel, preserve it, otherwise convert to RGB if 'A' in image.getbands(): logger.debug("Preserving alpha channel in image conversion") image_to_save = image else: logger.debug("Converting image to RGB mode") image_to_save = image.convert('RGB') image_to_save.save(png_buffer, format='PNG') png_buffer.seek(0) image = Image.open(png_buffer) logger.debug(f"Image converted to PNG. New size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}") # Preserve original metadata for key, value in original_metadata.items(): image.info[key] = value logger.debug("Original metadata preserved in converted image") logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}") return image except requests.exceptions.Timeout: retry_count += 1 logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.HTTPError as e: status_code = e.response.status_code logger.error(f"HTTP error {status_code} when downloading image from {image_url}") try: error_content = e.response.text[:500] logger.error(f"Error response content: {error_content}") except: logger.error("Could not read error response content") if 400 <= status_code < 500: error_message = f"HTTP error {status_code} when downloading image" logger.error(error_message) raise APIError(error_message) else: retry_count += 1 logger.warning(f"Server error {status_code}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.RequestException as e: retry_count += 1 logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") logger.debug(f"Network error details: {traceback.format_exc()}") time.sleep(1) except Exception as e: logger.error(f"Error processing image from {image_url}: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") raise APIError(f"Failed to process image: {str(e)}") logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries") raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries") def add_metadata_to_image(image, metadata): """ Add metadata to a PIL image. Args: image (PIL.Image): The image to add metadata to metadata (dict): Metadata to add to the image Returns: PIL.Image: Image with metadata """ logger.debug(f"Adding metadata to image: {metadata}") if not image: logger.error("Null image provided to add_metadata_to_image") return None try: # Get any existing metadata existing_metadata = {} for key, value in image.info.items(): if isinstance(key, str) and isinstance(value, str): existing_metadata[key] = value logger.debug(f"Existing image metadata: {existing_metadata}") # Merge with new metadata (new values override existing ones) all_metadata = {**existing_metadata, **metadata} logger.debug(f"Combined metadata: {all_metadata}") # Create a new metadata dictionary for PNG meta = PngImagePlugin.PngInfo() # Add each metadata item for key, value in all_metadata.items(): meta.add_text(key, str(value)) # Save with metadata to a buffer buffer = BytesIO() image.save(buffer, format='PNG', pnginfo=meta) logger.debug("Image saved to buffer with metadata") # Reload the image from the buffer buffer.seek(0) result_image = Image.open(buffer) logger.debug("Image reloaded from buffer with metadata") return result_image except Exception as e: logger.error(f"Failed to add metadata to image: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") return image # Return original image if metadata addition fails # Create Gradio interface def create_ui(): logger.info("Creating Gradio UI") with gr.Blocks(title="HiDream-E1-Full Image Editor", theme=gr.themes.Base()) as demo: with gr.Row(equal_height=True): with gr.Column(scale=1): gr.Markdown(""" # HiDream-E1-Full Image Editor Edit images using natural language instructions with state-of-the-art AI [🤗 HuggingFace](https://huggingface.co./HiDream-ai/HiDream-E1-Full) | [GitHub](https://github.com/HiDream-ai/HiDream-E1) | [Twitter](https://x.com/vivago_ai) For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/). """) with gr.Row(): # Input column with gr.Column(scale=1): input_image = gr.Image( type="pil", label="Input Image", # height=400, show_download_button=True, show_label=True, scale=1, container=True, image_mode="RGB" ) instruction = gr.Textbox( label="Editing Instruction", placeholder="e.g., convert the image into a Ghibli style", lines=3 ) gr.Markdown("""