File size: 2,291 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*- coding: utf-8 -*-
# Author: ximing
# Description: SVGDreamer - optim
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
from functools import partial

import torch
from omegaconf import DictConfig


def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None):
    param_dict = {}
    if optimizer_name == "adam":
        optimizer = partial(torch.optim.Adam, params=parameters)
        if lr is not None:
            optimizer = partial(torch.optim.Adam, params=parameters, lr=lr)
        if config.get('betas'):
            param_dict['betas'] = config.betas
        if config.get('weight_decay'):
            param_dict['weight_decay'] = config.weight_decay
        if config.get('eps'):
            param_dict['eps'] = config.eps
    elif optimizer_name == "adamW":
        optimizer = partial(torch.optim.AdamW, params=parameters)
        if lr is not None:
            optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr)
        if config.get('betas'):
            param_dict['betas'] = config.betas
        if config.get('weight_decay'):
            param_dict['weight_decay'] = config.weight_decay
        if config.get('eps'):
            param_dict['eps'] = config.eps
    elif optimizer_name == "radam":
        optimizer = partial(torch.optim.RAdam, params=parameters)
        if lr is not None:
            optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr)
        if config.get('betas'):
            param_dict['betas'] = config.betas
        if config.get('weight_decay'):
            param_dict['weight_decay'] = config.weight_decay
    elif optimizer_name == "sgd":
        optimizer = partial(torch.optim.SGD, params=parameters)
        if lr is not None:
            optimizer = partial(torch.optim.SGD, params=parameters, lr=lr)
        if config.get('momentum'):
            param_dict['momentum'] = config.momentum
        if config.get('weight_decay'):
            param_dict['weight_decay'] = config.weight_decay
        if config.get('nesterov'):
            param_dict['nesterov'] = config.nesterov
    else:
        raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.")

    if len(param_dict.keys()) > 0:
        return optimizer(**param_dict)
    else:
        return optimizer()