colossalai

Форк
0
126 строк · 4.1 Кб
1
from time import time
2
from typing import Optional
3

4
import torch
5
import torch.distributed as dist
6
import torch.nn as nn
7
from torch import Tensor
8

9
from colossalai.logging import DistributedLogger
10

11

12
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
13
    B = 1024**3
14
    M = 1024**2
15
    K = 1024
16
    outputs = "Model param count: "
17
    model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
18
    if model_param >= B:
19
        outputs += f"{model_param / B:.2f} B\n"
20
    elif model_param >= M:
21
        outputs += f"{model_param / M:.2f} M\n"
22
    elif model_param >= K:
23
        outputs += f"{model_param / K:.2f} K\n"
24
    else:
25
        outputs += f"{model_param}\n"
26
    logger.info(outputs, ranks=[0])
27

28

29
def get_model_numel(model: nn.Module) -> None:
30
    model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
31
    return model_param
32

33

34
def divide(x: float, y: float) -> float:
35
    if y == 0:
36
        return float("inf")
37
    elif y == float("inf"):
38
        return float("nan")
39
    return x / y
40

41

42
@torch.no_grad()
43
def all_reduce_mean(x: float, world_size: int) -> float:
44
    if world_size == 1:
45
        return x
46
    tensor = torch.tensor([x], device=torch.cuda.current_device())
47
    dist.all_reduce(tensor)
48
    tensor = tensor / world_size
49
    return tensor.item()
50

51

52
class Timer:
53

54
    def __init__(self) -> None:
55
        self.start_time: Optional[float] = None
56
        self.duration: float = 0.0
57

58
    def start(self) -> None:
59
        self.start_time = time()
60

61
    def end(self) -> None:
62
        assert self.start_time is not None
63
        self.duration += time() - self.start_time
64
        self.start_time = None
65

66
    def reset(self) -> None:
67
        self.duration = 0.0
68

69

70
class PerformanceEvaluator:
71
    """
72
        Callback for valuate the performance of the model.
73
    Args:
74
        actor_num_params: The number of parameters of the actor model.
75
        critic_num_params: The number of parameters of the critic model.
76
        initial_model_num_params: The number of parameters of the initial model.
77
        reward_model_num_params: The number of parameters of the reward model.
78
        enable_grad_checkpoint: Whether to enable gradient checkpointing.
79
        ignore_episodes: The number of episodes to ignore when calculating the performance.
80
    """
81

82
    def __init__(
83
        self,
84
        model_numel: int,
85
        enable_grad_checkpoint: bool = False,
86
        ignore_steps: int = 0,
87
        dp_world_size: Optional[int] = None,
88
    ) -> None:
89
        self.model_numel = model_numel
90
        self.enable_grad_checkpoint = enable_grad_checkpoint
91
        self.ignore_steps = ignore_steps
92
        self.dp_world_size = dp_world_size
93
        self.world_size = dist.get_world_size()
94
        self.disable: bool = False
95
        self.timer = Timer()
96
        self.num_samples: int = 0
97
        self.flop: int = 0
98

99
    def on_step_start(self, step: int) -> None:
100
        self.disable = self.ignore_steps > 0 and step < self.ignore_steps
101
        if self.disable:
102
            return
103
        torch.cuda.synchronize()
104
        self.timer.start()
105

106
    def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
107
        if self.disable:
108
            return
109
        torch.cuda.synchronize()
110
        self.timer.end()
111

112
        batch_size, seq_len = input_ids.shape
113

114
        self.num_samples += batch_size
115
        self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
116

117
    def on_fit_end(self) -> None:
118
        avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
119
        avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
120
        mp_world_size = self.world_size // self.dp_world_size
121
        avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
122
        if dist.get_rank() == 0:
123
            print(
124
                f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
125
                f"avg_throughput: {avg_throughput}")
126
            print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
127

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.