|
from typing import Generator, Sequence |
|
|
|
from metagpt.utils.token_counter import TOKEN_MAX, count_output_tokens |
|
|
|
|
|
def reduce_message_length( |
|
msgs: Generator[str, None, None], |
|
model_name: str, |
|
system_text: str, |
|
reserved: int = 0, |
|
) -> str: |
|
"""Reduce the length of concatenated message segments to fit within the maximum token size. |
|
|
|
Args: |
|
msgs: A generator of strings representing progressively shorter valid prompts. |
|
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo") |
|
system_text: The system prompts. |
|
reserved: The number of reserved tokens. |
|
|
|
Returns: |
|
The concatenated message segments reduced to fit within the maximum token size. |
|
|
|
Raises: |
|
RuntimeError: If it fails to reduce the concatenated message length. |
|
""" |
|
max_token = TOKEN_MAX.get(model_name, 2048) - count_output_tokens(system_text, model_name) - reserved |
|
for msg in msgs: |
|
if count_output_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX: |
|
return msg |
|
|
|
raise RuntimeError("fail to reduce message length") |
|
|
|
|
|
def generate_prompt_chunk( |
|
text: str, |
|
prompt_template: str, |
|
model_name: str, |
|
system_text: str, |
|
reserved: int = 0, |
|
) -> Generator[str, None, None]: |
|
"""Split the text into chunks of a maximum token size. |
|
|
|
Args: |
|
text: The text to split. |
|
prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}". |
|
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo") |
|
system_text: The system prompts. |
|
reserved: The number of reserved tokens. |
|
|
|
Yields: |
|
The chunk of text. |
|
""" |
|
paragraphs = text.splitlines(keepends=True) |
|
current_token = 0 |
|
current_lines = [] |
|
|
|
reserved = reserved + count_output_tokens(prompt_template + system_text, model_name) |
|
|
|
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100 |
|
|
|
while paragraphs: |
|
paragraph = paragraphs.pop(0) |
|
token = count_output_tokens(paragraph, model_name) |
|
if current_token + token <= max_token: |
|
current_lines.append(paragraph) |
|
current_token += token |
|
elif token > max_token: |
|
paragraphs = split_paragraph(paragraph) + paragraphs |
|
continue |
|
else: |
|
yield prompt_template.format("".join(current_lines)) |
|
current_lines = [paragraph] |
|
current_token = token |
|
|
|
if current_lines: |
|
yield prompt_template.format("".join(current_lines)) |
|
|
|
|
|
def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]: |
|
"""Split a paragraph into multiple parts. |
|
|
|
Args: |
|
paragraph: The paragraph to split. |
|
sep: The separator character. |
|
count: The number of parts to split the paragraph into. |
|
|
|
Returns: |
|
A list of split parts of the paragraph. |
|
""" |
|
for i in sep: |
|
sentences = list(_split_text_with_ends(paragraph, i)) |
|
if len(sentences) <= 1: |
|
continue |
|
ret = ["".join(j) for j in _split_by_count(sentences, count)] |
|
return ret |
|
return list(_split_by_count(paragraph, count)) |
|
|
|
|
|
def decode_unicode_escape(text: str) -> str: |
|
"""Decode a text with unicode escape sequences. |
|
|
|
Args: |
|
text: The text to decode. |
|
|
|
Returns: |
|
The decoded text. |
|
""" |
|
return text.encode("utf-8").decode("unicode_escape", "ignore") |
|
|
|
|
|
def _split_by_count(lst: Sequence, count: int): |
|
avg = len(lst) // count |
|
remainder = len(lst) % count |
|
start = 0 |
|
for i in range(count): |
|
end = start + avg + (1 if i < remainder else 0) |
|
yield lst[start:end] |
|
start = end |
|
|
|
|
|
def _split_text_with_ends(text: str, sep: str = "."): |
|
parts = [] |
|
for i in text: |
|
parts.append(i) |
|
if i == sep: |
|
yield "".join(parts) |
|
parts = [] |
|
if parts: |
|
yield "".join(parts) |
|
|