Jegree commited on
Commit
c479acb
·
verified ·
1 Parent(s): c781724

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. models.py +88 -88
app.py CHANGED
@@ -226,7 +226,7 @@ def create_app():
226
  )
227
  def get_gpu_kind():
228
  device = jax.devices()[0]
229
- if not gradio_helpers.should_mock():
230
  raise gr.Error('GPU not visible to JAX!')
231
  return f'GPU={device.device_kind}'
232
  demo.load(get_gpu_kind, None, gpu_kind)
@@ -248,4 +248,4 @@ if __name__ == '__main__':
248
  for name, (repo, filename, revision) in models.MODELS.items():
249
  gradio_helpers.register_download(name, repo, filename, revision)
250
 
251
- create_app().queue().launch()
 
226
  )
227
  def get_gpu_kind():
228
  device = jax.devices()[0]
229
+ if not gradio_helpers.should_mock() and device.platform != 'gpu':
230
  raise gr.Error('GPU not visible to JAX!')
231
  return f'GPU={device.device_kind}'
232
  demo.load(get_gpu_kind, None, gpu_kind)
 
248
  for name, (repo, filename, revision) in models.MODELS.items():
249
  gradio_helpers.register_download(name, repo, filename, revision)
250
 
251
+ create_app().queue().launch(share = True)
models.py CHANGED
@@ -1,88 +1,88 @@
1
- """Model-related code and constants."""
2
-
3
- import dataclasses
4
- import os
5
- import re
6
-
7
- import PIL.Image
8
-
9
- # pylint: disable=g-bad-import-order
10
- import gradio_helpers
11
- import paligemma_bv
12
-
13
-
14
- ORGANIZATION = 'google'
15
- BASE_MODELS = [
16
- ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
- ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
18
- ]
19
- MODELS = {
20
- # **{
21
- # model_name: (
22
- # f'{ORGANIZATION}/{repo}',
23
- # f'{model_name}.bf16.npz',
24
- # 'bfloat16', # Model repo revision.
25
- # )
26
- # for repo, model_name in BASE_MODELS
27
- # },
28
- 'testPaligemma':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
29
- }
30
-
31
- MODELS_INFO = {
32
- 'paligemma-3b-mix-224': (
33
- 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
34
- 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
35
- 'bfloat16 and float16 format for research purposes only.'
36
- ),
37
- 'paligemma-3b-mix-448': (
38
- 'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
39
- 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
40
- 'bfloat16 and float16 format for research purposes only.'
41
- ),
42
- }
43
-
44
- MODELS_RES_SEQ = {
45
- 'paligemma-3b-mix-224': (224, 256),
46
- 'paligemma-3b-mix-448': (448, 512),
47
- }
48
-
49
- # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
50
- # Below value should be smaller than "available RAM - one model".
51
- # A single bf16 is about 5860 MB.
52
- MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
53
-
54
- config = paligemma_bv.PaligemmaConfig(
55
- ckpt='', # will be set below
56
- res=224,
57
- text_len=64,
58
- tokenizer='gemma(tokensets=("loc", "seg"))',
59
- vocab_size=256_000 + 1024 + 128,
60
- )
61
-
62
-
63
- def get_cached_model(
64
- model_name: str,
65
- ) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
66
- """Returns model and params, using RAM cache."""
67
- res, seq = MODELS_RES_SEQ[model_name]
68
- model_path = gradio_helpers.get_paths()[model_name]
69
- config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
70
- model, params_cpu = gradio_helpers.get_memory_cache(
71
- config_,
72
- lambda: paligemma_bv.load_model(config_),
73
- max_cache_size_bytes=MAX_RAM_CACHE,
74
- )
75
- return model, params_cpu
76
-
77
-
78
- def generate(
79
- model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
80
- ) -> str:
81
- """Generates output with specified `model_name`, `sampler`."""
82
- model, params_cpu = get_cached_model(model_name)
83
- batch = model.shard_batch(model.prepare_batch([image], [prompt]))
84
- with gradio_helpers.timed('sharding'):
85
- params = model.shard_params(params_cpu)
86
- with gradio_helpers.timed('computation', start_message=True):
87
- tokens = model.predict(params, batch, sampler=sampler)
88
- return model.tokenizer.to_str(tokens[0])
 
1
+ """Model-related code and constants."""
2
+
3
+ import dataclasses
4
+ import os
5
+ import re
6
+
7
+ import PIL.Image
8
+
9
+ # pylint: disable=g-bad-import-order
10
+ import gradio_helpers
11
+ import paligemma_bv
12
+
13
+
14
+ ORGANIZATION = 'google'
15
+ BASE_MODELS = [
16
+ ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
+ ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
18
+ ]
19
+ MODELS = {
20
+ # **{
21
+ # model_name: (
22
+ # f'{ORGANIZATION}/{repo}',
23
+ # f'{model_name}.bf16.npz',
24
+ # 'bfloat16', # Model repo revision.
25
+ # )
26
+ # for repo, model_name in BASE_MODELS
27
+ # },
28
+ 'myPaligem':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
29
+ }
30
+
31
+ MODELS_INFO = {
32
+ 'paligemma-3b-mix-224': (
33
+ 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
34
+ 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
35
+ 'bfloat16 and float16 format for research purposes only.'
36
+ ),
37
+ 'paligemma-3b-mix-448': (
38
+ 'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
39
+ 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
40
+ 'bfloat16 and float16 format for research purposes only.'
41
+ ),
42
+ }
43
+
44
+ MODELS_RES_SEQ = {
45
+ 'paligemma-3b-mix-224': (224, 256),
46
+ 'paligemma-3b-mix-448': (448, 512),
47
+ }
48
+
49
+ # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
50
+ # Below value should be smaller than "available RAM - one model".
51
+ # A single bf16 is about 5860 MB.
52
+ MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
53
+
54
+ config = paligemma_bv.PaligemmaConfig(
55
+ ckpt='', # will be set below
56
+ res=224,
57
+ text_len=64,
58
+ tokenizer='gemma(tokensets=("loc", "seg"))',
59
+ vocab_size=256_000 + 1024 + 128,
60
+ )
61
+
62
+
63
+ def get_cached_model(
64
+ model_name: str,
65
+ ) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
66
+ """Returns model and params, using RAM cache."""
67
+ res, seq = MODELS_RES_SEQ[model_name]
68
+ model_path = gradio_helpers.get_paths()[model_name]
69
+ config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
70
+ model, params_cpu = gradio_helpers.get_memory_cache(
71
+ config_,
72
+ lambda: paligemma_bv.load_model(config_),
73
+ max_cache_size_bytes=MAX_RAM_CACHE,
74
+ )
75
+ return model, params_cpu
76
+
77
+
78
+ def generate(
79
+ model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
80
+ ) -> str:
81
+ """Generates output with specified `model_name`, `sampler`."""
82
+ model, params_cpu = get_cached_model(model_name)
83
+ batch = model.shard_batch(model.prepare_batch([image], [prompt]))
84
+ with gradio_helpers.timed('sharding'):
85
+ params = model.shard_params(params_cpu)
86
+ with gradio_helpers.timed('computation', start_message=True):
87
+ tokens = model.predict(params, batch, sampler=sampler)
88
+ return model.tokenizer.to_str(tokens[0])