colossalai

Форк
0
36 строк · 1.0 Кб
1
from typing import Optional
2

3
from transformers import BloomConfig, BloomForCausalLM
4

5
from ..base import Actor
6

7

8
class BLOOMActor(Actor):
9
    """
10
    BLOOM Actor model.
11

12
    Args:
13
        pretrained (str): Pretrained model name or path.
14
        config (BloomConfig): Model config.
15
        checkpoint (bool): Enable gradient checkpointing.
16
        lora_rank (int): LoRA rank.
17
        lora_train_bias (str): LoRA bias training mode.
18
    """
19

20
    def __init__(
21
        self,
22
        pretrained: str = None,
23
        config: Optional[BloomConfig] = None,
24
        checkpoint: bool = False,
25
        lora_rank: int = 0,
26
        lora_train_bias: str = "none",
27
    ) -> None:
28
        if pretrained is not None:
29
            model = BloomForCausalLM.from_pretrained(pretrained)
30
        elif config is not None:
31
            model = BloomForCausalLM(config)
32
        else:
33
            model = BloomForCausalLM(BloomConfig())
34
        if checkpoint:
35
            model.gradient_checkpointing_enable()
36
        super().__init__(model, lora_rank, lora_train_bias)
37

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.