stable-diffusion-webui

Форк
0
92 строки · 2.7 Кб
1
import threading
2
import time
3
from collections import defaultdict
4

5
import torch
6

7

8
class MemUsageMonitor(threading.Thread):
9
    run_flag = None
10
    device = None
11
    disabled = False
12
    opts = None
13
    data = None
14

15
    def __init__(self, name, device, opts):
16
        threading.Thread.__init__(self)
17
        self.name = name
18
        self.device = device
19
        self.opts = opts
20

21
        self.daemon = True
22
        self.run_flag = threading.Event()
23
        self.data = defaultdict(int)
24

25
        try:
26
            self.cuda_mem_get_info()
27
            torch.cuda.memory_stats(self.device)
28
        except Exception as e:  # AMD or whatever
29
            print(f"Warning: caught exception '{e}', memory monitor disabled")
30
            self.disabled = True
31

32
    def cuda_mem_get_info(self):
33
        index = self.device.index if self.device.index is not None else torch.cuda.current_device()
34
        return torch.cuda.mem_get_info(index)
35

36
    def run(self):
37
        if self.disabled:
38
            return
39

40
        while True:
41
            self.run_flag.wait()
42

43
            torch.cuda.reset_peak_memory_stats()
44
            self.data.clear()
45

46
            if self.opts.memmon_poll_rate <= 0:
47
                self.run_flag.clear()
48
                continue
49

50
            self.data["min_free"] = self.cuda_mem_get_info()[0]
51

52
            while self.run_flag.is_set():
53
                free, total = self.cuda_mem_get_info()
54
                self.data["min_free"] = min(self.data["min_free"], free)
55

56
                time.sleep(1 / self.opts.memmon_poll_rate)
57

58
    def dump_debug(self):
59
        print(self, 'recorded data:')
60
        for k, v in self.read().items():
61
            print(k, -(v // -(1024 ** 2)))
62

63
        print(self, 'raw torch memory stats:')
64
        tm = torch.cuda.memory_stats(self.device)
65
        for k, v in tm.items():
66
            if 'bytes' not in k:
67
                continue
68
            print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
69

70
        print(torch.cuda.memory_summary())
71

72
    def monitor(self):
73
        self.run_flag.set()
74

75
    def read(self):
76
        if not self.disabled:
77
            free, total = self.cuda_mem_get_info()
78
            self.data["free"] = free
79
            self.data["total"] = total
80

81
            torch_stats = torch.cuda.memory_stats(self.device)
82
            self.data["active"] = torch_stats["active.all.current"]
83
            self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
84
            self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
85
            self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
86
            self.data["system_peak"] = total - self.data["min_free"]
87

88
        return self.data
89

90
    def stop(self):
91
        self.run_flag.clear()
92
        return self.read()
93

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.