mike23415 commited on
Commit
388cf5c
·
verified ·
1 Parent(s): 261bbdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -32
app.py CHANGED
@@ -1,49 +1,55 @@
1
  import os
2
- import torch
3
- from flask import Flask, request, jsonify, send_file
4
- from pipeline import Zero123PlusPipeline # from your local pipeline.py
5
  from PIL import Image
6
- from io import BytesIO
 
 
 
 
 
7
 
8
  app = Flask(__name__)
9
 
10
- # Load the model once at startup (on CPU)
11
- print("Loading Zero123Plus pipeline on CPU...")
 
 
12
  pipe = Zero123PlusPipeline.from_pretrained(
13
  "sudo-ai/zero123plus-v1.2",
14
- torch_dtype=torch.float32,
15
  )
16
- pipe.to("cpu")
17
- pipe.enable_model_cpu_offload()
18
- print("Model loaded.")
19
 
20
  @app.route("/")
21
- def home():
22
- return '''
23
- <h1>Zero123Plus Image to 3D Generator</h1>
24
- <form action="/generate" method="post" enctype="multipart/form-data">
25
- <p>Upload a single-view image:</p>
26
- <input type="file" name="image"><br><br>
27
- <input type="submit" value="Generate 3D View">
28
- </form>
29
- '''
30
-
31
- @app.route("/generate", methods=["POST"])
32
- def generate():
33
- if "image" not in request.files:
34
  return jsonify({"error": "No image uploaded"}), 400
35
 
36
- file = request.files["image"]
37
- image = Image.open(file.stream).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- print("Generating 3D view...")
40
- result = pipe(image, num_inference_steps=50, guidance_scale=3.0)
41
 
42
- output = result.images[0]
43
- img_io = BytesIO()
44
- output.save(img_io, "PNG")
45
- img_io.seek(0)
46
- return send_file(img_io, mimetype="image/png")
47
 
48
  if __name__ == "__main__":
49
  app.run(host="0.0.0.0", port=7860)
 
1
  import os
2
+ import sys
3
+ from flask import Flask, request, jsonify
 
4
  from PIL import Image
5
+ import torch
6
+
7
+ # Add the current directory to sys.path to allow local import
8
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
+
10
+ from pipeline import Zero123PlusPipeline
11
 
12
  app = Flask(__name__)
13
 
14
+ # Load the pipeline once when the app starts
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"Running on device: {device}")
17
+
18
  pipe = Zero123PlusPipeline.from_pretrained(
19
  "sudo-ai/zero123plus-v1.2",
20
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
21
  )
22
+ pipe = pipe.to(device)
 
 
23
 
24
  @app.route("/")
25
+ def index():
26
+ return "Zero123Plus API is running!"
27
+
28
+ @app.route("/predict", methods=["POST"])
29
+ def predict():
30
+ if 'image' not in request.files:
 
 
 
 
 
 
 
31
  return jsonify({"error": "No image uploaded"}), 400
32
 
33
+ image = request.files["image"]
34
+ try:
35
+ input_image = Image.open(image).convert("RGB")
36
+
37
+ result = pipe(input_image, num_inference_steps=75, num_images_per_prompt=4)
38
+
39
+ images = result.images # List of PIL Images
40
+ output_dir = "outputs"
41
+ os.makedirs(output_dir, exist_ok=True)
42
+ saved_paths = []
43
+
44
+ for i, img in enumerate(images):
45
+ path = os.path.join(output_dir, f"output_{i}.png")
46
+ img.save(path)
47
+ saved_paths.append(path)
48
 
49
+ return jsonify({"outputs": saved_paths})
 
50
 
51
+ except Exception as e:
52
+ return jsonify({"error": str(e)}), 500
 
 
 
53
 
54
  if __name__ == "__main__":
55
  app.run(host="0.0.0.0", port=7860)