h2o-llmstudio

Форк
0
/
optimizers.py 
35 строк · 826.0 Байт
1
from functools import partial
2
from typing import Any, List
3

4
import bitsandbytes as bnb
5
from torch import optim
6

7
__all__ = ["Optimizers"]
8

9

10
class Optimizers:
11
    """Optimizers factory."""
12

13
    _optimizers = {
14
        "Adam": optim.Adam,
15
        "AdamW": optim.AdamW,
16
        "SGD": partial(optim.SGD, momentum=0.9, nesterov=True),
17
        "RMSprop": partial(optim.RMSprop, momentum=0.9, alpha=0.9),
18
        "Adadelta": optim.Adadelta,
19
        "AdamW8bit": bnb.optim.Adam8bit,
20
    }
21

22
    @classmethod
23
    def names(cls) -> List[str]:
24
        return sorted(cls._optimizers.keys())
25

26
    @classmethod
27
    def get(cls, name: str) -> Any:
28
        """Access to Optimizers.
29

30
        Args:
31
            name: optimizer name
32
        Returns:
33
            A class to build the Optimizer
34
        """
35
        return cls._optimizers.get(name)
36

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.