colossalai

Форк
0
38 строк · 1.2 Кб
1
from typing import Optional
2

3
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
4
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
5

6
from ..base import Actor
7

8

9
class GPTActor(Actor):
10
    """
11
    GPT Actor model.
12

13
    Args:
14
        pretrained (str): Pretrained model name or path.
15
        config (GPT2Config): Model config.
16
        checkpoint (bool): Enable gradient checkpointing.
17
        lora_rank (int): Rank of the LoRa layer.
18
        lora_train_bias (str): Bias training strategy for the LoRa layer.
19
    """
20

21
    def __init__(
22
        self,
23
        pretrained: Optional[str] = None,
24
        config: Optional[GPT2Config] = None,
25
        checkpoint: bool = False,
26
        lora_rank: int = 0,
27
        lora_train_bias: str = "none",
28
        **kwargs,
29
    ) -> None:
30
        if pretrained is not None:
31
            model = GPT2LMHeadModel.from_pretrained(pretrained)
32
        elif config is not None:
33
            model = GPT2LMHeadModel(config)
34
        else:
35
            model = GPT2LMHeadModel(GPT2Config())
36
        if checkpoint:
37
            model.gradient_checkpointing_enable()
38
        super().__init__(model, lora_rank, lora_train_bias, **kwargs)
39

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.