colossalai

Форк
0
/
performance_evaluator.py 
214 строк · 8.4 Кб
1
from time import time
2
from typing import Optional
3

4
import torch
5
import torch.distributed as dist
6
from coati.experience_maker import Experience
7

8
from .base import MakerCallback, TrainerCallback
9

10

11
def get_world_size() -> int:
12
    if dist.is_initialized():
13
        return dist.get_world_size()
14
    return 1
15

16

17
def print_rank_0(*args, **kwargs) -> None:
18
    if not dist.is_initialized() or dist.get_rank() == 0:
19
        print(*args, **kwargs)
20

21

22
@torch.no_grad()
23
def all_reduce_mean(x: float, world_size: int) -> float:
24
    if world_size == 1:
25
        return x
26
    tensor = torch.tensor([x], device=torch.cuda.current_device())
27
    dist.all_reduce(tensor)
28
    tensor = tensor / world_size
29
    return tensor.item()
30

31

32
class Timer:
33
    def __init__(self) -> None:
34
        self.start_time: Optional[float] = None
35
        self.duration: float = 0.0
36

37
    def start(self) -> None:
38
        self.start_time = time()
39

40
    def end(self) -> None:
41
        self.duration += time() - self.start_time
42

43
    def reset(self) -> None:
44
        self.duration = 0.0
45

46

47
class ExperienceMakerPerformanceEvaluator(MakerCallback):
48
    def __init__(
49
        self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
50
    ) -> None:
51
        super().__init__()
52
        self.world_size = get_world_size()
53
        self.actor_num_params = actor_num_params
54
        self.critic_num_params = critic_num_params
55
        self.initial_model_num_params = initial_model_num_params
56
        self.reward_model_num_params = reward_model_num_params
57

58
        self.batch_timer = Timer()
59
        self.send_timer = Timer()
60
        self.make_experience_timer = Timer()
61
        self.total_samples: int = 0
62
        self.make_experience_flop: int = 0
63

64
        print_rank_0(
65
            f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
66
        )
67

68
    def on_make_experience_start(self) -> None:
69
        self.make_experience_timer.start()
70

71
    def on_make_experience_end(self, experience: Experience) -> None:
72
        self.make_experience_timer.end()
73

74
        batch_size, seq_len = experience.sequences.shape
75

76
        self.total_samples += batch_size
77

78
        # actor generate
79
        num_actions = experience.action_mask.size(1)
80
        input_len = seq_len - num_actions
81
        total_seq_len = (input_len + seq_len - 1) * num_actions / 2
82
        self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
83
        # actor forward
84
        self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
85
        # critic forward
86
        self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
87
        # initial model forward
88
        self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
89
        # reward model forward
90
        self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
91

92
    def on_send_start(self) -> None:
93
        self.send_timer.start()
94

95
    def on_send_end(self) -> None:
96
        self.send_timer.end()
97

98
    def on_batch_start(self) -> None:
99
        self.batch_timer.start()
100

101
    def on_batch_end(self) -> None:
102
        self.batch_timer.end()
103

104
    def on_loop_end(self) -> None:
105
        avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
106
        avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
107
        avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size)
108

109
        avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
110
        avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
111
        avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
112
        avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
113
            self.total_samples * self.world_size
114
        )
115
        avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
116

117
        print_rank_0(
118
            "Making Experience Performance Summary:\n"
119
            + f"Throughput: {avg_throughput:.3f} samples/sec\n"
120
            + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
121
            + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
122
            + f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
123
            + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
124
        )
125

126

127
class TrainerPerformanceEvaluator(TrainerCallback):
128
    def __init__(
129
        self,
130
        actor_num_params: int,
131
        critic_num_params: int,
132
        enable_grad_checkpoint: bool = False,
133
        ignore_first_episodes: int = 1,
134
    ) -> None:
135
        super().__init__()
136
        self.world_size = get_world_size()
137
        self.actor_num_params = actor_num_params
138
        self.critic_num_params = critic_num_params
139
        self.enable_grad_checkpoint = enable_grad_checkpoint
140
        self.ignore_first_episodes = ignore_first_episodes
141
        self.ignore_this_episode = False
142

143
        self.episode_timer = Timer()
144
        self.batch_timer = Timer()
145
        self.update_timer = Timer()
146
        self.total_samples: int = 0
147
        self.learn_flop: int = 0
148

149
        print_rank_0(
150
            f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
151
        )
152

153
    def on_episode_start(self, episodes: int) -> None:
154
        self.ignore_this_episode = episodes < self.ignore_first_episodes
155
        if self.ignore_this_episode:
156
            return
157
        self.episode_timer.start()
158

159
    def on_episode_end(self, episodes: int) -> None:
160
        if self.ignore_this_episode:
161
            return
162
        self.episode_timer.end()
163

164
    def on_batch_start(self) -> None:
165
        if self.ignore_this_episode:
166
            return
167
        self.batch_timer.start()
168

169
    def on_batch_end(self, metrics: dict, experience: Experience) -> None:
170
        if self.ignore_this_episode:
171
            return
172
        self.batch_timer.end()
173

174
        batch_size, seq_len = experience.sequences.shape
175

176
        self.total_samples += batch_size
177

178
        # actor forward-backward, 3 means forward(1) + backward(2)
179
        self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
180
        # critic forward-backward
181
        self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
182

183
    def on_update_start(self) -> None:
184
        if self.ignore_this_episode:
185
            return
186
        self.update_timer.start()
187

188
    def on_update_end(self) -> None:
189
        if self.ignore_this_episode:
190
            return
191
        self.update_timer.end()
192

193
    def on_fit_end(self) -> None:
194
        if self.total_samples == 0:
195
            print_rank_0("No samples are collected, skip trainer performance evaluation")
196
            return
197
        avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
198
        avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
199
        avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size)
200

201
        avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12)
202
        avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12)
203
        avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size)
204
        avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size)
205
        avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
206

207
        print_rank_0(
208
            "Learning Performance Summary:\n"
209
            + f"Throughput: {avg_throughput:.3f} samples/sec\n"
210
            + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
211
            + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
212
            + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
213
            + f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
214
        )
215

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.