Spaces:
Running
Running
Update models.py
Browse files
models.py
CHANGED
@@ -11,34 +11,38 @@ import gradio_helpers
|
|
11 |
import paligemma_bv
|
12 |
|
13 |
|
14 |
-
ORGANIZATION = '
|
15 |
BASE_MODELS = [
|
16 |
-
|
17 |
-
|
18 |
-
('myPaligem', 'fine-tuned-paligemma-3b-pt-224')
|
19 |
]
|
20 |
MODELS = {
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
'fine-tuned-paligemma-3b-pt-224':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
|
30 |
}
|
31 |
|
32 |
MODELS_INFO = {
|
33 |
-
'
|
34 |
'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
|
35 |
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
|
36 |
'bfloat16 and float16 format for research purposes only.'
|
37 |
),
|
|
|
|
|
|
|
|
|
|
|
38 |
}
|
39 |
|
40 |
MODELS_RES_SEQ = {
|
41 |
-
'
|
|
|
42 |
}
|
43 |
|
44 |
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
|
@@ -49,7 +53,7 @@ MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
|
|
49 |
config = paligemma_bv.PaligemmaConfig(
|
50 |
ckpt='', # will be set below
|
51 |
res=224,
|
52 |
-
text_len=
|
53 |
tokenizer='gemma(tokensets=("loc", "seg"))',
|
54 |
vocab_size=256_000 + 1024 + 128,
|
55 |
)
|
@@ -80,4 +84,4 @@ def generate(
|
|
80 |
params = model.shard_params(params_cpu)
|
81 |
with gradio_helpers.timed('computation', start_message=True):
|
82 |
tokens = model.predict(params, batch, sampler=sampler)
|
83 |
-
return model.tokenizer.to_str(tokens[0])
|
|
|
11 |
import paligemma_bv
|
12 |
|
13 |
|
14 |
+
ORGANIZATION = 'google'
|
15 |
BASE_MODELS = [
|
16 |
+
('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
|
17 |
+
('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
|
|
|
18 |
]
|
19 |
MODELS = {
|
20 |
+
**{
|
21 |
+
model_name: (
|
22 |
+
f'{ORGANIZATION}/{repo}',
|
23 |
+
f'{model_name}.bf16.npz',
|
24 |
+
'bfloat16', # Model repo revision.
|
25 |
+
)
|
26 |
+
for repo, model_name in BASE_MODELS
|
27 |
+
},
|
|
|
28 |
}
|
29 |
|
30 |
MODELS_INFO = {
|
31 |
+
'paligemma-3b-mix-224': (
|
32 |
'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
|
33 |
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
|
34 |
'bfloat16 and float16 format for research purposes only.'
|
35 |
),
|
36 |
+
'paligemma-3b-mix-448': (
|
37 |
+
'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
|
38 |
+
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
|
39 |
+
'bfloat16 and float16 format for research purposes only.'
|
40 |
+
),
|
41 |
}
|
42 |
|
43 |
MODELS_RES_SEQ = {
|
44 |
+
'paligemma-3b-mix-224': (224, 256),
|
45 |
+
'paligemma-3b-mix-448': (448, 512),
|
46 |
}
|
47 |
|
48 |
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
|
|
|
53 |
config = paligemma_bv.PaligemmaConfig(
|
54 |
ckpt='', # will be set below
|
55 |
res=224,
|
56 |
+
text_len=64,
|
57 |
tokenizer='gemma(tokensets=("loc", "seg"))',
|
58 |
vocab_size=256_000 + 1024 + 128,
|
59 |
)
|
|
|
84 |
params = model.shard_params(params_cpu)
|
85 |
with gradio_helpers.timed('computation', start_message=True):
|
86 |
tokens = model.predict(params, batch, sampler=sampler)
|
87 |
+
return model.tokenizer.to_str(tokens[0])
|