Spaces:
Running on Zero

Ruurd commited on
Commit
098132b
·
1 Parent(s): 42ed840

Change version of loading model

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -27,20 +27,18 @@ with open("token_probabilities.json") as f:
27
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
28
 
29
  @spaces.GPU
30
- def load_model():
 
31
  ckpt_path = hf_hub_download(
32
  repo_id="ruurd/tini_model",
33
  filename="diffusion-model.pth",
34
  token=os.getenv("HF_TOKEN")
35
  )
 
36
 
37
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- model = torch.load(ckpt_path, map_location=device)
39
- model = disable_dropout(model)
40
- model.to(device)
41
- model.eval()
42
- return model
43
-
44
 
45
  rng = np.random.default_rng()
46
 
@@ -146,10 +144,6 @@ def diffusion_chat(question, eot_weight, max_it, sharpness):
146
 
147
  # --- Gradio Interface ---
148
 
149
- print("Loading model...")
150
- model = load_model()
151
- print("✅ Model loaded.")
152
-
153
  demo = gr.Interface(
154
  fn=diffusion_chat,
155
  inputs=[
 
27
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
28
 
29
  @spaces.GPU
30
+ def load_weights():
31
+ # OK: download & load weights to CPU
32
  ckpt_path = hf_hub_download(
33
  repo_id="ruurd/tini_model",
34
  filename="diffusion-model.pth",
35
  token=os.getenv("HF_TOKEN")
36
  )
37
+ return torch.load(ckpt_path, map_location="cpu") # ✅ returns only CPU tensors
38
 
39
+ model = CustomTransformerModel(...)
40
+ model.load_state_dict(load_weights())
41
+ model.to("cuda") # OK now, after @spaces.GPU is done
 
 
 
 
42
 
43
  rng = np.random.default_rng()
44
 
 
144
 
145
  # --- Gradio Interface ---
146
 
 
 
 
 
147
  demo = gr.Interface(
148
  fn=diffusion_chat,
149
  inputs=[