stable-diffusion-webui
92 строки · 2.7 Кб
1import threading2import time3from collections import defaultdict4
5import torch6
7
8class MemUsageMonitor(threading.Thread):9run_flag = None10device = None11disabled = False12opts = None13data = None14
15def __init__(self, name, device, opts):16threading.Thread.__init__(self)17self.name = name18self.device = device19self.opts = opts20
21self.daemon = True22self.run_flag = threading.Event()23self.data = defaultdict(int)24
25try:26self.cuda_mem_get_info()27torch.cuda.memory_stats(self.device)28except Exception as e: # AMD or whatever29print(f"Warning: caught exception '{e}', memory monitor disabled")30self.disabled = True31
32def cuda_mem_get_info(self):33index = self.device.index if self.device.index is not None else torch.cuda.current_device()34return torch.cuda.mem_get_info(index)35
36def run(self):37if self.disabled:38return39
40while True:41self.run_flag.wait()42
43torch.cuda.reset_peak_memory_stats()44self.data.clear()45
46if self.opts.memmon_poll_rate <= 0:47self.run_flag.clear()48continue49
50self.data["min_free"] = self.cuda_mem_get_info()[0]51
52while self.run_flag.is_set():53free, total = self.cuda_mem_get_info()54self.data["min_free"] = min(self.data["min_free"], free)55
56time.sleep(1 / self.opts.memmon_poll_rate)57
58def dump_debug(self):59print(self, 'recorded data:')60for k, v in self.read().items():61print(k, -(v // -(1024 ** 2)))62
63print(self, 'raw torch memory stats:')64tm = torch.cuda.memory_stats(self.device)65for k, v in tm.items():66if 'bytes' not in k:67continue68print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))69
70print(torch.cuda.memory_summary())71
72def monitor(self):73self.run_flag.set()74
75def read(self):76if not self.disabled:77free, total = self.cuda_mem_get_info()78self.data["free"] = free79self.data["total"] = total80
81torch_stats = torch.cuda.memory_stats(self.device)82self.data["active"] = torch_stats["active.all.current"]83self.data["active_peak"] = torch_stats["active_bytes.all.peak"]84self.data["reserved"] = torch_stats["reserved_bytes.all.current"]85self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]86self.data["system_peak"] = total - self.data["min_free"]87
88return self.data89
90def stop(self):91self.run_flag.clear()92return self.read()93