File size: 3,515 Bytes
dab5199 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import importlib
import omegaconf
from .models import ContrastiveModel, DiffuserSTDiT, ResNet18, SegDiTTransformer2DModel
def parse_klass_arg(value, full_config):
"""
Parse an argument value that might represent a class, enum, or basic data type.
This function tries to dynamically import and resolve nested attributes.
It also resolves OmegaConf interpolations if found.
"""
if isinstance(value, str) and "." in value:
# Check if the value is an interpolation and try to resolve it
if value.startswith("${") and value.endswith("}"):
try:
# Attempt to resolve the interpolation directly using OmegaConf
value = omegaconf.OmegaConf.resolve(full_config)[value[2:-1]]
except Exception as e:
print(f"Error resolving OmegaConf interpolation {value}: {e}")
return None
parts = value.split(".")
for i in range(len(parts) - 1, 0, -1):
module_name = ".".join(parts[:i])
attr_name = parts[i]
try:
module = importlib.import_module(module_name)
result = module
for j in range(i, len(parts)):
result = getattr(result, parts[j])
return result
except ImportError as e:
continue
except AttributeError as e:
print(
f"Warning: Could not resolve attribute {parts[j]} from {module_name}, error: {e}"
)
continue
# print(f"Warning: Failed to import or resolve {value}. Falling back to string.")
return (
value # Return the original string if no valid import and resolution occurs
)
return value
def instantiate_class_from_config(config, *args, **kwargs):
"""
Dynamically instantiate a class based on a configuration object.
Supports passing additional positional and keyword arguments.
"""
module_name, class_name = config.target.rsplit(".", 1)
klass = globals().get(class_name)
# module = importlib.import_module(module_name)
# klass = getattr(module, class_name)
# Assuming config might be a part of a larger OmegaConf structure:
# if not isinstance(config, omegaconf.DictConfig):
# config = omegaconf.OmegaConf.create(config)
config = omegaconf.OmegaConf.to_container(config, resolve=True)
# Resolve args and kwargs from the configuration
# conf_args = [parse_klass_arg(arg, config) for arg in config.get('args', [])]
# conf_kwargs = {key: parse_klass_arg(value, config) for key, value in config.get('kwargs', {}).items()}
conf_kwargs = {
key: parse_klass_arg(value, config) for key, value in config["args"].items()
}
# Combine conf_args with explicitly passed *args
all_args = list(args) # + conf_args
# Combine conf_kwargs with explicitly passed **kwargs
all_kwargs = {**conf_kwargs, **kwargs}
# Instantiate the class with the processed arguments
instance = klass(*all_args, **all_kwargs)
return instance
def unscale_latents(latents, vae_scaling=None):
if vae_scaling is not None:
if latents.ndim == 4:
v = (1, -1, 1, 1)
elif latents.ndim == 5:
v = (1, -1, 1, 1, 1)
else:
raise ValueError("Latents should be 4D or 5D")
latents *= vae_scaling["std"].view(*v)
latents += vae_scaling["mean"].view(*v)
return latents
|