Upload folder using huggingface_hub
Browse files- README.md +1 -0
- generate_example.py +1 -0
- model.py +3 -8
README.md
CHANGED
@@ -116,6 +116,7 @@ import urllib.request
|
|
116 |
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
|
117 |
|
118 |
if not os.path.exists(MODEL_FILE):
|
|
|
119 |
urllib.request.urlretrieve(url, MODEL_FILE)
|
120 |
print(f"Downloaded to {MODEL_FILE}")
|
121 |
```
|
|
|
116 |
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
|
117 |
|
118 |
if not os.path.exists(MODEL_FILE):
|
119 |
+
print(f"Downloading {MODEL_FILE}...")
|
120 |
urllib.request.urlretrieve(url, MODEL_FILE)
|
121 |
print(f"Downloaded to {MODEL_FILE}")
|
122 |
```
|
generate_example.py
CHANGED
@@ -40,6 +40,7 @@ TOP_K = 1
|
|
40 |
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
|
41 |
|
42 |
if not os.path.exists(MODEL_FILE):
|
|
|
43 |
urllib.request.urlretrieve(url, MODEL_FILE)
|
44 |
print(f"Downloaded to {MODEL_FILE}")
|
45 |
|
|
|
40 |
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
|
41 |
|
42 |
if not os.path.exists(MODEL_FILE):
|
43 |
+
print(f"Downloading {MODEL_FILE}...")
|
44 |
urllib.request.urlretrieve(url, MODEL_FILE)
|
45 |
print(f"Downloaded to {MODEL_FILE}")
|
46 |
|
model.py
CHANGED
@@ -97,11 +97,8 @@ class TransformerBlock(nn.Module):
|
|
97 |
self.att = GroupedQueryAttention(
|
98 |
d_in=cfg["emb_dim"],
|
99 |
d_out=cfg["emb_dim"],
|
100 |
-
context_length=cfg["context_length"],
|
101 |
num_heads=cfg["n_heads"],
|
102 |
num_kv_groups=cfg["n_kv_groups"],
|
103 |
-
rope_base=cfg["rope_base"],
|
104 |
-
rope_config=cfg["rope_freq"],
|
105 |
dtype=cfg["dtype"]
|
106 |
)
|
107 |
self.ff = FeedForward(cfg)
|
@@ -140,10 +137,8 @@ class FeedForward(nn.Module):
|
|
140 |
|
141 |
class GroupedQueryAttention(nn.Module):
|
142 |
def __init__(
|
143 |
-
self, d_in, d_out,
|
144 |
num_kv_groups,
|
145 |
-
rope_base=10_000,
|
146 |
-
rope_config=None,
|
147 |
dtype=None
|
148 |
):
|
149 |
super().__init__()
|
@@ -306,14 +301,14 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
|
306 |
logits = model(idx_cond)
|
307 |
logits = logits[:, -1, :]
|
308 |
|
309 |
-
#
|
310 |
if top_k is not None:
|
311 |
# Keep only top_k values
|
312 |
top_logits, _ = torch.topk(logits, top_k)
|
313 |
min_val = top_logits[:, -1]
|
314 |
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
|
315 |
|
316 |
-
#
|
317 |
if temperature > 0.0:
|
318 |
logits = logits / temperature
|
319 |
|
|
|
97 |
self.att = GroupedQueryAttention(
|
98 |
d_in=cfg["emb_dim"],
|
99 |
d_out=cfg["emb_dim"],
|
|
|
100 |
num_heads=cfg["n_heads"],
|
101 |
num_kv_groups=cfg["n_kv_groups"],
|
|
|
|
|
102 |
dtype=cfg["dtype"]
|
103 |
)
|
104 |
self.ff = FeedForward(cfg)
|
|
|
137 |
|
138 |
class GroupedQueryAttention(nn.Module):
|
139 |
def __init__(
|
140 |
+
self, d_in, d_out, num_heads,
|
141 |
num_kv_groups,
|
|
|
|
|
142 |
dtype=None
|
143 |
):
|
144 |
super().__init__()
|
|
|
301 |
logits = model(idx_cond)
|
302 |
logits = logits[:, -1, :]
|
303 |
|
304 |
+
# Filter logits with top_k sampling
|
305 |
if top_k is not None:
|
306 |
# Keep only top_k values
|
307 |
top_logits, _ = torch.topk(logits, top_k)
|
308 |
min_val = top_logits[:, -1]
|
309 |
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
|
310 |
|
311 |
+
# Apply temperature scaling
|
312 |
if temperature > 0.0:
|
313 |
logits = logits / temperature
|
314 |
|