|
import argparse |
|
import warnings |
|
import subprocess |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from utils import * |
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser(description='Run the extraction model.') |
|
parser.add_argument('--config', type=str, required=True, |
|
help='Path to the YAML configuration file.') |
|
parser.add_argument('--tensor-parallel-size', type=int, default=2, |
|
help='Tensor parallel size for the VLLM server.') |
|
parser.add_argument('--max-model-len', type=int, default=32768, |
|
help='Maximum model length for the VLLM server.') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
config = load_extraction_config(args.config) |
|
|
|
model_config = config['model'] |
|
if model_config['vllm_serve'] == False: |
|
warnings.warn("VLLM-deployed model will not be used for extraction. To enable VLLM, set vllm_serve to true in the configuration file.") |
|
model_name_or_path = model_config['model_name_or_path'] |
|
command = f"vllm serve {model_name_or_path} --tensor-parallel-size {args.tensor_parallel_size} --max-model-len {args.max_model_len} --enforce-eager --port 8000" |
|
subprocess.run(command, shell=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|