Jegree commited on
Commit
306aed7
·
verified ·
1 Parent(s): 307292d

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +21 -17
models.py CHANGED
@@ -11,34 +11,38 @@ import gradio_helpers
11
  import paligemma_bv
12
 
13
 
14
- ORGANIZATION = 'Jegree'
15
  BASE_MODELS = [
16
- # ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
- # ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
18
- ('myPaligem', 'fine-tuned-paligemma-3b-pt-224')
19
  ]
20
  MODELS = {
21
- # **{
22
- # model_name: (
23
- # f'{ORGANIZATION}/{repo}',
24
- # f'{model_name}.bf16.npz',
25
- # 'bfloat16', # Model repo revision.
26
- # )
27
- # for repo, model_name in BASE_MODELS
28
- # },
29
- 'fine-tuned-paligemma-3b-pt-224':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
30
  }
31
 
32
  MODELS_INFO = {
33
- 'fine-tuned-paligemma-3b-pt-224': (
34
  'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
35
  'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
36
  'bfloat16 and float16 format for research purposes only.'
37
  ),
 
 
 
 
 
38
  }
39
 
40
  MODELS_RES_SEQ = {
41
- 'fine-tuned-paligemma-3b-pt-224': (224, 128),
 
42
  }
43
 
44
  # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
@@ -49,7 +53,7 @@ MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
49
  config = paligemma_bv.PaligemmaConfig(
50
  ckpt='', # will be set below
51
  res=224,
52
- text_len=128,
53
  tokenizer='gemma(tokensets=("loc", "seg"))',
54
  vocab_size=256_000 + 1024 + 128,
55
  )
@@ -80,4 +84,4 @@ def generate(
80
  params = model.shard_params(params_cpu)
81
  with gradio_helpers.timed('computation', start_message=True):
82
  tokens = model.predict(params, batch, sampler=sampler)
83
- return model.tokenizer.to_str(tokens[0])
 
11
  import paligemma_bv
12
 
13
 
14
+ ORGANIZATION = 'google'
15
  BASE_MODELS = [
16
+ ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
+ ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
 
18
  ]
19
  MODELS = {
20
+ **{
21
+ model_name: (
22
+ f'{ORGANIZATION}/{repo}',
23
+ f'{model_name}.bf16.npz',
24
+ 'bfloat16', # Model repo revision.
25
+ )
26
+ for repo, model_name in BASE_MODELS
27
+ },
 
28
  }
29
 
30
  MODELS_INFO = {
31
+ 'paligemma-3b-mix-224': (
32
  'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
33
  'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
34
  'bfloat16 and float16 format for research purposes only.'
35
  ),
36
+ 'paligemma-3b-mix-448': (
37
+ 'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
38
+ 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
39
+ 'bfloat16 and float16 format for research purposes only.'
40
+ ),
41
  }
42
 
43
  MODELS_RES_SEQ = {
44
+ 'paligemma-3b-mix-224': (224, 256),
45
+ 'paligemma-3b-mix-448': (448, 512),
46
  }
47
 
48
  # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
 
53
  config = paligemma_bv.PaligemmaConfig(
54
  ckpt='', # will be set below
55
  res=224,
56
+ text_len=64,
57
  tokenizer='gemma(tokensets=("loc", "seg"))',
58
  vocab_size=256_000 + 1024 + 128,
59
  )
 
84
  params = model.shard_params(params_cpu)
85
  with gradio_helpers.timed('computation', start_message=True):
86
  tokens = model.predict(params, batch, sampler=sampler)
87
+ return model.tokenizer.to_str(tokens[0])