avreymi commited on
Commit
ff9f17d
·
1 Parent(s): 2923ad4

add mpt model

Browse files
Files changed (3) hide show
  1. model.py +52 -0
  2. test.py +2 -1
  3. tests.sh +1 -0
model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import warnings
4
+
5
+
6
+ torch_dtype = torch.bfloat16
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ model_name = "mosaicml/mpt-7b"
9
+
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype=torch_dtype,
14
+ trust_remote_code=True,
15
+ use_auth_token=None,
16
+ )
17
+ tokenizer = AutoTokenizer.from_pretrained(
18
+ model_name,
19
+ trust_remote_code=True,
20
+ use_auth_token=None,
21
+ )
22
+ model.eval()
23
+ model.to(device=device, dtype=torch_dtype)
24
+ if tokenizer.pad_token_id is None:
25
+ warnings.warn(
26
+ "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
27
+ )
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+ tokenizer.padding_side = "left"
30
+
31
+ gkw = {
32
+ "temperature": 0.5,
33
+ "top_p": 0.92,
34
+ "top_k": 0,
35
+ "max_new_tokens": 512,
36
+ "use_cache": True,
37
+ "do_sample": True,
38
+ "eos_token_id": tokenizer.eos_token_id,
39
+ "pad_token_id": tokenizer.pad_token_id,
40
+ "repetition_penalty": 1.1, # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
41
+ }
42
+
43
+
44
+ def mpt_7b(s):
45
+ input_ids = tokenizer(s, return_tensors="pt").input_ids
46
+ input_ids = input_ids.to(model.device)
47
+ with torch.no_grad():
48
+ output_ids = model.generate(input_ids, **gkw)
49
+ # Slice the output_ids tensor to get only new tokens
50
+ new_tokens = output_ids[0, len(input_ids[0]) :]
51
+ output_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
52
+ return output_text
test.py CHANGED
@@ -1,3 +1,4 @@
1
  import subprocess
 
2
 
3
- print(subprocess.run(["dir"], shell=True, capture_output=True).stdout)
 
1
  import subprocess
2
+ import model
3
 
4
+ print(model.mpt_7b("Hello, world!, please generate some text for me."))
tests.sh CHANGED
@@ -1 +1,2 @@
1
  echo "running tests"
 
 
1
  echo "running tests"
2
+ python test.py