gpt-neox
1import torch.cuda
2
3
4class Metric:
5"""
6Dumb utility to collect and report average wall-time metrics.
7"""
8
9def __init__(self, label):
10self.label = label
11self.measurements = []
12
13def collect(self, measurement):
14self.measurements.append(measurement)
15
16def get_measurements(self):
17return self.measurements[:]
18
19def report(self):
20print(
21self.label,
22torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0),
23)
24
25
26def monitor_method_cuda_wall_times(metric, obj, methodname):
27"""
28Measure timings for a method on an object or class.
29
30For instance:
31
32>>> metric = Metric('!LNORM')
33>>> monitor_method_wall_times(metric, LayerNorm, 'forward')
34"""
35oldmeth = getattr(obj, methodname)
36
37start_event = torch.cuda.Event(enable_timing=True)
38end_event = torch.cuda.Event(enable_timing=True)
39
40def newmeth(*args, **kw):
41start_event.record()
42try:
43return oldmeth(*args, **kw)
44finally:
45end_event.record()
46torch.cuda.synchronize()
47elapsed = start_event.elapsed_time(end_event)
48metric.collect(elapsed)
49metric.report()
50
51setattr(obj, methodname, newmeth)
52