colossalai

Форк
0
70 строк · 2.3 Кб
1
from abc import ABC, abstractmethod
2
from dataclasses import dataclass
3
from typing import Optional
4

5
import torch
6
from coati.models.base import Actor, Critic, RewardModel
7

8

9
@dataclass
10
class Experience:
11
    """Experience is a batch of data.
12
    These data should have the sequence length and number of actions.
13
    Left padding for sequences is applied.
14

15
    Shapes of each tensor:
16
    sequences: (B, S)
17
    action_log_probs: (B, A)
18
    values: (B)
19
    reward: (B)
20
    advantages: (B)
21
    attention_mask: (B, S)
22
    action_mask: (B, A)
23

24
    "A" is the number of actions.
25
    """
26

27
    sequences: torch.Tensor
28
    action_log_probs: torch.Tensor
29
    values: torch.Tensor
30
    reward: torch.Tensor
31
    advantages: torch.Tensor
32
    attention_mask: Optional[torch.LongTensor]
33
    action_mask: Optional[torch.BoolTensor]
34

35
    @torch.no_grad()
36
    def to_device(self, device: torch.device) -> None:
37
        self.sequences = self.sequences.to(device)
38
        self.action_log_probs = self.action_log_probs.to(device)
39
        self.values = self.values.to(device)
40
        self.reward = self.reward.to(device)
41
        self.advantages = self.advantages.to(device)
42
        if self.attention_mask is not None:
43
            self.attention_mask = self.attention_mask.to(device)
44
        if self.action_mask is not None:
45
            self.action_mask = self.action_mask.to(device)
46

47
    def pin_memory(self):
48
        self.sequences = self.sequences.pin_memory()
49
        self.action_log_probs = self.action_log_probs.pin_memory()
50
        self.values = self.values.pin_memory()
51
        self.reward = self.reward.pin_memory()
52
        self.advantages = self.advantages.pin_memory()
53
        if self.attention_mask is not None:
54
            self.attention_mask = self.attention_mask.pin_memory()
55
        if self.action_mask is not None:
56
            self.action_mask = self.action_mask.pin_memory()
57
        return self
58

59

60
class ExperienceMaker(ABC):
61
    def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
62
        super().__init__()
63
        self.actor = actor
64
        self.critic = critic
65
        self.reward_model = reward_model
66
        self.initial_model = initial_model
67

68
    @abstractmethod
69
    def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
70
        pass
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.