rasbt commited on
Commit
137c45d
·
verified ·
1 Parent(s): 6a867d4

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. generate_example.py +19 -0
  2. main.py +80 -0
  3. tokenizer.py +1 -1
generate_example.py CHANGED
@@ -32,6 +32,11 @@ TEMPERATURE = 0.
32
  TOP_K = 1
33
  #######################################
34
 
 
 
 
 
 
35
  url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
36
 
37
  if not os.path.exists(MODEL_FILE):
@@ -58,11 +63,25 @@ device = (
58
  )
59
  model.to(device)
60
 
 
 
 
 
 
 
 
 
 
 
61
  tokenizer = Llama3Tokenizer("tokenizer.model")
62
 
63
  if "instruct" in MODEL_FILE:
64
  tokenizer = ChatFormat(tokenizer)
65
 
 
 
 
 
66
  torch.manual_seed(123)
67
 
68
  start = time.time()
 
32
  TOP_K = 1
33
  #######################################
34
 
35
+
36
+ ###################################
37
+ # Initialize model
38
+ ##################################
39
+
40
  url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
41
 
42
  if not os.path.exists(MODEL_FILE):
 
63
  )
64
  model.to(device)
65
 
66
+ ###################################
67
+ # Initialize tokenizer
68
+ ##################################
69
+ TOKENIZER_FILE = "tokenizer.model"
70
+
71
+ url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{TOKENIZER_FILE}"
72
+
73
+ if not os.path.exists(TOKENIZER_FILE):
74
+ urllib.request.urlretrieve(url, TOKENIZER_FILE)
75
+ print(f"Downloaded to {TOKENIZER_FILE}")
76
  tokenizer = Llama3Tokenizer("tokenizer.model")
77
 
78
  if "instruct" in MODEL_FILE:
79
  tokenizer = ChatFormat(tokenizer)
80
 
81
+ ###################################
82
+ # Generate text
83
+ ##################################
84
+
85
  torch.manual_seed(123)
86
 
87
  start = time.time()
main.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2
+ # Source for "Build a Large Language Model From Scratch"
3
+ # https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb
4
+
5
+ import time
6
+ import torch
7
+
8
+ from model import Llama3Model, generate, text_to_token_ids, token_ids_to_text
9
+ from tokenizer import Llama3Tokenizer, ChatFormat, clean_text
10
+
11
+ #######################################
12
+ # Model settings
13
+
14
+ MODEL_FILE = "llama3.2-1B-instruct.pth"
15
+ # MODEL_FILE = "llama3.2-1B-base.pth"
16
+ # MODEL_FILE = "llama3.2-3B-instruct.pth"
17
+ # MODEL_FILE = "llama3.2-3B-base.pth"
18
+
19
+ MODEL_CONTEXT_LENGTH = 8192 # Supports up to 131_072
20
+
21
+ # Text generation settings
22
+ if "instruct" in MODEL_FILE:
23
+ PROMPT = "What do llamas eat?"
24
+ else:
25
+ PROMPT = "Llamas eat"
26
+
27
+ MAX_NEW_TOKENS = 150
28
+ TEMPERATURE = 0.
29
+ TOP_K = 1
30
+ #######################################
31
+
32
+ if "1B" in MODEL_FILE:
33
+ from model import LLAMA32_CONFIG_1B as LLAMA32_CONFIG
34
+ elif "3B" in MODEL_FILE:
35
+ from model import LLAMA32_CONFIG_3B as LLAMA32_CONFIG
36
+ else:
37
+ raise ValueError("Incorrect model file name")
38
+
39
+ model = Llama3Model(LLAMA32_CONFIG)
40
+
41
+ tokenizer = Tokenizer("tokenizer.model")
42
+
43
+ if "instruct" in MODEL_FILE:
44
+ tokenizer = ChatFormat(tokenizer)
45
+
46
+ model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))
47
+
48
+ device = (
49
+ torch.device("cuda") if torch.cuda.is_available() else
50
+ torch.device("mps") if torch.backends.mps.is_available() else
51
+ torch.device("cpu")
52
+ )
53
+ model.to(device)
54
+
55
+ torch.manual_seed(123)
56
+
57
+ start = time.time()
58
+
59
+ token_ids = generate(
60
+ model=model,
61
+ idx=text_to_token_ids(PROMPT, tokenizer).to(device),
62
+ max_new_tokens=MAX_NEW_TOKENS,
63
+ context_size=LLAMA32_CONFIG["context_length"],
64
+ top_k=TOP_K,
65
+ temperature=TEMPERATURE
66
+ )
67
+
68
+ print(f"Time: {time.time() - start:.2f} sec")
69
+
70
+ if torch.cuda.is_available():
71
+ max_mem_bytes = torch.cuda.max_memory_allocated()
72
+ max_mem_gb = max_mem_bytes / (1024 ** 3)
73
+ print(f"Max memory allocated: {max_mem_gb:.2f} GB")
74
+
75
+ output_text = token_ids_to_text(token_ids, tokenizer)
76
+
77
+ if "instruct" in MODEL_FILE:
78
+ output_text = clean_text(output_text)
79
+
80
+ print("\n\nOutput text:\n\n", output_text)
tokenizer.py CHANGED
@@ -10,7 +10,7 @@ import tiktoken
10
  from tiktoken.load import load_tiktoken_bpe
11
 
12
 
13
- class Tokenizer:
14
  def __init__(self, model_path):
15
  assert os.path.isfile(model_path), f"Model file {model_path} not found"
16
  mergeable_ranks = load_tiktoken_bpe(model_path)
 
10
  from tiktoken.load import load_tiktoken_bpe
11
 
12
 
13
+ class Llama3Tokenizer:
14
  def __init__(self, model_path):
15
  assert os.path.isfile(model_path), f"Model file {model_path} not found"
16
  mergeable_ranks = load_tiktoken_bpe(model_path)