|
from __future__ import annotations |
|
|
|
import re |
|
from abc import ABC, abstractmethod |
|
from collections import defaultdict |
|
from collections.abc import Hashable |
|
from pathlib import Path |
|
from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union |
|
|
|
from PIL import Image |
|
from pydantic import Field, field_validator |
|
from tenacity import retry, stop_after_attempt, stop_after_delay |
|
|
|
from ...base import BotBase |
|
from ...utils.env import EnvVar |
|
from ...utils.general import LRUCache |
|
from ...utils.registry import registry |
|
from .prompt.base import _OUTPUT_PARSER, StrParser |
|
from .prompt.parser import BaseOutputParser |
|
from .prompt.prompt import PromptTemplate |
|
from .schemas import Message |
|
import copy |
|
from collections.abc import Iterator |
|
|
|
T = TypeVar("T", str, dict, list) |
|
|
|
|
|
class BaseLLM(BotBase, ABC): |
|
cache: bool = False |
|
lru_cache: LRUCache = Field(default=LRUCache(EnvVar.LLM_CACHE_NUM)) |
|
|
|
@property |
|
def workflow_instance_id(self) -> str: |
|
if hasattr(self, "_parent"): |
|
return self._parent.workflow_instance_id |
|
return None |
|
|
|
@workflow_instance_id.setter |
|
def workflow_instance_id(self, value: str): |
|
if hasattr(self, "_parent"): |
|
self._parent.workflow_instance_id = value |
|
|
|
@abstractmethod |
|
def _call(self, records: List[Message], **kwargs) -> str: |
|
"""Run the LLM on the given prompt and input.""" |
|
|
|
async def _acall(self, records: List[Message], **kwargs) -> str: |
|
"""Run the LLM on the given prompt and input.""" |
|
raise NotImplementedError("Async generation not implemented for this LLM.") |
|
|
|
def generate(self, records: List[Message], **kwargs) -> str: |
|
"""Run the LLM on the given prompt and input.""" |
|
if self.cache: |
|
key = self._cache_key(records) |
|
cached_res = self.lru_cache.get(key) |
|
if cached_res: |
|
return cached_res |
|
else: |
|
gen = self._call(records, **kwargs) |
|
self.lru_cache.put(key, gen) |
|
return gen |
|
else: |
|
return self._call(records, **kwargs) |
|
|
|
@retry( |
|
stop=( |
|
stop_after_delay(EnvVar.STOP_AFTER_DELAY) |
|
| stop_after_attempt(EnvVar.STOP_AFTER_ATTEMPT) |
|
), |
|
reraise=True, |
|
) |
|
async def agenerate(self, records: List[str], **kwargs) -> str: |
|
"""Run the LLM on the given prompt and input.""" |
|
if self.cache: |
|
key = self._cache_key(records) |
|
cached_res = self.lru_cache.get(key) |
|
if cached_res: |
|
return cached_res |
|
else: |
|
gen = await self._acall(records, **kwargs) |
|
self.lru_cache.put(key, gen) |
|
return gen |
|
else: |
|
return await self._acall(records, **kwargs) |
|
|
|
def _cache_key(self, records: List[Message]) -> int: |
|
return str([item.model_dump() for item in records]) |
|
|
|
def dict(self, *args, **kwargs): |
|
kwargs["exclude"] = {"lru_cache"} |
|
return super().model_dump(*args, **kwargs) |
|
|
|
def json(self, *args, **kwargs): |
|
kwargs["exclude"] = {"lru_cache"} |
|
return super().model_dump_json(*args, **kwargs) |
|
|
|
|
|
T = TypeVar("T", str, dict, list) |
|
|
|
|
|
class BaseLLMBackend(BotBase, ABC): |
|
"""Prompts prepare and LLM infer""" |
|
|
|
output_parser: Optional[BaseOutputParser] = None |
|
prompts: List[PromptTemplate] = [] |
|
llm: BaseLLM |
|
|
|
|
|
@property |
|
def token_usage(self): |
|
if not hasattr(self, 'workflow_instance_id'): |
|
raise AttributeError("workflow_instance_id not set") |
|
return dict(self.stm(self.workflow_instance_id).get('token_usage', defaultdict(int))) |
|
|
|
@field_validator("output_parser", mode="before") |
|
@classmethod |
|
def set_output_parser(cls, output_parser: Union[BaseOutputParser, Dict, None]): |
|
if output_parser is None: |
|
return StrParser() |
|
elif isinstance(output_parser, BaseOutputParser): |
|
return output_parser |
|
elif isinstance(output_parser, dict): |
|
return _OUTPUT_PARSER[output_parser["name"]](**output_parser) |
|
else: |
|
raise ValueError |
|
|
|
@field_validator("prompts", mode="before") |
|
@classmethod |
|
def set_prompts( |
|
cls, prompts: List[Union[PromptTemplate, Dict, str]] |
|
) -> List[PromptTemplate]: |
|
init_prompts = [] |
|
for prompt in prompts: |
|
prompt = copy.deepcopy(prompt) |
|
if isinstance(prompt, Path): |
|
if prompt.suffix == ".prompt": |
|
init_prompts.append(PromptTemplate.from_file(prompt)) |
|
elif isinstance(prompt, str): |
|
if prompt.endswith(".prompt"): |
|
init_prompts.append(PromptTemplate.from_file(prompt)) |
|
init_prompts.append(PromptTemplate.from_template(prompt)) |
|
elif isinstance(prompt, dict): |
|
init_prompts.append(PromptTemplate.from_config(prompt)) |
|
elif isinstance(prompt, PromptTemplate): |
|
init_prompts.append(prompt) |
|
else: |
|
raise ValueError( |
|
"Prompt only support str, dict and PromptTemplate object" |
|
) |
|
return init_prompts |
|
|
|
@field_validator("llm", mode="before") |
|
@classmethod |
|
def set_llm(cls, llm: Union[BaseLLM, Dict]): |
|
if isinstance(llm, dict): |
|
return registry.get_llm(llm["name"])(**llm) |
|
elif isinstance(llm, BaseLLM): |
|
return llm |
|
else: |
|
raise ValueError("LLM only support dict and BaseLLM object") |
|
|
|
def prep_prompt( |
|
self, input_list: List[Dict[str, Any]], prompts=None, **kwargs |
|
) -> List[List[Message]]: |
|
"""Prepare prompts from inputs.""" |
|
if prompts is None: |
|
prompts = self.prompts |
|
images = [] |
|
if len(kwargs_images := kwargs.get("images", [])): |
|
images = kwargs_images |
|
processed_prompts = [] |
|
for inputs in input_list: |
|
records = [] |
|
for prompt in prompts: |
|
selected_inputs = {k: inputs.get(k, "") for k in prompt.input_variables} |
|
prompt_str = prompt.template |
|
parts = re.split(r"(\{\{.*?\}\})", prompt_str) |
|
formatted_parts = [] |
|
for part in parts: |
|
if part.startswith("{{") and part.endswith("}}"): |
|
part = part[2:-2].strip() |
|
value = selected_inputs[part] |
|
if isinstance(value, (Image.Image, list)): |
|
formatted_parts.extend( |
|
[value] if isinstance(value, Image.Image) else value |
|
) |
|
else: |
|
formatted_parts.append(str(value)) |
|
else: |
|
formatted_parts.append(str(part)) |
|
formatted_parts = ( |
|
formatted_parts[0] if len(formatted_parts) == 1 else formatted_parts |
|
) |
|
if prompt.role == "system": |
|
records.append(Message.system(formatted_parts)) |
|
elif prompt.role == "user": |
|
records.append(Message.user(formatted_parts)) |
|
if len(images): |
|
records.append(Message.user(images)) |
|
processed_prompts.append(records) |
|
return processed_prompts |
|
|
|
def infer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]: |
|
prompts = self.prep_prompt(input_list, **kwargs) |
|
res = [] |
|
stm_token_usage = self.stm(self.workflow_instance_id).get('token_usage', defaultdict(int)) |
|
|
|
def process_stream(self, stream_output): |
|
for chunk in stream_output: |
|
if chunk.usage is not None: |
|
for key, value in chunk.usage.dict().items(): |
|
if key in ["prompt_tokens", "completion_tokens", 'total_tokens']: |
|
if value is not None: |
|
stm_token_usage[key] += value |
|
self.stm(self.workflow_instance_id)['token_usage'] = stm_token_usage |
|
|
|
yield chunk |
|
|
|
for prompt in prompts: |
|
output = self.llm.generate(prompt, **kwargs) |
|
if not isinstance(output, Iterator): |
|
for key, value in output.get("usage", {}).items(): |
|
if key in ["prompt_tokens", "completion_tokens", 'total_tokens']: |
|
if value is not None: |
|
stm_token_usage[key] += value |
|
if not self.llm.stream: |
|
for choice in output["choices"]: |
|
if choice.get("message"): |
|
choice["message"]["content"] = self.output_parser.parse( |
|
choice["message"]["content"] |
|
) |
|
res.append(output) |
|
else: |
|
res.append(process_stream(self, output)) |
|
|
|
self.stm(self.workflow_instance_id)['token_usage'] = stm_token_usage |
|
return res |
|
|
|
async def ainfer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]: |
|
prompts = self.prep_prompt(input_list) |
|
res = [] |
|
for prompt in prompts: |
|
output = await self.llm.agenerate(prompt, **kwargs) |
|
for key, value in output["usage"].items(): |
|
self.token_usage[key] += value |
|
for choice in output["choices"]: |
|
if choice.get("message"): |
|
choice["message"]["content"] = self.output_parser.parse( |
|
choice["message"]["content"] |
|
) |
|
res.append(output) |
|
return res |
|
|
|
def simple_infer(self, **kwargs: Any) -> T: |
|
return self.infer([kwargs])[0] |
|
|
|
async def simple_ainfer(self, **kwargs: Any) -> T: |
|
return await self.ainfer([kwargs])[0] |
|
|