Spaces:
Running
Running
import json | |
import os | |
import time | |
import uuid | |
import tempfile | |
from PIL import Image, ImageDraw, ImageFont | |
import gradio as gr | |
import base64 | |
import mimetypes | |
from google import genai | |
from google.genai import types | |
def save_binary_file(file_name, data): | |
with open(file_name, "wb") as f: | |
f.write(data) | |
def generate(text, file_name, api_key, model="gemini-2.0-flash-exp"): | |
# Initialize client using provided api_key (or fallback to env variable) | |
client = genai.Client(api_key=(api_key.strip() if api_key and api_key.strip() != "" | |
else os.environ.get("GEMINI_API_KEY"))) | |
try: | |
print("Uploading file to Gemini API...") | |
files = [ client.files.upload(file=file_name) ] | |
contents = [ | |
types.Content( | |
role="user", | |
parts=[ | |
types.Part.from_uri( | |
file_uri=files[0].uri, | |
mime_type=files[0].mime_type, | |
), | |
types.Part.from_text(text=text), | |
], | |
), | |
] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=0, # Lower temperature for more consistent, conservative results | |
top_p=0.92, | |
max_output_tokens=8192, | |
response_modalities=["image", "text"], | |
response_mime_type="text/plain", | |
# Additional parameters to encourage subtle, natural results | |
safety_settings=[ | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_MEDIUM_AND_ABOVE" | |
} | |
] | |
) | |
text_response = "" | |
image_path = None | |
# Create a temporary file to potentially store image data | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
temp_path = tmp.name | |
print("Sending request to Gemini API...") | |
# Add a timeout to prevent indefinite waiting | |
start_time = time.time() | |
max_wait_time = 60 # Maximum wait time in seconds | |
try: | |
stream = client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
) | |
for chunk in stream: | |
# Check for timeout | |
if time.time() - start_time > max_wait_time: | |
print("Gemini API request timed out after", max_wait_time, "seconds") | |
break | |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
continue | |
candidate = chunk.candidates[0].content.parts[0] | |
# Check for inline image data | |
if candidate.inline_data: | |
save_binary_file(temp_path, candidate.inline_data.data) | |
print(f"Smile enhancement image generated: {temp_path}") | |
image_path = temp_path | |
# If an image is found, we assume that is the desired output. | |
break | |
else: | |
# Accumulate text response if no inline_data is present. | |
text_response += chunk.text + "\n" | |
print("Received text response from Gemini API") | |
except Exception as e: | |
print(f"Error during content generation: {str(e)}") | |
# Continue with the function, returning empty responses | |
except Exception as e: | |
print(f"Error in Gemini API setup: {str(e)}") | |
return None, f"Error: {str(e)}" | |
finally: | |
# Always clean up files | |
try: | |
if 'files' in locals() and files: | |
del files | |
except: | |
pass | |
return image_path, text_response | |
def assess_image_quality(original_image, enhanced_image): | |
""" | |
Assesses the quality of the enhanced image based on specific criteria. | |
Returns a tuple of (is_acceptable, feedback_message) | |
""" | |
try: | |
# Check if enhanced image exists | |
if enhanced_image is None: | |
return False, "No enhanced image generated" | |
# Image dimension checks | |
if enhanced_image.size[0] < 100 or enhanced_image.size[1] < 100: | |
return False, "Enhanced image appears to be too small or improperly sized" | |
# Check that the enhanced image has similar dimensions to the original | |
# This helps ensure facial proportions are maintained | |
width_diff = abs(original_image.size[0] - enhanced_image.size[0]) | |
height_diff = abs(original_image.size[1] - enhanced_image.size[1]) | |
# If dimensions are significantly different, it suggests the image proportions changed | |
if width_diff > 20 or height_diff > 20: | |
return False, "Enhanced image dimensions differ significantly from original, suggesting facial proportions may have changed" | |
# Check image has proper RGB channels for natural skin tones | |
if enhanced_image.mode != 'RGB': | |
return False, "Enhanced image does not have the correct color mode" | |
# For now, we'll do basic checks and assume the model follows guidelines | |
return True, "Image passes quality assessment criteria" | |
except Exception as e: | |
print(f"Error in quality assessment: {str(e)}") | |
# Default to not accepting the image if assessment fails | |
return False, f"Assessment error: {str(e)}" | |
def process_smile_enhancement(input_image, max_attempts=2): | |
try: | |
if input_image is None: | |
return None, "", "" | |
# Get API key from environment variable | |
gemini_api_key = "AIzaSyCVzRDxkuvtaS1B22F_F-zl0ehhXR0nuU8" | |
if not gemini_api_key: | |
print("Error: GEMINI_API_KEY not found in environment variables") | |
return [input_image], "", "API key not configured" | |
# Save the input image to a temporary file | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
input_path = tmp.name | |
input_image.save(input_path) | |
print(f"Input image saved to {input_path}") | |
# Initialize attempt counter and result variables | |
current_attempt = 0 | |
result_img = None | |
feedback_history = [] | |
max_processing_time = 120 # Maximum time in seconds for overall processing | |
start_processing_time = time.time() | |
while current_attempt < max_attempts: | |
# Check if overall processing time exceeded | |
if time.time() - start_processing_time > max_processing_time: | |
print(f"Overall processing time exceeded {max_processing_time} seconds") | |
break | |
current_attempt += 1 | |
print(f"Starting processing attempt {current_attempt}/{max_attempts}...") | |
# Create a comprehensive prompt for true smile enhancement that affects facial features naturally | |
# Adjust prompt based on previous attempts if needed | |
prompt = """ | |
Create a naturally enhanced smile that suits this specific person's face and character. Make the following personalized improvements: | |
- Slightly enhance the existing teeth while PRESERVING their natural color, spacing, and individual characteristics | |
- DO NOT make teeth perfectly white or perfectly aligned - keep some natural variation and character | |
- Create subtle, natural smile lines around the eyes (crow's feet) appropriate for this person's age and face | |
- Slightly raise the cheeks WITHOUT widening the face | |
- Add a slight narrowing of the eyes that happens in genuine smiles | |
- Create subtle dimples ONLY if they already exist in the original image | |
- Enhance the overall joyful expression while maintaining the person's unique facial structure | |
IMPORTANT GUIDELINES: | |
- PRESERVE THE PERSON'S NATURAL DENTAL CHARACTERISTICS - teeth should still look like THEIR teeth, just slightly enhanced | |
- Keep teeth coloration natural and appropriate for the person - avoid unnaturally white teeth | |
- Maintain slight natural imperfections in tooth alignment that give character to the smile | |
- Create a genuine, authentic-looking smile that affects the entire face naturally | |
- ABSOLUTELY CRITICAL: DO NOT widen the face or change face width at all | |
- Preserve the person's identity completely (extremely important) | |
- Preserve exact facial proportions and face width of the original image | |
- Maintain natural-looking results appropriate for the person's age and face structure | |
- Keep teeth proportionate to the face - avoid making them too large or prominent | |
- Maintain proper tooth-to-face ratio and ensure teeth fit naturally within the mouth | |
- Keep the original background, lighting, and image quality intact | |
- Ensure the enhanced smile looks natural, genuine, and believable | |
- Create a smile that looks like a moment of true happiness for THIS specific person | |
- Remember that not everyone has or wants perfect white teeth - the enhancements should SUIT THE INDIVIDUAL | |
- If teeth are enhanced, maintain their natural characteristics while making subtle improvements | |
""" | |
# If not the first attempt, add previous feedback to the prompt | |
if current_attempt > 1 and feedback_history: | |
prompt += """ | |
IMPORTANT FEEDBACK FROM PREVIOUS ATTEMPT: | |
""" + " ".join(feedback_history) + """ | |
Please address these issues in this new attempt. | |
""" | |
# Process silently | |
print(f"Processing attempt {current_attempt}/{max_attempts}...") | |
# Set timeout for individual API call | |
api_call_timeout = time.time() + 45 # 45 second timeout for API call | |
try: | |
# Process the image using Google's Gemini model with timeout | |
image_path, text_response = generate(text=prompt, file_name=input_path, api_key=gemini_api_key) | |
# Check if API call timeout occurred | |
if time.time() > api_call_timeout: | |
print("API call timeout occurred") | |
feedback_history.append("API call timed out, trying again with simplified request.") | |
continue | |
print(f"API response received: Image path: {image_path is not None}, Text length: {len(text_response)}") | |
if image_path: | |
# Load and convert the image if needed | |
try: | |
result_img = Image.open(image_path) | |
if result_img.mode == "RGBA": | |
result_img = result_img.convert("RGB") | |
print("Successfully loaded generated image") | |
# Assess the quality of the enhanced image | |
is_acceptable, assessment_feedback = assess_image_quality(input_image, result_img) | |
print(f"Image quality assessment: {is_acceptable}, {assessment_feedback}") | |
if is_acceptable: | |
# Return the enhanced image with success message | |
success_message = "Successfully loaded generated image\nImage quality assessment: True, Image passes quality assessment criteria" | |
return [result_img], "", success_message | |
else: | |
# Image didn't pass quality assessment, add feedback for next attempt | |
feedback_history.append(assessment_feedback) | |
# If we've reached max attempts, return the best result so far | |
if current_attempt >= max_attempts: | |
print("Max attempts reached, returning best result") | |
return [result_img], "", "" | |
except Exception as img_error: | |
print(f"Error processing the generated image: {str(img_error)}") | |
feedback_history.append(f"Error with image: {str(img_error)}") | |
else: | |
# No image was generated, only text response | |
print("No image was generated, only text response") | |
feedback_history.append("No image was generated in the previous attempt.") | |
# If we've reached max attempts, return the original image | |
if current_attempt >= max_attempts: | |
print("Max attempts reached, returning original image") | |
return [input_image], "", "" | |
except Exception as gen_error: | |
print(f"Error during generation attempt {current_attempt}: {str(gen_error)}") | |
feedback_history.append(f"Error during processing: {str(gen_error)}") | |
# If we've reached max attempts, return the original image | |
if current_attempt >= max_attempts: | |
return [input_image], "", "" | |
# Return the original image as a fallback without messages | |
print("Returning original image as fallback") | |
return [input_image], "", "" | |
except Exception as e: | |
# Return the original image silently on error | |
print(f"Overall error in process_smile_enhancement: {str(e)}") | |
return [input_image], "", "" | |
# Create a clean interface with minimal UI elements | |
with gr.Blocks(title="Smile Enhancement") as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
type="pil", | |
label=None, | |
image_mode="RGB", | |
elem_classes="upload-box" | |
) | |
submit_btn = gr.Button("Enhance Smile with Natural Expressions", elem_classes="generate-btn") | |
with gr.Column(): | |
output_gallery = gr.Gallery(label=None) | |
# Show feedback text to display assessment results | |
feedback_text = gr.Textbox(label="Status", visible=True) | |
# Hidden element for structure | |
output_text = gr.Textbox(visible=False) | |
submit_btn.click( | |
fn=process_smile_enhancement, | |
inputs=[image_input], | |
outputs=[output_gallery, output_text, feedback_text] | |
) | |
demo.queue(max_size=50).launch(share=True) |