Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import tempfile | |
from skimage import measure | |
import trimesh | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
# Check if CUDA is available, otherwise use CPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Define a simple neural network to extract depth from images | |
class SimpleDepthNet(torch.nn.Module): | |
def __init__(self): | |
super(SimpleDepthNet, self).__init__() | |
self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1) | |
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1) | |
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1) | |
self.conv4 = torch.nn.Conv2d(128, 1, kernel_size=3, padding=1) | |
self.pool = torch.nn.MaxPool2d(2, 2) | |
self.upsample = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
def forward(self, x): | |
# Encoder | |
x = F.relu(self.conv1(x)) | |
x = self.pool(x) | |
x = F.relu(self.conv2(x)) | |
x = self.pool(x) | |
# Decoder | |
x = self.upsample(x) | |
x = F.relu(self.conv3(x)) | |
x = self.upsample(x) | |
x = torch.sigmoid(self.conv4(x)) | |
return x | |
# Initialize the model | |
model = SimpleDepthNet().to(device) | |
# Define transformation for input images | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
]) | |
def image_to_3d(image): | |
""" | |
Convert a single image to a 3D model using a simple depth extraction approach | |
""" | |
if image is None: | |
return None, "No image provided" | |
try: | |
# Preprocess image | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
# Generate depth map | |
with torch.no_grad(): | |
depth = model(img_tensor)[0, 0].cpu().numpy() | |
# Convert depth map to 3D points | |
h, w = depth.shape | |
y, x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij') | |
# Normalize coordinates | |
x = (x - w/2) / max(w, h) | |
y = (y - h/2) / max(w, h) | |
z = depth - 0.5 # Center around zero | |
# Create point cloud | |
points = np.stack([x.flatten(), y.flatten(), z.flatten()], axis=1) | |
# Get colors from original image | |
img_np = np.array(image.resize((w, h))) / 255.0 | |
colors = img_np.reshape(-1, 3) | |
# Create a mesh from the point cloud (using marching cubes on the depth map) | |
verts, faces, _, _ = measure.marching_cubes(depth, 0.5) | |
mesh = trimesh.Trimesh(vertices=verts, faces=faces) | |
# Save as OBJ | |
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file: | |
obj_path = obj_file.name | |
mesh.export(obj_path) | |
# Also save as PLY for better compatibility with Unity | |
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as ply_file: | |
ply_path = ply_file.name | |
mesh.export(ply_path) | |
return [obj_path, ply_path], "3D model generated successfully!" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def process_image(image): | |
try: | |
if image is None: | |
return None, None, "Please upload an image first." | |
results, message = image_to_3d(image) | |
if results: | |
return results[0], results[1], message | |
else: | |
return None, None, message | |
except Exception as e: | |
return None, None, f"Error: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="Simple Image to 3D Converter") as demo: | |
gr.Markdown("# Simple Image to 3D Converter") | |
gr.Markdown("Upload an image to convert it to a simple 3D model that you can use in Unity or other engines.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil", label="Input Image") | |
submit_btn = gr.Button("Convert to 3D") | |
with gr.Column(scale=1): | |
obj_file = gr.File(label="OBJ File (for editing)") | |
ply_file = gr.File(label="PLY File (for Unity)") | |
output_message = gr.Textbox(label="Output Message") | |
submit_btn.click( | |
fn=process_image, | |
inputs=[input_image], | |
outputs=[obj_file, ply_file, output_message] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |