|
import importlib |
|
|
|
import omegaconf |
|
|
|
from .models import ContrastiveModel, DiffuserSTDiT, ResNet18, SegDiTTransformer2DModel |
|
|
|
|
|
def parse_klass_arg(value, full_config): |
|
""" |
|
Parse an argument value that might represent a class, enum, or basic data type. |
|
This function tries to dynamically import and resolve nested attributes. |
|
It also resolves OmegaConf interpolations if found. |
|
""" |
|
if isinstance(value, str) and "." in value: |
|
|
|
if value.startswith("${") and value.endswith("}"): |
|
try: |
|
|
|
value = omegaconf.OmegaConf.resolve(full_config)[value[2:-1]] |
|
except Exception as e: |
|
print(f"Error resolving OmegaConf interpolation {value}: {e}") |
|
return None |
|
|
|
parts = value.split(".") |
|
for i in range(len(parts) - 1, 0, -1): |
|
module_name = ".".join(parts[:i]) |
|
attr_name = parts[i] |
|
try: |
|
module = importlib.import_module(module_name) |
|
result = module |
|
for j in range(i, len(parts)): |
|
result = getattr(result, parts[j]) |
|
return result |
|
except ImportError as e: |
|
continue |
|
except AttributeError as e: |
|
print( |
|
f"Warning: Could not resolve attribute {parts[j]} from {module_name}, error: {e}" |
|
) |
|
continue |
|
|
|
return ( |
|
value |
|
) |
|
return value |
|
|
|
|
|
def instantiate_class_from_config(config, *args, **kwargs): |
|
""" |
|
Dynamically instantiate a class based on a configuration object. |
|
Supports passing additional positional and keyword arguments. |
|
""" |
|
module_name, class_name = config.target.rsplit(".", 1) |
|
klass = globals().get(class_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
config = omegaconf.OmegaConf.to_container(config, resolve=True) |
|
|
|
|
|
|
|
conf_kwargs = { |
|
key: parse_klass_arg(value, config) for key, value in config["args"].items() |
|
} |
|
|
|
all_args = list(args) |
|
|
|
|
|
all_kwargs = {**conf_kwargs, **kwargs} |
|
|
|
|
|
instance = klass(*all_args, **all_kwargs) |
|
return instance |
|
|
|
|
|
def unscale_latents(latents, vae_scaling=None): |
|
if vae_scaling is not None: |
|
if latents.ndim == 4: |
|
v = (1, -1, 1, 1) |
|
elif latents.ndim == 5: |
|
v = (1, -1, 1, 1, 1) |
|
else: |
|
raise ValueError("Latents should be 4D or 5D") |
|
latents *= vae_scaling["std"].view(*v) |
|
latents += vae_scaling["mean"].view(*v) |
|
|
|
return latents |
|
|