|
from __future__ import annotations |
|
|
|
from pathlib import Path |
|
import os |
|
from string import Formatter |
|
from typing import Any, Dict, List, Union |
|
|
|
from pydantic import model_validator |
|
|
|
from ....utils.registry import registry |
|
from .base import (DEFAULT_FORMATTER_MAPPING, BasePromptTemplate, |
|
_get_jinja2_variables_from_template, check_valid_template) |
|
|
|
|
|
@registry.register_prompt() |
|
class PromptTemplate(BasePromptTemplate): |
|
"""Schema to represent a prompt for an LLM. |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") |
|
""" |
|
|
|
|
|
template: str |
|
"""The prompt template.""" |
|
|
|
template_format: str = "jinja2" |
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" |
|
|
|
validate_template: bool = True |
|
"""Whether or not to try validating the template.""" |
|
|
|
role: str = "user" |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = "allow" |
|
|
|
def __init__(self, **kwargs: Any): |
|
super().__init__(**kwargs) |
|
input_variables = kwargs.get("input_variables", []) |
|
pre_filled_kv = {key: kwargs[key] for key in input_variables if key in kwargs.keys()} |
|
if pre_filled_kv: |
|
self.template = self.format(**pre_filled_kv) |
|
input_variables = list(set(input_variables) - set(pre_filled_kv.keys())) |
|
self.input_variables = input_variables |
|
|
|
def format(self, **kwargs: Any) -> str: |
|
"""Format the prompt with the inputs. |
|
|
|
Args: |
|
kwargs: Any arguments to be passed to the prompt template. |
|
|
|
Returns: |
|
A formatted string. |
|
|
|
Example: |
|
|
|
.. code-block:: python |
|
|
|
prompt.format(variable1="foo") |
|
""" |
|
kwargs = self._merge_partial_and_user_variables(**kwargs) |
|
return DEFAULT_FORMATTER_MAPPING[self.template_format].format( |
|
self.template, **kwargs |
|
) |
|
|
|
@model_validator(mode="after") |
|
def template_is_valid(self) -> "PromptTemplate": |
|
"""Check that template and input variables are consistent.""" |
|
if self.validate_template: |
|
all_inputs = self.input_variables + list(self.partial_variables) |
|
check_valid_template(self.template, self.template_format, all_inputs) |
|
return self |
|
|
|
@classmethod |
|
def from_examples( |
|
cls, |
|
examples: List[str], |
|
suffix: str, |
|
input_variables: List[str], |
|
example_separator: str = "\n\n", |
|
prefix: str = "", |
|
**kwargs: Any, |
|
) -> PromptTemplate: |
|
"""Take examples in list format with prefix and suffix to create a prompt. |
|
|
|
Intended to be used as a way to dynamically create a prompt from examples. |
|
|
|
Args: |
|
examples: List of examples to use in the prompt. |
|
suffix: String to go after the list of examples. Should generally |
|
set up the user's input. |
|
input_variables: A list of variable names the final prompt template |
|
will expect. |
|
example_separator: The separator to use in between examples. Defaults |
|
to two new line characters. |
|
prefix: String that should go before any examples. Generally includes |
|
examples. Default to an empty string. |
|
|
|
Returns: |
|
The final prompt generated. |
|
""" |
|
template = example_separator.join([prefix, *examples, suffix]) |
|
return cls(input_variables=input_variables, template=template, **kwargs) |
|
|
|
@classmethod |
|
def find_file(cls, start_dir, file_name): |
|
for root, dirs, files in os.walk(start_dir): |
|
if file_name in files: |
|
return os.path.join(root, file_name) |
|
return None |
|
|
|
@classmethod |
|
def from_file( |
|
cls, template_file: Union[str, Path], **kwargs: Any |
|
) -> PromptTemplate: |
|
"""Load a prompt from a file. |
|
|
|
Args: |
|
template_file: The path to the file containing the prompt template. |
|
input_variables: A list of variable names the final prompt template |
|
will expect. |
|
Returns: |
|
The prompt loaded from the file. |
|
""" |
|
original_file = template_file |
|
while True: |
|
if os.path.exists(template_file): |
|
with open(template_file, "r") as f: |
|
template = f.read() |
|
return cls.from_template(template=template, **kwargs) |
|
|
|
if "/" in template_file: |
|
template_file = "/".join(template_file.split("/", 1)[1:]) |
|
else: |
|
raise ValueError(f"the prompt file path ({original_file}) is not valid") |
|
|
|
@classmethod |
|
def from_template( |
|
cls, template: str, template_format: str = "jinja2", **kwargs: Any |
|
) -> PromptTemplate: |
|
"""Load a prompt template from a template.""" |
|
if template_format == "jinja2": |
|
|
|
input_variables = _get_jinja2_variables_from_template(template) |
|
|
|
else: |
|
input_variables = { |
|
v for _, v, _, _ in Formatter().parse(template) if v is not None |
|
} |
|
|
|
return cls( |
|
input_variables=list(sorted(input_variables)), |
|
template=template, |
|
template_format=template_format, |
|
**kwargs, |
|
) |
|
|
|
@classmethod |
|
def from_config(cls, config: Dict) -> PromptTemplate: |
|
"""Load a prompt template from a config.""" |
|
template = config.pop("template") |
|
if template.endswith(".prompt"): |
|
return cls.from_file(template, **config) |
|
else: |
|
return cls.from_template(template, **config) |
|
|