pytorch
398 строк · 12.9 Кб
1import argparse
2
3import asyncio
4import os.path
5import subprocess
6import threading
7import time
8from concurrent.futures import ThreadPoolExecutor
9from queue import Empty
10
11import numpy as np
12import pandas as pd
13
14import torch
15import torch.multiprocessing as mp
16
17
18class FrontendWorker(mp.Process):
19"""
20This worker will send requests to a backend process, and measure the
21throughput and latency of those requests as well as GPU utilization.
22"""
23
24def __init__(
25self,
26metrics_dict,
27request_queue,
28response_queue,
29read_requests_event,
30batch_size,
31num_iters=10,
32):
33super().__init__()
34self.metrics_dict = metrics_dict
35self.request_queue = request_queue
36self.response_queue = response_queue
37self.read_requests_event = read_requests_event
38self.warmup_event = mp.Event()
39self.batch_size = batch_size
40self.num_iters = num_iters
41self.poll_gpu = True
42self.start_send_time = None
43self.end_recv_time = None
44
45def _run_metrics(self, metrics_lock):
46"""
47This function will poll the response queue until it has received all
48responses. It records the startup latency, the average, max, min latency
49as well as througput of requests.
50"""
51warmup_response_time = None
52response_times = []
53
54for i in range(self.num_iters + 1):
55response, request_time = self.response_queue.get()
56if warmup_response_time is None:
57self.warmup_event.set()
58warmup_response_time = time.time() - request_time
59else:
60response_times.append(time.time() - request_time)
61
62self.end_recv_time = time.time()
63self.poll_gpu = False
64
65response_times = np.array(response_times)
66with metrics_lock:
67self.metrics_dict["warmup_latency"] = warmup_response_time
68self.metrics_dict["average_latency"] = response_times.mean()
69self.metrics_dict["max_latency"] = response_times.max()
70self.metrics_dict["min_latency"] = response_times.min()
71self.metrics_dict["throughput"] = (self.num_iters * self.batch_size) / (
72self.end_recv_time - self.start_send_time
73)
74
75def _run_gpu_utilization(self, metrics_lock):
76"""
77This function will poll nvidia-smi for GPU utilization every 100ms to
78record the average GPU utilization.
79"""
80
81def get_gpu_utilization():
82try:
83nvidia_smi_output = subprocess.check_output(
84[
85"nvidia-smi",
86"--query-gpu=utilization.gpu",
87"--id=0",
88"--format=csv,noheader,nounits",
89]
90)
91gpu_utilization = nvidia_smi_output.decode().strip()
92return gpu_utilization
93except subprocess.CalledProcessError:
94return "N/A"
95
96gpu_utilizations = []
97
98while self.poll_gpu:
99gpu_utilization = get_gpu_utilization()
100if gpu_utilization != "N/A":
101gpu_utilizations.append(float(gpu_utilization))
102
103with metrics_lock:
104self.metrics_dict["gpu_util"] = torch.tensor(gpu_utilizations).mean().item()
105
106def _send_requests(self):
107"""
108This function will send one warmup request, and then num_iters requests
109to the backend process.
110"""
111
112fake_data = torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)
113other_data = [
114torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)
115for i in range(self.num_iters)
116]
117
118# Send one batch of warmup data
119self.request_queue.put((fake_data, time.time()))
120# Tell backend to poll queue for warmup request
121self.read_requests_event.set()
122self.warmup_event.wait()
123# Tell backend to poll queue for rest of requests
124self.read_requests_event.set()
125
126# Send fake data
127self.start_send_time = time.time()
128for i in range(self.num_iters):
129self.request_queue.put((other_data[i], time.time()))
130
131def run(self):
132# Lock for writing to metrics_dict
133metrics_lock = threading.Lock()
134requests_thread = threading.Thread(target=self._send_requests)
135metrics_thread = threading.Thread(
136target=self._run_metrics, args=(metrics_lock,)
137)
138gpu_utilization_thread = threading.Thread(
139target=self._run_gpu_utilization, args=(metrics_lock,)
140)
141
142requests_thread.start()
143metrics_thread.start()
144
145# only start polling GPU utilization after the warmup request is complete
146self.warmup_event.wait()
147gpu_utilization_thread.start()
148
149requests_thread.join()
150metrics_thread.join()
151gpu_utilization_thread.join()
152
153
154class BackendWorker:
155"""
156This worker will take tensors from the request queue, do some computation,
157and then return the result back in the response queue.
158"""
159
160def __init__(
161self,
162metrics_dict,
163request_queue,
164response_queue,
165read_requests_event,
166batch_size,
167num_workers,
168model_dir=".",
169compile_model=True,
170):
171super().__init__()
172self.device = "cuda:0"
173self.metrics_dict = metrics_dict
174self.request_queue = request_queue
175self.response_queue = response_queue
176self.read_requests_event = read_requests_event
177self.batch_size = batch_size
178self.num_workers = num_workers
179self.model_dir = model_dir
180self.compile_model = compile_model
181self._setup_complete = False
182self.h2d_stream = torch.cuda.Stream()
183self.d2h_stream = torch.cuda.Stream()
184# maps thread_id to the cuda.Stream associated with that worker thread
185self.stream_map = dict()
186
187def _setup(self):
188import time
189
190import torch
191from torchvision.models.resnet import BasicBlock, ResNet
192
193# Create ResNet18 on meta device
194with torch.device("meta"):
195m = ResNet(BasicBlock, [2, 2, 2, 2])
196
197# Load pretrained weights
198start_load_time = time.time()
199state_dict = torch.load(
200f"{self.model_dir}/resnet18-f37072fd.pth",
201mmap=True,
202map_location=self.device,
203)
204self.metrics_dict["torch_load_time"] = time.time() - start_load_time
205m.load_state_dict(state_dict, assign=True)
206m.eval()
207
208if self.compile_model:
209start_compile_time = time.time()
210m.compile()
211end_compile_time = time.time()
212self.metrics_dict["m_compile_time"] = end_compile_time - start_compile_time
213return m
214
215def model_predict(
216self,
217model,
218input_buffer,
219copy_event,
220compute_event,
221copy_sem,
222compute_sem,
223response_list,
224request_time,
225):
226# copy_sem makes sure copy_event has been recorded in the data copying thread
227copy_sem.acquire()
228self.stream_map[threading.get_native_id()].wait_event(copy_event)
229with torch.cuda.stream(self.stream_map[threading.get_native_id()]):
230with torch.no_grad():
231response_list.append(model(input_buffer))
232compute_event.record()
233compute_sem.release()
234del input_buffer
235
236def copy_data(self, input_buffer, data, copy_event, copy_sem):
237data = data.pin_memory()
238with torch.cuda.stream(self.h2d_stream):
239input_buffer.copy_(data, non_blocking=True)
240copy_event.record()
241copy_sem.release()
242
243def respond(self, compute_event, compute_sem, response_list, request_time):
244# compute_sem makes sure compute_event has been recorded in the model_predict thread
245compute_sem.acquire()
246self.d2h_stream.wait_event(compute_event)
247with torch.cuda.stream(self.d2h_stream):
248self.response_queue.put((response_list[0].cpu(), request_time))
249
250async def run(self):
251def worker_initializer():
252self.stream_map[threading.get_native_id()] = torch.cuda.Stream()
253
254worker_pool = ThreadPoolExecutor(
255max_workers=self.num_workers, initializer=worker_initializer
256)
257h2d_pool = ThreadPoolExecutor(max_workers=1)
258d2h_pool = ThreadPoolExecutor(max_workers=1)
259
260self.read_requests_event.wait()
261# Clear as we will wait for this event again before continuing to
262# poll the request_queue for the non-warmup requests
263self.read_requests_event.clear()
264while True:
265try:
266data, request_time = self.request_queue.get(timeout=5)
267except Empty:
268break
269
270if not self._setup_complete:
271model = self._setup()
272
273copy_sem = threading.Semaphore(0)
274compute_sem = threading.Semaphore(0)
275copy_event = torch.cuda.Event()
276compute_event = torch.cuda.Event()
277response_list = []
278input_buffer = torch.empty(
279[self.batch_size, 3, 250, 250], dtype=torch.float32, device="cuda"
280)
281asyncio.get_running_loop().run_in_executor(
282h2d_pool,
283self.copy_data,
284input_buffer,
285data,
286copy_event,
287copy_sem,
288)
289asyncio.get_running_loop().run_in_executor(
290worker_pool,
291self.model_predict,
292model,
293input_buffer,
294copy_event,
295compute_event,
296copy_sem,
297compute_sem,
298response_list,
299request_time,
300)
301asyncio.get_running_loop().run_in_executor(
302d2h_pool,
303self.respond,
304compute_event,
305compute_sem,
306response_list,
307request_time,
308)
309
310if not self._setup_complete:
311self.read_requests_event.wait()
312self._setup_complete = True
313
314
315if __name__ == "__main__":
316parser = argparse.ArgumentParser()
317parser.add_argument("--num_iters", type=int, default=100)
318parser.add_argument("--batch_size", type=int, default=32)
319parser.add_argument("--model_dir", type=str, default=".")
320parser.add_argument(
321"--compile", default=True, action=argparse.BooleanOptionalAction
322)
323parser.add_argument("--output_file", type=str, default="output.csv")
324parser.add_argument(
325"--profile", default=False, action=argparse.BooleanOptionalAction
326)
327parser.add_argument("--num_workers", type=int, default=4)
328args = parser.parse_args()
329
330downloaded_checkpoint = False
331if not os.path.isfile(f"{args.model_dir}/resnet18-f37072fd.pth"):
332p = subprocess.run(
333[
334"wget",
335"https://download.pytorch.org/models/resnet18-f37072fd.pth",
336]
337)
338if p.returncode == 0:
339downloaded_checkpoint = True
340else:
341raise RuntimeError("Failed to download checkpoint")
342
343try:
344mp.set_start_method("forkserver")
345request_queue = mp.Queue()
346response_queue = mp.Queue()
347read_requests_event = mp.Event()
348
349manager = mp.Manager()
350metrics_dict = manager.dict()
351metrics_dict["batch_size"] = args.batch_size
352metrics_dict["compile"] = args.compile
353
354frontend = FrontendWorker(
355metrics_dict,
356request_queue,
357response_queue,
358read_requests_event,
359args.batch_size,
360num_iters=args.num_iters,
361)
362backend = BackendWorker(
363metrics_dict,
364request_queue,
365response_queue,
366read_requests_event,
367args.batch_size,
368args.num_workers,
369args.model_dir,
370args.compile,
371)
372
373frontend.start()
374
375if args.profile:
376
377def trace_handler(prof):
378prof.export_chrome_trace("trace.json")
379
380with torch.profiler.profile(on_trace_ready=trace_handler) as prof:
381asyncio.run(backend.run())
382else:
383asyncio.run(backend.run())
384
385frontend.join()
386
387metrics_dict = {k: [v] for k, v in metrics_dict._getvalue().items()}
388output = pd.DataFrame.from_dict(metrics_dict, orient="columns")
389output_file = "./results/" + args.output_file
390is_empty = not os.path.isfile(output_file)
391
392with open(output_file, "a+", newline="") as file:
393output.to_csv(file, header=is_empty, index=False)
394
395finally:
396# Cleanup checkpoint file if we downloaded it
397if downloaded_checkpoint:
398os.remove(f"{args.model_dir}/resnet18-f37072fd.pth")
399