File size: 1,406 Bytes
4754e33
 
 
 
 
 
 
 
 
 
 
e6e7506
4754e33
 
 
 
 
e6e7506
4754e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import warnings
import subprocess
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import *

def main():
    # Create command-line argument parser
    parser = argparse.ArgumentParser(description='Run the extraction model.')
    parser.add_argument('--config', type=str, required=True,
                        help='Path to the YAML configuration file.')
    parser.add_argument('--tensor-parallel-size', type=int, default=2,
                        help='Tensor parallel size for the VLLM server.')
    parser.add_argument('--max-model-len', type=int, default=32768,
                        help='Maximum model length for the VLLM server.')

    # Parse command-line arguments
    args = parser.parse_args()

    # Load configuration
    config = load_extraction_config(args.config)
    # Model config
    model_config = config['model']
    if model_config['vllm_serve'] == False:
        warnings.warn("VLLM-deployed model will not be used for extraction. To enable VLLM, set vllm_serve to true in the configuration file.")
    model_name_or_path = model_config['model_name_or_path']
    command = f"vllm serve {model_name_or_path} --tensor-parallel-size {args.tensor_parallel_size} --max-model-len {args.max_model_len} --enforce-eager --port 8000"
    subprocess.run(command, shell=True)

if __name__ == "__main__":
    main()