|
import json |
|
import logging |
|
import os |
|
import pathlib |
|
import re |
|
from copy import deepcopy |
|
from pathlib import Path |
|
from typing import Optional, Tuple, Union, Dict, Any |
|
import torch |
|
|
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] |
|
_MODEL_CONFIGS = {} |
|
|
|
|
|
def _natural_key(string_): |
|
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] |
|
|
|
|
|
def _rescan_model_configs(): |
|
global _MODEL_CONFIGS |
|
|
|
config_ext = (".json",) |
|
config_files = [] |
|
for config_path in _MODEL_CONFIG_PATHS: |
|
if config_path.is_file() and config_path.suffix in config_ext: |
|
config_files.append(config_path) |
|
elif config_path.is_dir(): |
|
for ext in config_ext: |
|
config_files.extend(config_path.glob(f"*{ext}")) |
|
for cf in config_files: |
|
with open(cf, "r", encoding="utf8") as f: |
|
model_cfg = json.load(f) |
|
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): |
|
_MODEL_CONFIGS[cf.stem] = model_cfg |
|
|
|
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) |
|
|
|
|
|
_rescan_model_configs() |
|
|
|
|
|
def list_models(): |
|
"""enumerate available model architectures based on config files""" |
|
return list(_MODEL_CONFIGS.keys()) |
|
|
|
|
|
def add_model_config(path): |
|
"""add model config path or file and update registry""" |
|
if not isinstance(path, Path): |
|
path = Path(path) |
|
_MODEL_CONFIG_PATHS.append(path) |
|
_rescan_model_configs() |
|
|
|
|
|
def get_model_config(model_name): |
|
if model_name in _MODEL_CONFIGS: |
|
return deepcopy(_MODEL_CONFIGS[model_name]) |
|
else: |
|
return None |
|
|