colossalai
70 строк · 2.3 Кб
1from abc import ABC, abstractmethod2from dataclasses import dataclass3from typing import Optional4
5import torch6from coati.models.base import Actor, Critic, RewardModel7
8
9@dataclass
10class Experience:11"""Experience is a batch of data.12These data should have the sequence length and number of actions.
13Left padding for sequences is applied.
14
15Shapes of each tensor:
16sequences: (B, S)
17action_log_probs: (B, A)
18values: (B)
19reward: (B)
20advantages: (B)
21attention_mask: (B, S)
22action_mask: (B, A)
23
24"A" is the number of actions.
25"""
26
27sequences: torch.Tensor28action_log_probs: torch.Tensor29values: torch.Tensor30reward: torch.Tensor31advantages: torch.Tensor32attention_mask: Optional[torch.LongTensor]33action_mask: Optional[torch.BoolTensor]34
35@torch.no_grad()36def to_device(self, device: torch.device) -> None:37self.sequences = self.sequences.to(device)38self.action_log_probs = self.action_log_probs.to(device)39self.values = self.values.to(device)40self.reward = self.reward.to(device)41self.advantages = self.advantages.to(device)42if self.attention_mask is not None:43self.attention_mask = self.attention_mask.to(device)44if self.action_mask is not None:45self.action_mask = self.action_mask.to(device)46
47def pin_memory(self):48self.sequences = self.sequences.pin_memory()49self.action_log_probs = self.action_log_probs.pin_memory()50self.values = self.values.pin_memory()51self.reward = self.reward.pin_memory()52self.advantages = self.advantages.pin_memory()53if self.attention_mask is not None:54self.attention_mask = self.attention_mask.pin_memory()55if self.action_mask is not None:56self.action_mask = self.action_mask.pin_memory()57return self58
59
60class ExperienceMaker(ABC):61def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:62super().__init__()63self.actor = actor64self.critic = critic65self.reward_model = reward_model66self.initial_model = initial_model67
68@abstractmethod69def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:70pass71