|
# Add a Model |
|
|
|
Currently, we support HF models, some model APIs, and some third-party models. |
|
|
|
## Adding API Models |
|
|
|
To add a new API-based model, you need to create a new file named `mymodel_api.py` under `opencompass/models` directory. In this file, you should inherit from `BaseAPIModel` and implement the `generate` method for inference and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file. |
|
|
|
```python |
|
from ..base_api import BaseAPIModel |
|
|
|
class MyModelAPI(BaseAPIModel): |
|
|
|
is_api: bool = True |
|
|
|
def __init__(self, |
|
path: str, |
|
max_seq_len: int = 2048, |
|
query_per_second: int = 1, |
|
retry: int = 2, |
|
**kwargs): |
|
super().__init__(path=path, |
|
max_seq_len=max_seq_len, |
|
meta_template=meta_template, |
|
query_per_second=query_per_second, |
|
retry=retry) |
|
... |
|
|
|
def generate( |
|
self, |
|
inputs, |
|
max_out_len: int = 512, |
|
temperature: float = 0.7, |
|
) -> List[str]: |
|
"""Generate results given a list of inputs.""" |
|
pass |
|
|
|
def get_token_len(self, prompt: str) -> int: |
|
"""Get lengths of the tokenized string.""" |
|
pass |
|
``` |
|
|
|
## Adding Third-Party Models |
|
|
|
To add a new third-party model, you need to create a new file named `mymodel.py` under `opencompass/models` directory. In this file, you should inherit from `BaseModel` and implement the `generate` method for generative inference, the `get_ppl` method for discriminative inference, and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file. |
|
|
|
```python |
|
from ..base import BaseModel |
|
|
|
class MyModel(BaseModel): |
|
|
|
def __init__(self, |
|
pkg_root: str, |
|
ckpt_path: str, |
|
tokenizer_only: bool = False, |
|
meta_template: Optional[Dict] = None, |
|
**kwargs): |
|
... |
|
|
|
def get_token_len(self, prompt: str) -> int: |
|
"""Get lengths of the tokenized strings.""" |
|
pass |
|
|
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]: |
|
"""Generate results given a list of inputs. """ |
|
pass |
|
|
|
def get_ppl(self, |
|
inputs: List[str], |
|
mask_length: Optional[List[int]] = None) -> List[float]: |
|
"""Get perplexity scores given a list of inputs.""" |
|
pass |
|
``` |
|
|