Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,49 +1,55 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
from flask import Flask, request, jsonify
|
4 |
-
from pipeline import Zero123PlusPipeline # from your local pipeline.py
|
5 |
from PIL import Image
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
-
# Load the
|
11 |
-
|
|
|
|
|
12 |
pipe = Zero123PlusPipeline.from_pretrained(
|
13 |
"sudo-ai/zero123plus-v1.2",
|
14 |
-
torch_dtype=torch.float32,
|
15 |
)
|
16 |
-
pipe.to(
|
17 |
-
pipe.enable_model_cpu_offload()
|
18 |
-
print("Model loaded.")
|
19 |
|
20 |
@app.route("/")
|
21 |
-
def
|
22 |
-
return
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
<input type="submit" value="Generate 3D View">
|
28 |
-
</form>
|
29 |
-
'''
|
30 |
-
|
31 |
-
@app.route("/generate", methods=["POST"])
|
32 |
-
def generate():
|
33 |
-
if "image" not in request.files:
|
34 |
return jsonify({"error": "No image uploaded"}), 400
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
result = pipe(image, num_inference_steps=50, guidance_scale=3.0)
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
output.save(img_io, "PNG")
|
45 |
-
img_io.seek(0)
|
46 |
-
return send_file(img_io, mimetype="image/png")
|
47 |
|
48 |
if __name__ == "__main__":
|
49 |
app.run(host="0.0.0.0", port=7860)
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
+
from flask import Flask, request, jsonify
|
|
|
4 |
from PIL import Image
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Add the current directory to sys.path to allow local import
|
8 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
9 |
+
|
10 |
+
from pipeline import Zero123PlusPipeline
|
11 |
|
12 |
app = Flask(__name__)
|
13 |
|
14 |
+
# Load the pipeline once when the app starts
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
print(f"Running on device: {device}")
|
17 |
+
|
18 |
pipe = Zero123PlusPipeline.from_pretrained(
|
19 |
"sudo-ai/zero123plus-v1.2",
|
20 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
21 |
)
|
22 |
+
pipe = pipe.to(device)
|
|
|
|
|
23 |
|
24 |
@app.route("/")
|
25 |
+
def index():
|
26 |
+
return "Zero123Plus API is running!"
|
27 |
+
|
28 |
+
@app.route("/predict", methods=["POST"])
|
29 |
+
def predict():
|
30 |
+
if 'image' not in request.files:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
return jsonify({"error": "No image uploaded"}), 400
|
32 |
|
33 |
+
image = request.files["image"]
|
34 |
+
try:
|
35 |
+
input_image = Image.open(image).convert("RGB")
|
36 |
+
|
37 |
+
result = pipe(input_image, num_inference_steps=75, num_images_per_prompt=4)
|
38 |
+
|
39 |
+
images = result.images # List of PIL Images
|
40 |
+
output_dir = "outputs"
|
41 |
+
os.makedirs(output_dir, exist_ok=True)
|
42 |
+
saved_paths = []
|
43 |
+
|
44 |
+
for i, img in enumerate(images):
|
45 |
+
path = os.path.join(output_dir, f"output_{i}.png")
|
46 |
+
img.save(path)
|
47 |
+
saved_paths.append(path)
|
48 |
|
49 |
+
return jsonify({"outputs": saved_paths})
|
|
|
50 |
|
51 |
+
except Exception as e:
|
52 |
+
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
|
53 |
|
54 |
if __name__ == "__main__":
|
55 |
app.run(host="0.0.0.0", port=7860)
|