colossalai
36 строк · 1.0 Кб
1from typing import Optional
2
3from transformers import BloomConfig, BloomForCausalLM
4
5from ..base import Actor
6
7
8class BLOOMActor(Actor):
9"""
10BLOOM Actor model.
11
12Args:
13pretrained (str): Pretrained model name or path.
14config (BloomConfig): Model config.
15checkpoint (bool): Enable gradient checkpointing.
16lora_rank (int): LoRA rank.
17lora_train_bias (str): LoRA bias training mode.
18"""
19
20def __init__(
21self,
22pretrained: str = None,
23config: Optional[BloomConfig] = None,
24checkpoint: bool = False,
25lora_rank: int = 0,
26lora_train_bias: str = "none",
27) -> None:
28if pretrained is not None:
29model = BloomForCausalLM.from_pretrained(pretrained)
30elif config is not None:
31model = BloomForCausalLM(config)
32else:
33model = BloomForCausalLM(BloomConfig())
34if checkpoint:
35model.gradient_checkpointing_enable()
36super().__init__(model, lora_rank, lora_train_bias)
37