colossalai

Форк
0
31 строка · 1011.0 Байт
1
from typing import Optional
2

3
from ..base import Actor
4
from .configuration_chatglm import ChatGLMConfig
5
from .modeling_chatglm import ChatGLMForConditionalGeneration
6

7

8
class ChatGLMActor(Actor):
9
    """
10
    ChatGLM Actor model.
11

12
    Args:
13
        pretrained (str): Pretrained model name or path.
14
        config (ChatGLMConfig): Model config.
15
        checkpoint (bool): Enable gradient checkpointing.
16

17
    do not support lora for now.
18
    """
19

20
    def __init__(
21
        self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False
22
    ) -> None:
23
        if pretrained is not None:
24
            model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
25
        elif config is not None:
26
            model = ChatGLMForConditionalGeneration(config)
27
        else:
28
            model = ChatGLMForConditionalGeneration(ChatGLMConfig())
29
        if checkpoint:
30
            model.gradient_checkpointing_enable()
31
        super().__init__(model, lora_rank=0, lora_train_bias="none")
32

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.