pytorch

Форк
0
398 строк · 12.9 Кб
1
import argparse
2

3
import asyncio
4
import os.path
5
import subprocess
6
import threading
7
import time
8
from concurrent.futures import ThreadPoolExecutor
9
from queue import Empty
10

11
import numpy as np
12
import pandas as pd
13

14
import torch
15
import torch.multiprocessing as mp
16

17

18
class FrontendWorker(mp.Process):
19
    """
20
    This worker will send requests to a backend process, and measure the
21
    throughput and latency of those requests as well as GPU utilization.
22
    """
23

24
    def __init__(
25
        self,
26
        metrics_dict,
27
        request_queue,
28
        response_queue,
29
        read_requests_event,
30
        batch_size,
31
        num_iters=10,
32
    ):
33
        super().__init__()
34
        self.metrics_dict = metrics_dict
35
        self.request_queue = request_queue
36
        self.response_queue = response_queue
37
        self.read_requests_event = read_requests_event
38
        self.warmup_event = mp.Event()
39
        self.batch_size = batch_size
40
        self.num_iters = num_iters
41
        self.poll_gpu = True
42
        self.start_send_time = None
43
        self.end_recv_time = None
44

45
    def _run_metrics(self, metrics_lock):
46
        """
47
        This function will poll the response queue until it has received all
48
        responses. It records the startup latency, the average, max, min latency
49
        as well as througput of requests.
50
        """
51
        warmup_response_time = None
52
        response_times = []
53

54
        for i in range(self.num_iters + 1):
55
            response, request_time = self.response_queue.get()
56
            if warmup_response_time is None:
57
                self.warmup_event.set()
58
                warmup_response_time = time.time() - request_time
59
            else:
60
                response_times.append(time.time() - request_time)
61

62
        self.end_recv_time = time.time()
63
        self.poll_gpu = False
64

65
        response_times = np.array(response_times)
66
        with metrics_lock:
67
            self.metrics_dict["warmup_latency"] = warmup_response_time
68
            self.metrics_dict["average_latency"] = response_times.mean()
69
            self.metrics_dict["max_latency"] = response_times.max()
70
            self.metrics_dict["min_latency"] = response_times.min()
71
            self.metrics_dict["throughput"] = (self.num_iters * self.batch_size) / (
72
                self.end_recv_time - self.start_send_time
73
            )
74

75
    def _run_gpu_utilization(self, metrics_lock):
76
        """
77
        This function will poll nvidia-smi for GPU utilization every 100ms to
78
        record the average GPU utilization.
79
        """
80

81
        def get_gpu_utilization():
82
            try:
83
                nvidia_smi_output = subprocess.check_output(
84
                    [
85
                        "nvidia-smi",
86
                        "--query-gpu=utilization.gpu",
87
                        "--id=0",
88
                        "--format=csv,noheader,nounits",
89
                    ]
90
                )
91
                gpu_utilization = nvidia_smi_output.decode().strip()
92
                return gpu_utilization
93
            except subprocess.CalledProcessError:
94
                return "N/A"
95

96
        gpu_utilizations = []
97

98
        while self.poll_gpu:
99
            gpu_utilization = get_gpu_utilization()
100
            if gpu_utilization != "N/A":
101
                gpu_utilizations.append(float(gpu_utilization))
102

103
        with metrics_lock:
104
            self.metrics_dict["gpu_util"] = torch.tensor(gpu_utilizations).mean().item()
105

106
    def _send_requests(self):
107
        """
108
        This function will send one warmup request, and then num_iters requests
109
        to the backend process.
110
        """
111

112
        fake_data = torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)
113
        other_data = [
114
            torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)
115
            for i in range(self.num_iters)
116
        ]
117

118
        # Send one batch of warmup data
119
        self.request_queue.put((fake_data, time.time()))
120
        # Tell backend to poll queue for warmup request
121
        self.read_requests_event.set()
122
        self.warmup_event.wait()
123
        # Tell backend to poll queue for rest of requests
124
        self.read_requests_event.set()
125

126
        # Send fake data
127
        self.start_send_time = time.time()
128
        for i in range(self.num_iters):
129
            self.request_queue.put((other_data[i], time.time()))
130

131
    def run(self):
132
        # Lock for writing to metrics_dict
133
        metrics_lock = threading.Lock()
134
        requests_thread = threading.Thread(target=self._send_requests)
135
        metrics_thread = threading.Thread(
136
            target=self._run_metrics, args=(metrics_lock,)
137
        )
138
        gpu_utilization_thread = threading.Thread(
139
            target=self._run_gpu_utilization, args=(metrics_lock,)
140
        )
141

142
        requests_thread.start()
143
        metrics_thread.start()
144

145
        # only start polling GPU utilization after the warmup request is complete
146
        self.warmup_event.wait()
147
        gpu_utilization_thread.start()
148

149
        requests_thread.join()
150
        metrics_thread.join()
151
        gpu_utilization_thread.join()
152

153

154
class BackendWorker:
155
    """
156
    This worker will take tensors from the request queue, do some computation,
157
    and then return the result back in the response queue.
158
    """
159

160
    def __init__(
161
        self,
162
        metrics_dict,
163
        request_queue,
164
        response_queue,
165
        read_requests_event,
166
        batch_size,
167
        num_workers,
168
        model_dir=".",
169
        compile_model=True,
170
    ):
171
        super().__init__()
172
        self.device = "cuda:0"
173
        self.metrics_dict = metrics_dict
174
        self.request_queue = request_queue
175
        self.response_queue = response_queue
176
        self.read_requests_event = read_requests_event
177
        self.batch_size = batch_size
178
        self.num_workers = num_workers
179
        self.model_dir = model_dir
180
        self.compile_model = compile_model
181
        self._setup_complete = False
182
        self.h2d_stream = torch.cuda.Stream()
183
        self.d2h_stream = torch.cuda.Stream()
184
        # maps thread_id to the cuda.Stream associated with that worker thread
185
        self.stream_map = dict()
186

187
    def _setup(self):
188
        import time
189

190
        import torch
191
        from torchvision.models.resnet import BasicBlock, ResNet
192

193
        # Create ResNet18 on meta device
194
        with torch.device("meta"):
195
            m = ResNet(BasicBlock, [2, 2, 2, 2])
196

197
        # Load pretrained weights
198
        start_load_time = time.time()
199
        state_dict = torch.load(
200
            f"{self.model_dir}/resnet18-f37072fd.pth",
201
            mmap=True,
202
            map_location=self.device,
203
        )
204
        self.metrics_dict["torch_load_time"] = time.time() - start_load_time
205
        m.load_state_dict(state_dict, assign=True)
206
        m.eval()
207

208
        if self.compile_model:
209
            start_compile_time = time.time()
210
            m.compile()
211
            end_compile_time = time.time()
212
            self.metrics_dict["m_compile_time"] = end_compile_time - start_compile_time
213
        return m
214

215
    def model_predict(
216
        self,
217
        model,
218
        input_buffer,
219
        copy_event,
220
        compute_event,
221
        copy_sem,
222
        compute_sem,
223
        response_list,
224
        request_time,
225
    ):
226
        # copy_sem makes sure copy_event has been recorded in the data copying thread
227
        copy_sem.acquire()
228
        self.stream_map[threading.get_native_id()].wait_event(copy_event)
229
        with torch.cuda.stream(self.stream_map[threading.get_native_id()]):
230
            with torch.no_grad():
231
                response_list.append(model(input_buffer))
232
                compute_event.record()
233
                compute_sem.release()
234
        del input_buffer
235

236
    def copy_data(self, input_buffer, data, copy_event, copy_sem):
237
        data = data.pin_memory()
238
        with torch.cuda.stream(self.h2d_stream):
239
            input_buffer.copy_(data, non_blocking=True)
240
            copy_event.record()
241
            copy_sem.release()
242

243
    def respond(self, compute_event, compute_sem, response_list, request_time):
244
        # compute_sem makes sure compute_event has been recorded in the model_predict thread
245
        compute_sem.acquire()
246
        self.d2h_stream.wait_event(compute_event)
247
        with torch.cuda.stream(self.d2h_stream):
248
            self.response_queue.put((response_list[0].cpu(), request_time))
249

250
    async def run(self):
251
        def worker_initializer():
252
            self.stream_map[threading.get_native_id()] = torch.cuda.Stream()
253

254
        worker_pool = ThreadPoolExecutor(
255
            max_workers=self.num_workers, initializer=worker_initializer
256
        )
257
        h2d_pool = ThreadPoolExecutor(max_workers=1)
258
        d2h_pool = ThreadPoolExecutor(max_workers=1)
259

260
        self.read_requests_event.wait()
261
        # Clear as we will wait for this event again before continuing to
262
        # poll the request_queue for the non-warmup requests
263
        self.read_requests_event.clear()
264
        while True:
265
            try:
266
                data, request_time = self.request_queue.get(timeout=5)
267
            except Empty:
268
                break
269

270
            if not self._setup_complete:
271
                model = self._setup()
272

273
            copy_sem = threading.Semaphore(0)
274
            compute_sem = threading.Semaphore(0)
275
            copy_event = torch.cuda.Event()
276
            compute_event = torch.cuda.Event()
277
            response_list = []
278
            input_buffer = torch.empty(
279
                [self.batch_size, 3, 250, 250], dtype=torch.float32, device="cuda"
280
            )
281
            asyncio.get_running_loop().run_in_executor(
282
                h2d_pool,
283
                self.copy_data,
284
                input_buffer,
285
                data,
286
                copy_event,
287
                copy_sem,
288
            )
289
            asyncio.get_running_loop().run_in_executor(
290
                worker_pool,
291
                self.model_predict,
292
                model,
293
                input_buffer,
294
                copy_event,
295
                compute_event,
296
                copy_sem,
297
                compute_sem,
298
                response_list,
299
                request_time,
300
            )
301
            asyncio.get_running_loop().run_in_executor(
302
                d2h_pool,
303
                self.respond,
304
                compute_event,
305
                compute_sem,
306
                response_list,
307
                request_time,
308
            )
309

310
            if not self._setup_complete:
311
                self.read_requests_event.wait()
312
                self._setup_complete = True
313

314

315
if __name__ == "__main__":
316
    parser = argparse.ArgumentParser()
317
    parser.add_argument("--num_iters", type=int, default=100)
318
    parser.add_argument("--batch_size", type=int, default=32)
319
    parser.add_argument("--model_dir", type=str, default=".")
320
    parser.add_argument(
321
        "--compile", default=True, action=argparse.BooleanOptionalAction
322
    )
323
    parser.add_argument("--output_file", type=str, default="output.csv")
324
    parser.add_argument(
325
        "--profile", default=False, action=argparse.BooleanOptionalAction
326
    )
327
    parser.add_argument("--num_workers", type=int, default=4)
328
    args = parser.parse_args()
329

330
    downloaded_checkpoint = False
331
    if not os.path.isfile(f"{args.model_dir}/resnet18-f37072fd.pth"):
332
        p = subprocess.run(
333
            [
334
                "wget",
335
                "https://download.pytorch.org/models/resnet18-f37072fd.pth",
336
            ]
337
        )
338
        if p.returncode == 0:
339
            downloaded_checkpoint = True
340
        else:
341
            raise RuntimeError("Failed to download checkpoint")
342

343
    try:
344
        mp.set_start_method("forkserver")
345
        request_queue = mp.Queue()
346
        response_queue = mp.Queue()
347
        read_requests_event = mp.Event()
348

349
        manager = mp.Manager()
350
        metrics_dict = manager.dict()
351
        metrics_dict["batch_size"] = args.batch_size
352
        metrics_dict["compile"] = args.compile
353

354
        frontend = FrontendWorker(
355
            metrics_dict,
356
            request_queue,
357
            response_queue,
358
            read_requests_event,
359
            args.batch_size,
360
            num_iters=args.num_iters,
361
        )
362
        backend = BackendWorker(
363
            metrics_dict,
364
            request_queue,
365
            response_queue,
366
            read_requests_event,
367
            args.batch_size,
368
            args.num_workers,
369
            args.model_dir,
370
            args.compile,
371
        )
372

373
        frontend.start()
374

375
        if args.profile:
376

377
            def trace_handler(prof):
378
                prof.export_chrome_trace("trace.json")
379

380
            with torch.profiler.profile(on_trace_ready=trace_handler) as prof:
381
                asyncio.run(backend.run())
382
        else:
383
            asyncio.run(backend.run())
384

385
        frontend.join()
386

387
        metrics_dict = {k: [v] for k, v in metrics_dict._getvalue().items()}
388
        output = pd.DataFrame.from_dict(metrics_dict, orient="columns")
389
        output_file = "./results/" + args.output_file
390
        is_empty = not os.path.isfile(output_file)
391

392
        with open(output_file, "a+", newline="") as file:
393
            output.to_csv(file, header=is_empty, index=False)
394

395
    finally:
396
        # Cleanup checkpoint file if we downloaded it
397
        if downloaded_checkpoint:
398
            os.remove(f"{args.model_dir}/resnet18-f37072fd.pth")
399

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.