h2o-llmstudio
35 строк · 826.0 Байт
1from functools import partial
2from typing import Any, List
3
4import bitsandbytes as bnb
5from torch import optim
6
7__all__ = ["Optimizers"]
8
9
10class Optimizers:
11"""Optimizers factory."""
12
13_optimizers = {
14"Adam": optim.Adam,
15"AdamW": optim.AdamW,
16"SGD": partial(optim.SGD, momentum=0.9, nesterov=True),
17"RMSprop": partial(optim.RMSprop, momentum=0.9, alpha=0.9),
18"Adadelta": optim.Adadelta,
19"AdamW8bit": bnb.optim.Adam8bit,
20}
21
22@classmethod
23def names(cls) -> List[str]:
24return sorted(cls._optimizers.keys())
25
26@classmethod
27def get(cls, name: str) -> Any:
28"""Access to Optimizers.
29
30Args:
31name: optimizer name
32Returns:
33A class to build the Optimizer
34"""
35return cls._optimizers.get(name)
36