Spaces:
Running
Running
Upload 2 files
Browse files
app.py
CHANGED
@@ -226,7 +226,7 @@ def create_app():
|
|
226 |
)
|
227 |
def get_gpu_kind():
|
228 |
device = jax.devices()[0]
|
229 |
-
if not gradio_helpers.should_mock():
|
230 |
raise gr.Error('GPU not visible to JAX!')
|
231 |
return f'GPU={device.device_kind}'
|
232 |
demo.load(get_gpu_kind, None, gpu_kind)
|
@@ -248,4 +248,4 @@ if __name__ == '__main__':
|
|
248 |
for name, (repo, filename, revision) in models.MODELS.items():
|
249 |
gradio_helpers.register_download(name, repo, filename, revision)
|
250 |
|
251 |
-
create_app().queue().launch()
|
|
|
226 |
)
|
227 |
def get_gpu_kind():
|
228 |
device = jax.devices()[0]
|
229 |
+
if not gradio_helpers.should_mock() and device.platform != 'gpu':
|
230 |
raise gr.Error('GPU not visible to JAX!')
|
231 |
return f'GPU={device.device_kind}'
|
232 |
demo.load(get_gpu_kind, None, gpu_kind)
|
|
|
248 |
for name, (repo, filename, revision) in models.MODELS.items():
|
249 |
gradio_helpers.register_download(name, repo, filename, revision)
|
250 |
|
251 |
+
create_app().queue().launch(share = True)
|
models.py
CHANGED
@@ -1,88 +1,88 @@
|
|
1 |
-
"""Model-related code and constants."""
|
2 |
-
|
3 |
-
import dataclasses
|
4 |
-
import os
|
5 |
-
import re
|
6 |
-
|
7 |
-
import PIL.Image
|
8 |
-
|
9 |
-
# pylint: disable=g-bad-import-order
|
10 |
-
import gradio_helpers
|
11 |
-
import paligemma_bv
|
12 |
-
|
13 |
-
|
14 |
-
ORGANIZATION = 'google'
|
15 |
-
BASE_MODELS = [
|
16 |
-
('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
|
17 |
-
('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
|
18 |
-
]
|
19 |
-
MODELS = {
|
20 |
-
# **{
|
21 |
-
# model_name: (
|
22 |
-
# f'{ORGANIZATION}/{repo}',
|
23 |
-
# f'{model_name}.bf16.npz',
|
24 |
-
# 'bfloat16', # Model repo revision.
|
25 |
-
# )
|
26 |
-
# for repo, model_name in BASE_MODELS
|
27 |
-
# },
|
28 |
-
'
|
29 |
-
}
|
30 |
-
|
31 |
-
MODELS_INFO = {
|
32 |
-
'paligemma-3b-mix-224': (
|
33 |
-
'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
|
34 |
-
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
|
35 |
-
'bfloat16 and float16 format for research purposes only.'
|
36 |
-
),
|
37 |
-
'paligemma-3b-mix-448': (
|
38 |
-
'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
|
39 |
-
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
|
40 |
-
'bfloat16 and float16 format for research purposes only.'
|
41 |
-
),
|
42 |
-
}
|
43 |
-
|
44 |
-
MODELS_RES_SEQ = {
|
45 |
-
'paligemma-3b-mix-224': (224, 256),
|
46 |
-
'paligemma-3b-mix-448': (448, 512),
|
47 |
-
}
|
48 |
-
|
49 |
-
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
|
50 |
-
# Below value should be smaller than "available RAM - one model".
|
51 |
-
# A single bf16 is about 5860 MB.
|
52 |
-
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
|
53 |
-
|
54 |
-
config = paligemma_bv.PaligemmaConfig(
|
55 |
-
ckpt='', # will be set below
|
56 |
-
res=224,
|
57 |
-
text_len=64,
|
58 |
-
tokenizer='gemma(tokensets=("loc", "seg"))',
|
59 |
-
vocab_size=256_000 + 1024 + 128,
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
def get_cached_model(
|
64 |
-
model_name: str,
|
65 |
-
) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
|
66 |
-
"""Returns model and params, using RAM cache."""
|
67 |
-
res, seq = MODELS_RES_SEQ[model_name]
|
68 |
-
model_path = gradio_helpers.get_paths()[model_name]
|
69 |
-
config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
|
70 |
-
model, params_cpu = gradio_helpers.get_memory_cache(
|
71 |
-
config_,
|
72 |
-
lambda: paligemma_bv.load_model(config_),
|
73 |
-
max_cache_size_bytes=MAX_RAM_CACHE,
|
74 |
-
)
|
75 |
-
return model, params_cpu
|
76 |
-
|
77 |
-
|
78 |
-
def generate(
|
79 |
-
model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
|
80 |
-
) -> str:
|
81 |
-
"""Generates output with specified `model_name`, `sampler`."""
|
82 |
-
model, params_cpu = get_cached_model(model_name)
|
83 |
-
batch = model.shard_batch(model.prepare_batch([image], [prompt]))
|
84 |
-
with gradio_helpers.timed('sharding'):
|
85 |
-
params = model.shard_params(params_cpu)
|
86 |
-
with gradio_helpers.timed('computation', start_message=True):
|
87 |
-
tokens = model.predict(params, batch, sampler=sampler)
|
88 |
-
return model.tokenizer.to_str(tokens[0])
|
|
|
1 |
+
"""Model-related code and constants."""
|
2 |
+
|
3 |
+
import dataclasses
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
import PIL.Image
|
8 |
+
|
9 |
+
# pylint: disable=g-bad-import-order
|
10 |
+
import gradio_helpers
|
11 |
+
import paligemma_bv
|
12 |
+
|
13 |
+
|
14 |
+
ORGANIZATION = 'google'
|
15 |
+
BASE_MODELS = [
|
16 |
+
('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
|
17 |
+
('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
|
18 |
+
]
|
19 |
+
MODELS = {
|
20 |
+
# **{
|
21 |
+
# model_name: (
|
22 |
+
# f'{ORGANIZATION}/{repo}',
|
23 |
+
# f'{model_name}.bf16.npz',
|
24 |
+
# 'bfloat16', # Model repo revision.
|
25 |
+
# )
|
26 |
+
# for repo, model_name in BASE_MODELS
|
27 |
+
# },
|
28 |
+
'myPaligem':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
|
29 |
+
}
|
30 |
+
|
31 |
+
MODELS_INFO = {
|
32 |
+
'paligemma-3b-mix-224': (
|
33 |
+
'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
|
34 |
+
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
|
35 |
+
'bfloat16 and float16 format for research purposes only.'
|
36 |
+
),
|
37 |
+
'paligemma-3b-mix-448': (
|
38 |
+
'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
|
39 |
+
'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
|
40 |
+
'bfloat16 and float16 format for research purposes only.'
|
41 |
+
),
|
42 |
+
}
|
43 |
+
|
44 |
+
MODELS_RES_SEQ = {
|
45 |
+
'paligemma-3b-mix-224': (224, 256),
|
46 |
+
'paligemma-3b-mix-448': (448, 512),
|
47 |
+
}
|
48 |
+
|
49 |
+
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
|
50 |
+
# Below value should be smaller than "available RAM - one model".
|
51 |
+
# A single bf16 is about 5860 MB.
|
52 |
+
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
|
53 |
+
|
54 |
+
config = paligemma_bv.PaligemmaConfig(
|
55 |
+
ckpt='', # will be set below
|
56 |
+
res=224,
|
57 |
+
text_len=64,
|
58 |
+
tokenizer='gemma(tokensets=("loc", "seg"))',
|
59 |
+
vocab_size=256_000 + 1024 + 128,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def get_cached_model(
|
64 |
+
model_name: str,
|
65 |
+
) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
|
66 |
+
"""Returns model and params, using RAM cache."""
|
67 |
+
res, seq = MODELS_RES_SEQ[model_name]
|
68 |
+
model_path = gradio_helpers.get_paths()[model_name]
|
69 |
+
config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
|
70 |
+
model, params_cpu = gradio_helpers.get_memory_cache(
|
71 |
+
config_,
|
72 |
+
lambda: paligemma_bv.load_model(config_),
|
73 |
+
max_cache_size_bytes=MAX_RAM_CACHE,
|
74 |
+
)
|
75 |
+
return model, params_cpu
|
76 |
+
|
77 |
+
|
78 |
+
def generate(
|
79 |
+
model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
|
80 |
+
) -> str:
|
81 |
+
"""Generates output with specified `model_name`, `sampler`."""
|
82 |
+
model, params_cpu = get_cached_model(model_name)
|
83 |
+
batch = model.shard_batch(model.prepare_batch([image], [prompt]))
|
84 |
+
with gradio_helpers.timed('sharding'):
|
85 |
+
params = model.shard_params(params_cpu)
|
86 |
+
with gradio_helpers.timed('computation', start_message=True):
|
87 |
+
tokens = model.predict(params, batch, sampler=sampler)
|
88 |
+
return model.tokenizer.to_str(tokens[0])
|