gpt-neox

Форк
0
/
devutil.py 
51 строка · 1.2 Кб
1
import torch.cuda
2

3

4
class Metric:
5
    """
6
    Dumb utility to collect and report average wall-time metrics.
7
    """
8

9
    def __init__(self, label):
10
        self.label = label
11
        self.measurements = []
12

13
    def collect(self, measurement):
14
        self.measurements.append(measurement)
15

16
    def get_measurements(self):
17
        return self.measurements[:]
18

19
    def report(self):
20
        print(
21
            self.label,
22
            torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0),
23
        )
24

25

26
def monitor_method_cuda_wall_times(metric, obj, methodname):
27
    """
28
    Measure timings for a method on an object or class.
29

30
    For instance:
31

32
    >>> metric = Metric('!LNORM')
33
    >>> monitor_method_wall_times(metric, LayerNorm, 'forward')
34
    """
35
    oldmeth = getattr(obj, methodname)
36

37
    start_event = torch.cuda.Event(enable_timing=True)
38
    end_event = torch.cuda.Event(enable_timing=True)
39

40
    def newmeth(*args, **kw):
41
        start_event.record()
42
        try:
43
            return oldmeth(*args, **kw)
44
        finally:
45
            end_event.record()
46
            torch.cuda.synchronize()
47
            elapsed = start_event.elapsed_time(end_event)
48
            metric.collect(elapsed)
49
            metric.report()
50

51
    setattr(obj, methodname, newmeth)
52

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.