Spaces:
Runtime error
Runtime error
batched generation
Browse files- generate.py +27 -6
- gradio_app.py +3 -3
generate.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3 |
import logging
|
4 |
import regex
|
5 |
import time
|
|
|
6 |
from pathlib import Path
|
7 |
from typing import Annotated, Iterator
|
8 |
|
@@ -22,14 +23,16 @@ logger = logging.getLogger(__name__)
|
|
22 |
|
23 |
|
24 |
logger.warning("Loading model...")
|
25 |
-
model_id = "google/gemma-2b-it"
|
26 |
-
# model_id = "Qwen/Qwen1.5-0.5B-Chat"
|
27 |
if torch.backends.mps.is_available():
|
28 |
device = "mps"
|
29 |
-
|
|
|
30 |
else:
|
31 |
device = "cuda"
|
32 |
-
|
|
|
|
|
|
|
33 |
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
35 |
sampler = PenalizedMultinomialSampler()
|
@@ -95,6 +98,23 @@ def samples_prommpt(filename: str, prompt: str, columns: str):
|
|
95 |
{{ prompt }}
|
96 |
"""
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
99 |
filename = Path(filename).stem
|
100 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
@@ -134,7 +154,8 @@ def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int,
|
|
134 |
tokenize=False,
|
135 |
add_generation_prompt=True
|
136 |
)
|
137 |
-
|
138 |
-
|
|
|
139 |
yield json.dumps(sample, ensure_ascii=False) + "\n"
|
140 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
|
|
|
3 |
import logging
|
4 |
import regex
|
5 |
import time
|
6 |
+
from itertools import chain, islice
|
7 |
from pathlib import Path
|
8 |
from typing import Annotated, Iterator
|
9 |
|
|
|
23 |
|
24 |
|
25 |
logger.warning("Loading model...")
|
|
|
|
|
26 |
if torch.backends.mps.is_available():
|
27 |
device = "mps"
|
28 |
+
model_id = "Qwen/Qwen1.5-0.5B-Chat"
|
29 |
+
batch_size = 4
|
30 |
else:
|
31 |
device = "cuda"
|
32 |
+
model_id = "google/gemma-2b-it"
|
33 |
+
batch_size = 20
|
34 |
+
|
35 |
+
model = models.transformers(model_id, device=device)
|
36 |
|
37 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
38 |
sampler = PenalizedMultinomialSampler()
|
|
|
98 |
{{ prompt }}
|
99 |
"""
|
100 |
|
101 |
+
|
102 |
+
def stream_json_objects_from_batched_tokens_generator(batched_tokens_generator: Iterator[list[str]], json_field: str) -> Iterator[dict]:
|
103 |
+
first_batch = next(batched_tokens_generator)
|
104 |
+
batch_size = len(first_batch)
|
105 |
+
streams = [""] * batch_size
|
106 |
+
skips = [0] * batch_size
|
107 |
+
for tokens_batch in chain([first_batch], batched_tokens_generator):
|
108 |
+
for stream_idx, token in enumerate(tokens_batch):
|
109 |
+
streams[stream_idx] += token
|
110 |
+
try:
|
111 |
+
for stream_sample in islice(ijson.items(StringIteratorIO(streams[stream_idx].__iter__()), json_field + ".item", buf_size=1), skips[stream_idx], None):
|
112 |
+
yield stream_sample
|
113 |
+
skips[stream_idx] = +1
|
114 |
+
except ijson.IncompleteJSONError:
|
115 |
+
pass
|
116 |
+
|
117 |
+
|
118 |
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
119 |
filename = Path(filename).stem
|
120 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
|
|
154 |
tokenize=False,
|
155 |
add_generation_prompt=True
|
156 |
)
|
157 |
+
batched_samples_generator_tokens = samples_generator.stream([text] * batch_size, rng=rng)
|
158 |
+
json_field = list(Dataset.model_fields)[0]
|
159 |
+
for _, sample in zip(range(size), stream_json_objects_from_batched_tokens_generator(batched_samples_generator_tokens, json_field=json_field)):
|
160 |
yield json.dumps(sample, ensure_ascii=False) + "\n"
|
161 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
|
gradio_app.py
CHANGED
@@ -6,11 +6,11 @@ import io
|
|
6 |
import pandas as pd
|
7 |
import spaces
|
8 |
|
9 |
-
from generate import model_id, stream_jsonl_file
|
10 |
|
11 |
-
MAX_SIZE = 20
|
12 |
DEFAULT_SEED = 42
|
13 |
-
DEFAULT_SIZE =
|
14 |
|
15 |
@spaces.GPU(duration=120)
|
16 |
def stream_output(query: str, continue_content: str = ""):
|
|
|
6 |
import pandas as pd
|
7 |
import spaces
|
8 |
|
9 |
+
from generate import model_id, stream_jsonl_file, batch_size
|
10 |
|
11 |
+
MAX_SIZE = 20 * batch_size
|
12 |
DEFAULT_SEED = 42
|
13 |
+
DEFAULT_SIZE = 5 * batch_size
|
14 |
|
15 |
@spaces.GPU(duration=120)
|
16 |
def stream_output(query: str, continue_content: str = ""):
|