Spaces:
Running on Zero

Ruurd commited on
Commit
9aaa660
·
1 Parent(s): 0af2920

Load model differently

Browse files
Files changed (1) hide show
  1. app.py +35 -8
app.py CHANGED
@@ -4,8 +4,10 @@ import numpy as np
4
  import json
5
  import time
6
  from transformers import AutoTokenizer
7
- from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout
8
  import os
 
 
9
 
10
  hf_token = os.getenv("HF_TOKEN")
11
 
@@ -24,18 +26,43 @@ token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(toke
24
 
25
 
26
  def load_model():
27
- config = CustomTransformerConfig(vocab_size=vocab_size)
28
- model = CustomTransformerModel(config)
29
- model.load_state_dict(torch.hub.load_state_dict_from_url(
30
- "https://huggingface.co/Ruurd/tini_model/resolve/main/diffusion-model.pth",
31
- map_location="cuda",
32
- headers={"Authorization": f"Bearer {hf_token}"}
33
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model = disable_dropout(model)
35
  model.to("cuda")
36
  model.eval()
 
37
  return model
38
 
 
39
  rng = np.random.default_rng()
40
 
41
  # --- Utility Functions ---
 
4
  import json
5
  import time
6
  from transformers import AutoTokenizer
7
+ from llama_diffusion_model import disable_dropout
8
  import os
9
+ import importlib
10
+ from huggingface_hub import hf_hub_download
11
 
12
  hf_token = os.getenv("HF_TOKEN")
13
 
 
26
 
27
 
28
  def load_model():
29
+
30
+ # 1. Download the checkpoint
31
+ checkpoint_path = hf_hub_download(
32
+ repo_id="ruurd/diffusion-llama",
33
+ filename="diffusion-model.pth",
34
+ token=os.getenv("HF_TOKEN")
35
+ )
36
+
37
+ # 2. Prepare dynamic class loading like you did before
38
+ torch.serialization.clear_safe_globals()
39
+ unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(checkpoint_path)
40
+ missing_class_names = [name.split(".")[-1] for name in unsafe_globals]
41
+
42
+ safe_classes = [cls for name, cls in globals().items() if name in missing_class_names]
43
+
44
+ for class_path in unsafe_globals:
45
+ try:
46
+ module_name, class_name = class_path.rsplit(".", 1)
47
+ module = importlib.import_module(module_name)
48
+ cls = getattr(module, class_name)
49
+ safe_classes.append(cls)
50
+ except (ImportError, AttributeError) as e:
51
+ print(f"⚠️ Warning: Could not import {class_path} - {e}")
52
+
53
+ torch.serialization.add_safe_globals(safe_classes)
54
+
55
+ # 3. Actually load the full model
56
+ model = torch.load(checkpoint_path, weights_only=True)
57
+
58
+ # 4. Final setup
59
  model = disable_dropout(model)
60
  model.to("cuda")
61
  model.eval()
62
+
63
  return model
64
 
65
+
66
  rng = np.random.default_rng()
67
 
68
  # --- Utility Functions ---