Steph254 commited on
Commit
e519624
·
verified ·
1 Parent(s): 196f1dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -23,19 +23,22 @@ def load_llama_model(model_name):
23
 
24
  tokenizer = LlamaTokenizer.from_pretrained(model_name, token=HUGGINGFACE_TOKEN)
25
 
26
- # Load the checkpoint manually
27
- model_path = f"{model_name}/consolidated.00.pth"
28
- state_dict = torch.load(model_path, map_location="cpu") # Adjust for GPU if needed
29
 
30
  print("✅ Model state dictionary loaded successfully!")
31
-
32
- return tokenizer, state_dict
 
 
 
33
 
34
  # Load the quantized Llama model
35
  tokenizer, model = load_llama_model(QUANTIZED_MODEL)
36
 
37
  # Load Llama Guard for content moderation
38
- guard_tokenizer, guard_model = load_llama_model(LLAMA_GUARD_NAME, is_guard=True)
39
 
40
  # Define Prompt Templates
41
  PROMPTS = {
 
23
 
24
  tokenizer = LlamaTokenizer.from_pretrained(model_name, token=HUGGINGFACE_TOKEN)
25
 
26
+ # Manually load `.pth` state dictionary
27
+ model_url = f"https://huggingface.co/{model_name}/resolve/main/consolidated.00.pth"
28
+ state_dict = torch.hub.load_state_dict_from_url(model_url, map_location="cpu")
29
 
30
  print("✅ Model state dictionary loaded successfully!")
31
+
32
+ # Initialize model and load state_dict
33
+ model = AutoModelForCausalLM.from_pretrained(model_name, state_dict=state_dict)
34
+
35
+ return tokenizer, model
36
 
37
  # Load the quantized Llama model
38
  tokenizer, model = load_llama_model(QUANTIZED_MODEL)
39
 
40
  # Load Llama Guard for content moderation
41
+ guard_tokenizer, guard_model = load_llama_model(LLAMA_GUARD_NAME)
42
 
43
  # Define Prompt Templates
44
  PROMPTS = {