|
from __future__ import annotations |
|
|
|
import sys |
|
|
|
sys.path.append("..") |
|
|
|
import json |
|
from abc import ABC, abstractmethod |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union |
|
|
|
import yaml |
|
from pydantic import Field |
|
|
|
from ....base import BotBase |
|
from .formatter import FStringFormatter, JinjiaFormatter |
|
from .parser import BaseOutputParser, DictParser, ListParser, StrParser |
|
|
|
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { |
|
"f-string": FStringFormatter(), |
|
"jinja2": JinjiaFormatter(), |
|
} |
|
|
|
_OUTPUT_PARSER = { |
|
"StrParser": StrParser, |
|
"ListParser": ListParser, |
|
"DictParser": DictParser, |
|
} |
|
|
|
|
|
def check_valid_template( |
|
template: str, template_format: str, input_variables: List[str] |
|
) -> None: |
|
"""Check that template string is valid.""" |
|
if template_format not in DEFAULT_FORMATTER_MAPPING: |
|
valid_formats = list(DEFAULT_FORMATTER_MAPPING) |
|
raise ValueError( |
|
f"Invalid template format. Got `{template_format}`;" |
|
f" should be one of {valid_formats}" |
|
) |
|
try: |
|
formatter = DEFAULT_FORMATTER_MAPPING[template_format] |
|
formatter.validate(template, input_variables) |
|
except KeyError as e: |
|
raise ValueError( |
|
"Invalid prompt schema; check for mismatched or missing input parameters. " |
|
+ str(e) |
|
) |
|
|
|
|
|
def _get_jinja2_variables_from_template(template: str) -> Set[str]: |
|
try: |
|
from jinja2 import Environment, meta |
|
except ImportError: |
|
raise ImportError( |
|
"jinja2 not installed, which is needed to use the jinja2_formatter. " |
|
"Please install it with `pip install jinja2`." |
|
) |
|
env = Environment() |
|
ast = env.parse(template) |
|
variables = meta.find_undeclared_variables(ast) |
|
return variables |
|
|
|
|
|
class BasePromptTemplate(BotBase, ABC): |
|
"""Base class for all prompt templates, returning a prompt.""" |
|
|
|
input_variables: List[str] |
|
"""A list of the names of the variables the prompt template expects.""" |
|
output_parser: Optional[BaseOutputParser] = None |
|
"""How to parse the output of calling an LLM on this formatted prompt.""" |
|
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( |
|
default_factory=dict |
|
) |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = "forbid" |
|
arbitrary_types_allowed = True |
|
|
|
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: |
|
"""Return a partial of the prompt template.""" |
|
prompt_dict = self.__dict__.copy() |
|
prompt_dict["input_variables"] = list( |
|
set(self.input_variables).difference(kwargs) |
|
) |
|
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} |
|
return type(self)(**prompt_dict) |
|
|
|
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: |
|
|
|
partial_kwargs = { |
|
k: v if isinstance(v, str) else v() |
|
for k, v in self.partial_variables.items() |
|
} |
|
return {**partial_kwargs, **kwargs} |
|
|
|
@abstractmethod |
|
def format(self, **kwargs: Any) -> str: |
|
"""Format the prompt with the inputs. |
|
|
|
Args: |
|
kwargs: Any arguments to be passed to the prompt template. |
|
|
|
Returns: |
|
A formatted string. |
|
|
|
Example: |
|
|
|
.. code-block:: python |
|
|
|
prompt.format(variable1="foo") |
|
""" |
|
|
|
def save(self, file_path: Union[Path, str]) -> None: |
|
"""Save the prompt. |
|
|
|
Args: |
|
file_path: Path to directory to save prompt to. |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
prompt.save(file_path="path/prompt.yaml") |
|
""" |
|
if self.partial_variables: |
|
raise ValueError("Cannot save prompt with partial variables.") |
|
|
|
if isinstance(file_path, str): |
|
save_path = Path(file_path) |
|
else: |
|
save_path = file_path |
|
|
|
directory_path = save_path.parent |
|
directory_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
prompt_dict = self.dict() |
|
|
|
if save_path.suffix == ".json": |
|
with open(file_path, "w") as f: |
|
json.dump(prompt_dict, f, indent=4) |
|
elif save_path.suffix == ".yaml": |
|
with open(file_path, "w") as f: |
|
yaml.dump(prompt_dict, f, default_flow_style=False) |
|
else: |
|
raise ValueError(f"{save_path} must be json or yaml") |
|
|
|
@classmethod |
|
@abstractmethod |
|
def from_template(cls, template: str, **kwargs: Any) -> BasePromptTemplate: |
|
"""Create a prompt from a template.""" |
|
|
|
@classmethod |
|
@abstractmethod |
|
def from_config(cls, config: Dict) -> BasePromptTemplate: |
|
"""Create a prompt from config.""" |
|
|