rasbt commited on
Commit
e436a25
·
verified ·
1 Parent(s): b984f18

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. generate_example.py +1 -0
  3. model.py +3 -8
README.md CHANGED
@@ -116,6 +116,7 @@ import urllib.request
116
  url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
117
 
118
  if not os.path.exists(MODEL_FILE):
 
119
  urllib.request.urlretrieve(url, MODEL_FILE)
120
  print(f"Downloaded to {MODEL_FILE}")
121
  ```
 
116
  url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
117
 
118
  if not os.path.exists(MODEL_FILE):
119
+ print(f"Downloading {MODEL_FILE}...")
120
  urllib.request.urlretrieve(url, MODEL_FILE)
121
  print(f"Downloaded to {MODEL_FILE}")
122
  ```
generate_example.py CHANGED
@@ -40,6 +40,7 @@ TOP_K = 1
40
  url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
41
 
42
  if not os.path.exists(MODEL_FILE):
 
43
  urllib.request.urlretrieve(url, MODEL_FILE)
44
  print(f"Downloaded to {MODEL_FILE}")
45
 
 
40
  url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
41
 
42
  if not os.path.exists(MODEL_FILE):
43
+ print(f"Downloading {MODEL_FILE}...")
44
  urllib.request.urlretrieve(url, MODEL_FILE)
45
  print(f"Downloaded to {MODEL_FILE}")
46
 
model.py CHANGED
@@ -97,11 +97,8 @@ class TransformerBlock(nn.Module):
97
  self.att = GroupedQueryAttention(
98
  d_in=cfg["emb_dim"],
99
  d_out=cfg["emb_dim"],
100
- context_length=cfg["context_length"],
101
  num_heads=cfg["n_heads"],
102
  num_kv_groups=cfg["n_kv_groups"],
103
- rope_base=cfg["rope_base"],
104
- rope_config=cfg["rope_freq"],
105
  dtype=cfg["dtype"]
106
  )
107
  self.ff = FeedForward(cfg)
@@ -140,10 +137,8 @@ class FeedForward(nn.Module):
140
 
141
  class GroupedQueryAttention(nn.Module):
142
  def __init__(
143
- self, d_in, d_out, context_length, num_heads,
144
  num_kv_groups,
145
- rope_base=10_000,
146
- rope_config=None,
147
  dtype=None
148
  ):
149
  super().__init__()
@@ -306,14 +301,14 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
306
  logits = model(idx_cond)
307
  logits = logits[:, -1, :]
308
 
309
- # New: Filter logits with top_k sampling
310
  if top_k is not None:
311
  # Keep only top_k values
312
  top_logits, _ = torch.topk(logits, top_k)
313
  min_val = top_logits[:, -1]
314
  logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
315
 
316
- # New: Apply temperature scaling
317
  if temperature > 0.0:
318
  logits = logits / temperature
319
 
 
97
  self.att = GroupedQueryAttention(
98
  d_in=cfg["emb_dim"],
99
  d_out=cfg["emb_dim"],
 
100
  num_heads=cfg["n_heads"],
101
  num_kv_groups=cfg["n_kv_groups"],
 
 
102
  dtype=cfg["dtype"]
103
  )
104
  self.ff = FeedForward(cfg)
 
137
 
138
  class GroupedQueryAttention(nn.Module):
139
  def __init__(
140
+ self, d_in, d_out, num_heads,
141
  num_kv_groups,
 
 
142
  dtype=None
143
  ):
144
  super().__init__()
 
301
  logits = model(idx_cond)
302
  logits = logits[:, -1, :]
303
 
304
+ # Filter logits with top_k sampling
305
  if top_k is not None:
306
  # Keep only top_k values
307
  top_logits, _ = torch.topk(logits, top_k)
308
  min_val = top_logits[:, -1]
309
  logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
310
 
311
+ # Apply temperature scaling
312
  if temperature > 0.0:
313
  logits = logits / temperature
314