YiftachEde commited on
Commit
b149af8
·
verified ·
1 Parent(s): c05134c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -150
app.py CHANGED
@@ -1,147 +1,10 @@
1
  import os
2
- # this is a HF Spaces specific hack, as
3
- # (i) building pytorch3d with GPU support is a bit tricky here
4
- # (ii) installing the wheel via requirements.txt breaks ZeroGPU
5
  import spaces
6
 
7
- # Use the dynamic approach from PyTorch3D documentation to determine the correct wheel
8
  import sys
9
  import torch
10
 
11
- # Print debug information about the environment
12
- try:
13
- cuda_version = torch.version.cuda
14
- torch_version = torch.__version__
15
- python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
16
- print(f"CUDA Version: {cuda_version}")
17
- print(f"PyTorch Version: {torch_version}")
18
- print(f"Python Version: {python_version}")
19
- except Exception as e:
20
- print(f"Error detecting environment versions: {e}")
21
-
22
- # Install PyTorch3D properly from source
23
- print("Installing PyTorch3D from source...")
24
-
25
- # First uninstall any existing PyTorch3D installation to avoid conflicts
26
- os.system("pip uninstall -y pytorch3d")
27
-
28
- # Install dependencies required for building PyTorch3D
29
- print("Installing build dependencies...")
30
- os.system("apt-get update && apt-get install -y git build-essential libglib2.0-0 libsm6 libxrender-dev libxext6 ninja-build")
31
- os.system("pip install 'imageio>=2.5.0' 'matplotlib>=3.1.2' 'numpy>=1.17.3' 'psutil>=5.6.5' 'scipy>=1.3.2' 'tqdm>=4.42.1' 'trimesh>=3.0.0'")
32
- os.system("pip install fvcore iopath")
33
-
34
- # Clone the PyTorch3D repository
35
- print("Cloning PyTorch3D repository...")
36
- os.system("rm -rf pytorch3d") # Remove any existing directory
37
- clone_result = os.system("git clone https://github.com/facebookresearch/pytorch3d.git")
38
- if clone_result != 0:
39
- print("Failed to clone PyTorch3D repository. Trying with git protocol...")
40
- clone_result = os.system("git clone git://github.com/facebookresearch/pytorch3d.git")
41
- if clone_result != 0:
42
- print("Failed to clone PyTorch3D repository with both HTTPS and git protocols.")
43
-
44
- # Use a specific release tag that is known to be stable
45
- checkout_result = os.system("cd pytorch3d && git checkout v0.7.4")
46
- if checkout_result != 0:
47
- print("Failed to checkout v0.7.4 tag. Trying with main branch...")
48
- checkout_result = os.system("cd pytorch3d && git checkout main")
49
- if checkout_result != 0:
50
- print("Failed to checkout main branch.")
51
-
52
- # Install PyTorch3D from source with CPU support
53
- print("Building PyTorch3D from source...")
54
- build_result = os.system("cd pytorch3d && pip install -v -e .")
55
- if build_result != 0:
56
- print("Failed to build PyTorch3D from source with default settings.")
57
-
58
- # Try with specific build flags
59
- print("Trying with specific build flags...")
60
- os.environ["FORCE_CUDA"] = "0" # Explicitly disable CUDA for build
61
- build_result = os.system("cd pytorch3d && pip install -v -e .")
62
-
63
- if build_result != 0:
64
- print("Failed to build PyTorch3D from source with specific build flags.")
65
-
66
- # Try with setup.py directly
67
- print("Trying with setup.py directly...")
68
- build_result = os.system("cd pytorch3d && python setup.py install")
69
-
70
- if build_result != 0:
71
- print("All attempts to build from source failed.")
72
-
73
- # Verify the installation
74
- import_result = os.popen('python -c "import pytorch3d; print(\'pytorch3d imported successfully\'); try: from pytorch3d import renderer; print(\'renderer module imported successfully\'); except ImportError as e: print(f\'Error importing renderer: {e}\');" 2>&1').read()
75
- print(import_result)
76
-
77
- # If the installation fails, try a different approach with wheels from PyPI
78
- if "Error importing renderer" in import_result or "No module named" in import_result:
79
- print("Source installation failed to provide renderer module, trying with PyPI...")
80
- os.system("pip uninstall -y pytorch3d")
81
-
82
- # Try with PyPI version first (which might be CPU-only but should have renderer)
83
- os.system("pip install pytorch3d")
84
-
85
- # Verify again
86
- import_result = os.popen('python -c "import pytorch3d; print(\'pytorch3d imported successfully\'); try: from pytorch3d import renderer; print(\'renderer module imported successfully\'); except ImportError as e: print(f\'Error importing renderer: {e}\');" 2>&1').read()
87
- print(import_result)
88
-
89
- # Patch the shap_e renderer to handle PyTorch3D renderer import error if needed
90
- shap_e_renderer_path = "/usr/local/lib/python3.10/site-packages/shap_e/models/stf/renderer.py"
91
- if os.path.exists(shap_e_renderer_path):
92
- print(f"Patching shap_e renderer at {shap_e_renderer_path}")
93
-
94
- # Read the current content
95
- with open(shap_e_renderer_path, "r") as f:
96
- content = f.read()
97
-
98
- # Create a backup
99
- os.system(f"cp {shap_e_renderer_path} {shap_e_renderer_path}.bak")
100
-
101
- # Modify the content to handle the error more gracefully
102
- modified_content = content
103
-
104
- # Replace the error message
105
- if "exception rendering with PyTorch3D" in content:
106
- modified_content = modified_content.replace(
107
- 'warnings.warn(f"exception rendering with PyTorch3D: {exc}")',
108
- 'warnings.warn("Using native PyTorch renderer")'
109
- )
110
-
111
- # Replace the fallback warning
112
- if "falling back on native PyTorch renderer" in modified_content:
113
- modified_content = modified_content.replace(
114
- 'warnings.warn("falling back on native PyTorch renderer, which does not support full gradients")',
115
- 'warnings.warn("Using native PyTorch renderer")'
116
- )
117
-
118
- # Write the modified content
119
- with open(shap_e_renderer_path, "w") as f:
120
- f.write(modified_content)
121
-
122
- print("Successfully patched shap_e renderer")
123
- else:
124
- print(f"shap_e renderer not found at {shap_e_renderer_path}")
125
-
126
- # Add a helper function to ensure PyTorch3D works with ZeroGPU
127
- def ensure_pytorch3d_cuda_compatibility():
128
- """
129
- This function ensures PyTorch3D works correctly with CUDA in ZeroGPU environments.
130
- It should be called at the beginning of any @spaces.GPU decorated function.
131
- """
132
- try:
133
- import pytorch3d
134
- if torch.cuda.is_available():
135
- # Check if we can access the renderer module
136
- from pytorch3d import renderer
137
- print("PyTorch3D renderer module is available with CUDA")
138
- else:
139
- print("CUDA is not available, using CPU version of PyTorch3D")
140
- except ImportError as e:
141
- print(f"Error importing PyTorch3D: {e}")
142
- except Exception as e:
143
- print(f"Unexpected error with PyTorch3D: {e}")
144
-
145
  import torch
146
  import torch.nn as nn
147
  import gradio as gr
@@ -342,9 +205,6 @@ def load_models():
342
  @spaces.GPU(duration=20)
343
  def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
344
  """Process input images and run refinement"""
345
- # Ensure PyTorch3D works with CUDA
346
- ensure_pytorch3d_cuda_compatibility()
347
-
348
  device = pipeline.device
349
 
350
  if isinstance(input_images, list):
@@ -415,9 +275,6 @@ def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=
415
  @spaces.GPU(duration=20)
416
  def create_mesh(refined_image, model, infer_config):
417
  """Generate mesh from refined image"""
418
- # Ensure PyTorch3D works with CUDA
419
- ensure_pytorch3d_cuda_compatibility()
420
-
421
  # Convert PIL image to tensor
422
  image = np.array(refined_image) / 255.0
423
  image = torch.from_numpy(image).float().permute(2, 0, 1)
@@ -680,9 +537,6 @@ def create_demo():
680
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
681
  def generate(prompt, guidance_scale, num_steps):
682
  try:
683
- # Ensure PyTorch3D works with CUDA
684
- ensure_pytorch3d_cuda_compatibility()
685
-
686
  torch.cuda.empty_cache() # Clear GPU memory before starting
687
  with torch.no_grad():
688
  layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
@@ -696,9 +550,6 @@ def create_demo():
696
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
697
  def refine(input_image, prompt, steps, guidance_scale):
698
  try:
699
- # Ensure PyTorch3D works with CUDA
700
- ensure_pytorch3d_cuda_compatibility()
701
-
702
  torch.cuda.empty_cache() # Clear GPU memory before starting
703
  refined_img, mesh_path = refiner.refine_model(
704
  input_image,
 
1
  import os
2
+ # this is a HF Spaces specific hack for ZeroGPU
 
 
3
  import spaces
4
 
 
5
  import sys
6
  import torch
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import torch
9
  import torch.nn as nn
10
  import gradio as gr
 
205
  @spaces.GPU(duration=20)
206
  def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
207
  """Process input images and run refinement"""
 
 
 
208
  device = pipeline.device
209
 
210
  if isinstance(input_images, list):
 
275
  @spaces.GPU(duration=20)
276
  def create_mesh(refined_image, model, infer_config):
277
  """Generate mesh from refined image"""
 
 
 
278
  # Convert PIL image to tensor
279
  image = np.array(refined_image) / 255.0
280
  image = torch.from_numpy(image).float().permute(2, 0, 1)
 
537
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
538
  def generate(prompt, guidance_scale, num_steps):
539
  try:
 
 
 
540
  torch.cuda.empty_cache() # Clear GPU memory before starting
541
  with torch.no_grad():
542
  layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
 
550
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
551
  def refine(input_image, prompt, steps, guidance_scale):
552
  try:
 
 
 
553
  torch.cuda.empty_cache() # Clear GPU memory before starting
554
  refined_img, mesh_path = refiner.refine_model(
555
  input_image,