Rightlight / app.py
mike23415's picture
Update app.py
8505235 verified
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
import tempfile
from skimage import measure
import trimesh
import torch.nn.functional as F
import torchvision.transforms as transforms
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define a simple neural network to extract depth from images
class SimpleDepthNet(torch.nn.Module):
def __init__(self):
super(SimpleDepthNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv4 = torch.nn.Conv2d(128, 1, kernel_size=3, padding=1)
self.pool = torch.nn.MaxPool2d(2, 2)
self.upsample = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x):
# Encoder
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
# Decoder
x = self.upsample(x)
x = F.relu(self.conv3(x))
x = self.upsample(x)
x = torch.sigmoid(self.conv4(x))
return x
# Initialize the model
model = SimpleDepthNet().to(device)
# Define transformation for input images
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
def image_to_3d(image):
"""
Convert a single image to a 3D model using a simple depth extraction approach
"""
if image is None:
return None, "No image provided"
try:
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Generate depth map
with torch.no_grad():
depth = model(img_tensor)[0, 0].cpu().numpy()
# Convert depth map to 3D points
h, w = depth.shape
y, x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
# Normalize coordinates
x = (x - w/2) / max(w, h)
y = (y - h/2) / max(w, h)
z = depth - 0.5 # Center around zero
# Create point cloud
points = np.stack([x.flatten(), y.flatten(), z.flatten()], axis=1)
# Get colors from original image
img_np = np.array(image.resize((w, h))) / 255.0
colors = img_np.reshape(-1, 3)
# Create a mesh from the point cloud (using marching cubes on the depth map)
verts, faces, _, _ = measure.marching_cubes(depth, 0.5)
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
# Save as OBJ
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file:
obj_path = obj_file.name
mesh.export(obj_path)
# Also save as PLY for better compatibility with Unity
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as ply_file:
ply_path = ply_file.name
mesh.export(ply_path)
return [obj_path, ply_path], "3D model generated successfully!"
except Exception as e:
return None, f"Error: {str(e)}"
def process_image(image):
try:
if image is None:
return None, None, "Please upload an image first."
results, message = image_to_3d(image)
if results:
return results[0], results[1], message
else:
return None, None, message
except Exception as e:
return None, None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Simple Image to 3D Converter") as demo:
gr.Markdown("# Simple Image to 3D Converter")
gr.Markdown("Upload an image to convert it to a simple 3D model that you can use in Unity or other engines.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
submit_btn = gr.Button("Convert to 3D")
with gr.Column(scale=1):
obj_file = gr.File(label="OBJ File (for editing)")
ply_file = gr.File(label="PLY File (for Unity)")
output_message = gr.Textbox(label="Output Message")
submit_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[obj_file, ply_file, output_message]
)
# Launch the app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)