# -*- coding: utf-8 -*- # Author: ximing # Description: SVGDreamer - optim # Copyright (c) 2023, XiMing Xing. # License: MIT License from functools import partial import torch from omegaconf import DictConfig def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None): param_dict = {} if optimizer_name == "adam": optimizer = partial(torch.optim.Adam, params=parameters) if lr is not None: optimizer = partial(torch.optim.Adam, params=parameters, lr=lr) if config.get('betas'): param_dict['betas'] = config.betas if config.get('weight_decay'): param_dict['weight_decay'] = config.weight_decay if config.get('eps'): param_dict['eps'] = config.eps elif optimizer_name == "adamW": optimizer = partial(torch.optim.AdamW, params=parameters) if lr is not None: optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr) if config.get('betas'): param_dict['betas'] = config.betas if config.get('weight_decay'): param_dict['weight_decay'] = config.weight_decay if config.get('eps'): param_dict['eps'] = config.eps elif optimizer_name == "radam": optimizer = partial(torch.optim.RAdam, params=parameters) if lr is not None: optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr) if config.get('betas'): param_dict['betas'] = config.betas if config.get('weight_decay'): param_dict['weight_decay'] = config.weight_decay elif optimizer_name == "sgd": optimizer = partial(torch.optim.SGD, params=parameters) if lr is not None: optimizer = partial(torch.optim.SGD, params=parameters, lr=lr) if config.get('momentum'): param_dict['momentum'] = config.momentum if config.get('weight_decay'): param_dict['weight_decay'] = config.weight_decay if config.get('nesterov'): param_dict['nesterov'] = config.nesterov else: raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.") if len(param_dict.keys()) > 0: return optimizer(**param_dict) else: return optimizer()