Pedro Cuenca commited on
Commit
d840f35
·
1 Parent(s): 4575fad

Fix state scope (it was global!)

Browse files
Files changed (1) hide show
  1. app.py +109 -91
app.py CHANGED
@@ -8,91 +8,92 @@ from functools import partial
8
  from PIL import Image
9
 
10
  SELECT_LABEL = "Select as seed"
11
- selectors: List[gr.Radio] = []
12
- seeds = []
13
-
14
- block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
15
-
16
- def infer_seeded_image(prompt, seed):
17
- print(f"Prompt: {prompt}, seed: {seed}")
18
- return Image.open(f"sample_outputs/seeded_1.png")
19
-
20
- def infer_grid(prompt):
21
- response = defaultdict(list)
22
- for i in range(1, 7):
23
- response["images"].append(Image.open(f"sample_outputs/{i}.png"))
24
- response["seeds"].append(random.randint(0, 2 ** 32 -1))
25
-
26
- global seeds
27
- seeds = response["seeds"]
28
- return response["images"]
29
-
30
- def infer(prompt):
31
- """
32
- Outputs:
33
- - Grid images (list)
34
- - Seeded Image (Image or None)
35
- - Grid Box with updated visibility
36
- - Seeded Box with updated visibility
37
- """
38
- grid_images = [None] * 6
39
- image_with_seed = None
40
- visible = (False, False)
41
-
42
- if seed_index := current_selection() > -1:
43
- seed = seeds[seed_index]
44
- image_with_seed = infer_seeded_image(prompt, seed)
45
- visible = (False, True)
46
- else:
47
- grid_images = infer_grid(prompt)
48
- visible = (True, False)
49
-
50
- boxes = [gr.Box.update(visible=v) for v in visible]
51
- return grid_images + [image_with_seed] + boxes
52
-
53
-
54
- def image_block():
55
- return gr.Image(
56
- interactive=False, show_label=False
57
- ).style(
58
- # border = (True, True, False, True),
59
- rounded = (True, True, False, False),
60
- )
61
 
62
- selectors_state = [''] * 6
63
- def did_select(radio: gr.Radio):
64
- new_state = list(map(lambda r: SELECT_LABEL if r == radio else '', selectors))
65
- return new_state
66
-
67
- def update_state(radio: gr.Radio, *state):
68
- global selectors_state
69
- if list(state) != selectors_state:
70
- selectors_state = did_select(radio)
71
- return selectors_state
72
-
73
- def current_selection():
74
- try:
75
- return selectors_state.index(SELECT_LABEL)
76
- except:
77
- return -1
78
-
79
- def clear_seed():
80
- """Update state of Radio buttons, grid, seeded_box"""
81
- global selectors_state
82
- selectors_state = [''] * 6
83
- return selectors_state + [gr.Box.update(visible=True), gr.Box.update(visible=False)]
84
-
85
- def radio_block():
86
- radio = gr.Radio(
87
- choices=[SELECT_LABEL], interactive=True, show_label=False,
88
- ).style(
89
- # border = (False, True, True, True),
90
- # rounded = (False, False, True, True)
91
- container=False
92
- )
93
- return radio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- with block:
96
  gr.Markdown("<h1><center>Latent Diffusion Demo</center></h1>")
97
  with gr.Group():
98
  with gr.Box():
@@ -134,22 +135,39 @@ with block:
134
  select6 = radio_block()
135
 
136
  images = [image1, image2, image3, image4, image5, image6]
137
- selectors += [select1, select2, select3, select4, select5, select6]
138
 
139
- for radio in selectors:
140
- radio.change(fn=partial(update_state, radio), inputs=selectors, outputs=selectors)
 
 
 
 
 
141
 
142
  with (seeded_box := gr.Box()):
143
  seeded_image = image_block()
144
  clear_seed_button = gr.Button("Clear Seed")
145
  seeded_box.visible = False
146
- clear_seed_button.click(clear_seed, inputs=[], outputs=selectors + [grid, seeded_box])
 
 
 
 
147
 
148
  all_images = images + [seeded_image]
149
  boxes = [grid, seeded_box]
150
- infer_outputs = all_images + boxes
151
 
152
- text.submit(infer, inputs=text, outputs=infer_outputs)
153
- btn.click(infer, inputs=text, outputs=infer_outputs)
 
 
 
 
 
 
 
 
154
 
155
- block.launch(enable_queue=False)
 
8
  from PIL import Image
9
 
10
  SELECT_LABEL = "Select as seed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
13
+ state = gr.Variable({
14
+ 'selectors': [''] * 6,
15
+ 'seeds': []
16
+ })
17
+ # selectors_state = [''] * 6
18
+ # seeds = []
19
+
20
+ def infer_seeded_image(prompt, seed):
21
+ print(f"Prompt: {prompt}, seed: {seed}")
22
+ return Image.open(f"sample_outputs/seeded_1.png")
23
+
24
+ def infer_grid(prompt):
25
+ response = defaultdict(list)
26
+ for i in range(1, 7):
27
+ response["images"].append(Image.open(f"sample_outputs/{i}.png"))
28
+ response["seeds"].append(random.randint(0, 2 ** 32 -1))
29
+
30
+ return response["images"], response["seeds"]
31
+
32
+ def infer(prompt, state):
33
+ """
34
+ Outputs:
35
+ - Grid images (list)
36
+ - Seeded Image (Image or None)
37
+ - Grid Box with updated visibility
38
+ - Seeded Box with updated visibility
39
+ """
40
+ grid_images = [None] * 6
41
+ image_with_seed = None
42
+ visible = (False, False)
43
+
44
+ if seed_index := current_selection(state) > -1:
45
+ seed = state["seeds"][seed_index]
46
+ image_with_seed = infer_seeded_image(prompt, seed)
47
+ visible = (False, True)
48
+ else:
49
+ grid_images, seeds = infer_grid(prompt)
50
+ state["seeds"] = seeds
51
+ visible = (True, False)
52
+
53
+ boxes = [gr.Box.update(visible=v) for v in visible]
54
+ return grid_images + [image_with_seed] + boxes + [state]
55
+
56
+
57
+ def image_block():
58
+ return gr.Image(
59
+ interactive=False, show_label=False
60
+ ).style(
61
+ # border = (True, True, False, True),
62
+ rounded = (True, True, False, False),
63
+ )
64
+
65
+ def update_state(selected_index: int, value, state):
66
+ if value == '':
67
+ others_value = None
68
+ else:
69
+ others_value = ''
70
+ new_state = [''] * 6
71
+ new_state[selected_index] = SELECT_LABEL
72
+ state["selectors"] = new_state
73
+ others = gr.Radio.update(value=others_value)
74
+ return [others] * 5 + [state]
75
+
76
+ def current_selection(state):
77
+ try:
78
+ return state["selectors"].index(SELECT_LABEL)
79
+ except:
80
+ return -1
81
+
82
+ def clear_seed(state):
83
+ """Update state of Radio buttons, grid, seeded_box"""
84
+ state["selectors"] = [''] * 6
85
+ return state["selectors"] + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]
86
+
87
+ def radio_block():
88
+ radio = gr.Radio(
89
+ choices=[SELECT_LABEL], interactive=True, show_label=False,
90
+ ).style(
91
+ # border = (False, True, True, True),
92
+ # rounded = (False, False, True, True)
93
+ container=False
94
+ )
95
+ return radio
96
 
 
97
  gr.Markdown("<h1><center>Latent Diffusion Demo</center></h1>")
98
  with gr.Group():
99
  with gr.Box():
 
135
  select6 = radio_block()
136
 
137
  images = [image1, image2, image3, image4, image5, image6]
138
+ selectors = [select1, select2, select3, select4, select5, select6]
139
 
140
+ for i, radio in enumerate(selectors):
141
+ others = list(filter(lambda s: s != radio, selectors))
142
+ radio.change(
143
+ partial(update_state, i),
144
+ inputs=[radio, state],
145
+ outputs=others + [state]
146
+ )
147
 
148
  with (seeded_box := gr.Box()):
149
  seeded_image = image_block()
150
  clear_seed_button = gr.Button("Clear Seed")
151
  seeded_box.visible = False
152
+ clear_seed_button.click(
153
+ clear_seed,
154
+ inputs=[state],
155
+ outputs=selectors + [grid, seeded_box] + [state]
156
+ )
157
 
158
  all_images = images + [seeded_image]
159
  boxes = [grid, seeded_box]
160
+ infer_outputs = all_images + boxes + [state]
161
 
162
+ text.submit(
163
+ infer,
164
+ inputs=[text, state],
165
+ outputs=infer_outputs
166
+ )
167
+ btn.click(
168
+ infer,
169
+ inputs=[text, state],
170
+ outputs=infer_outputs
171
+ )
172
 
173
+ demo.launch(enable_queue=False)