|
import logging |
|
from typing import Any, Optional |
|
|
|
import torch |
|
from omegaconf import DictConfig, OmegaConf |
|
from safetensors.torch import load_model |
|
|
|
|
|
def load_config(cfg_path: str) -> Any: |
|
""" |
|
Load and resolve a configuration file. |
|
Args: |
|
cfg_path (str): The path to the configuration file. |
|
Returns: |
|
Any: The loaded and resolved configuration object. |
|
Raises: |
|
AssertionError: If the loaded configuration is not an instance of DictConfig. |
|
""" |
|
|
|
cfg = OmegaConf.load(cfg_path) |
|
OmegaConf.resolve(cfg) |
|
assert isinstance(cfg, DictConfig) |
|
return cfg |
|
|
|
|
|
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: |
|
""" |
|
Parses a configuration dictionary into a structured configuration object. |
|
Args: |
|
cfg_type (Any): The type of the structured configuration object. |
|
cfg (DictConfig): The configuration dictionary to be parsed. |
|
Returns: |
|
Any: The structured configuration object created from the dictionary. |
|
""" |
|
|
|
scfg = OmegaConf.structured(cfg_type(**cfg)) |
|
return scfg |
|
|
|
|
|
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: |
|
""" |
|
Load a safetensors checkpoint into a PyTorch model. |
|
The model is updated in place. |
|
|
|
Args: |
|
model: PyTorch model to load weights into |
|
ckpt_path: Path to the safetensors checkpoint file |
|
|
|
Returns: |
|
None |
|
""" |
|
assert ckpt_path.endswith(".safetensors"), ( |
|
f"Checkpoint path '{ckpt_path}' is not a safetensors file" |
|
) |
|
|
|
load_model(model, ckpt_path) |
|
|