Spaces:
Sleeping
Sleeping
import click | |
import os | |
import sys | |
import importlib | |
import importlib.util | |
import json | |
from click.core import Context, Option | |
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ | |
from ding.framework import Parallel | |
from ding.entry.cli_parsers import PLATFORM_PARSERS | |
def print_version(ctx: Context, param: Option, value: bool) -> None: | |
if not value or ctx.resilient_parsing: | |
return | |
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) | |
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) | |
ctx.exit() | |
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) | |
def cli_ditask(*args, **kwargs): | |
return _cli_ditask(*args, **kwargs) | |
def _parse_platform_args(platform: str, platform_spec: str, all_args: dict): | |
if platform_spec: | |
try: | |
if os.path.splitext(platform_spec) == "json": | |
with open(platform_spec) as f: | |
platform_spec = json.load(f) | |
else: | |
platform_spec = json.loads(platform_spec) | |
except: | |
click.echo("platform_spec is not a valid json!") | |
exit(1) | |
if platform not in PLATFORM_PARSERS: | |
click.echo("platform type is invalid! type: {}".format(platform)) | |
exit(1) | |
all_args.pop("platform") | |
all_args.pop("platform_spec") | |
try: | |
parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args) | |
except Exception as e: | |
click.echo("error when parse platform spec configure: {}".format(e)) | |
raise e | |
return parsed_args | |
def _cli_ditask( | |
package: str, | |
main: str, | |
parallel_workers: int, | |
protocol: str, | |
ports: str, | |
attach_to: str, | |
address: str, | |
labels: str, | |
node_ids: str, | |
topology: str, | |
mq_type: str, | |
redis_host: str, | |
redis_port: int, | |
startup_interval: int, | |
local_rank: int = 0, | |
platform: str = None, | |
platform_spec: str = None, | |
): | |
# Parse entry point | |
all_args = locals() | |
if platform: | |
parsed_args = _parse_platform_args(platform, platform_spec, all_args) | |
return _cli_ditask(**parsed_args) | |
if not package: | |
package = os.getcwd() | |
sys.path.append(package) | |
if main is None: | |
mod_name = os.path.basename(package) | |
mod_name, _ = os.path.splitext(mod_name) | |
func_name = "main" | |
else: | |
mod_name, func_name = main.rsplit(".", 1) | |
root_mod_name = mod_name.split(".", 1)[0] | |
sys.path.append(os.path.join(package, root_mod_name)) | |
mod = importlib.import_module(mod_name) | |
main_func = getattr(mod, func_name) | |
# Parse arguments | |
ports = ports or 50515 | |
if not isinstance(ports, int): | |
ports = ports.split(",") | |
ports = list(map(lambda i: int(i), ports)) | |
ports = ports[0] if len(ports) == 1 else ports | |
if attach_to: | |
attach_to = attach_to.split(",") | |
attach_to = list(map(lambda s: s.strip(), attach_to)) | |
if labels: | |
labels = labels.split(",") | |
labels = set(map(lambda s: s.strip(), labels)) | |
if node_ids and not isinstance(node_ids, int): | |
node_ids = node_ids.split(",") | |
node_ids = list(map(lambda i: int(i), node_ids)) | |
Parallel.runner( | |
n_parallel_workers=parallel_workers, | |
ports=ports, | |
protocol=protocol, | |
topology=topology, | |
attach_to=attach_to, | |
address=address, | |
labels=labels, | |
node_ids=node_ids, | |
mq_type=mq_type, | |
redis_host=redis_host, | |
redis_port=redis_port, | |
startup_interval=startup_interval | |
)(main_func) | |