Spaces:
Running on Zero

Ruurd commited on
Commit
b5abc9b
·
1 Parent(s): 098132b

Change spaces.gpu application

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -26,19 +26,20 @@ with open("token_probabilities.json") as f:
26
  token_probs_dict = json.load(f)
27
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
28
 
29
- @spaces.GPU
30
- def load_weights():
31
- # OK: download & load weights to CPU
32
  ckpt_path = hf_hub_download(
33
  repo_id="ruurd/tini_model",
34
  filename="diffusion-model.pth",
35
  token=os.getenv("HF_TOKEN")
36
  )
37
- return torch.load(ckpt_path, map_location="cpu") # ✅ returns only CPU tensors
38
 
39
- model = CustomTransformerModel(...)
40
- model.load_state_dict(load_weights())
41
- model.to("cuda") # ✅ OK now, after @spaces.GPU is done
 
 
 
 
42
 
43
  rng = np.random.default_rng()
44
 
@@ -82,6 +83,8 @@ def generate_diffusion_text(input_ids, answer_start):
82
  return input_ids[:answer_start] + sampled[answer_start:]
83
 
84
  # --- Inference Wrapper ---
 
 
85
  def diffusion_chat(question, eot_weight, max_it, sharpness):
86
  placeholder = "What do you know about the city of New York?"
87
  if question.strip() == "":
@@ -144,6 +147,10 @@ def diffusion_chat(question, eot_weight, max_it, sharpness):
144
 
145
  # --- Gradio Interface ---
146
 
 
 
 
 
147
  demo = gr.Interface(
148
  fn=diffusion_chat,
149
  inputs=[
 
26
  token_probs_dict = json.load(f)
27
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
28
 
29
+ def load_model():
 
 
30
  ckpt_path = hf_hub_download(
31
  repo_id="ruurd/tini_model",
32
  filename="diffusion-model.pth",
33
  token=os.getenv("HF_TOKEN")
34
  )
 
35
 
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ model = torch.load(ckpt_path, map_location=device)
38
+ model = disable_dropout(model)
39
+ model.to(device)
40
+ model.eval()
41
+ return model
42
+
43
 
44
  rng = np.random.default_rng()
45
 
 
83
  return input_ids[:answer_start] + sampled[answer_start:]
84
 
85
  # --- Inference Wrapper ---
86
+
87
+ @spaces.GPU
88
  def diffusion_chat(question, eot_weight, max_it, sharpness):
89
  placeholder = "What do you know about the city of New York?"
90
  if question.strip() == "":
 
147
 
148
  # --- Gradio Interface ---
149
 
150
+ print("Loading model...")
151
+ model = load_model()
152
+ print("✅ Model loaded.")
153
+
154
  demo = gr.Interface(
155
  fn=diffusion_chat,
156
  inputs=[