pytorch
398 строк · 12.9 Кб
1import argparse2import asyncio3import os.path4import subprocess5import threading6import time7from concurrent.futures import ThreadPoolExecutor8from queue import Empty9
10import numpy as np11import pandas as pd12
13import torch14import torch.multiprocessing as mp15
16
17class FrontendWorker(mp.Process):18"""19This worker will send requests to a backend process, and measure the
20throughput and latency of those requests as well as GPU utilization.
21"""
22
23def __init__(24self,25metrics_dict,26request_queue,27response_queue,28read_requests_event,29batch_size,30num_iters=10,31):32super().__init__()33self.metrics_dict = metrics_dict34self.request_queue = request_queue35self.response_queue = response_queue36self.read_requests_event = read_requests_event37self.warmup_event = mp.Event()38self.batch_size = batch_size39self.num_iters = num_iters40self.poll_gpu = True41self.start_send_time = None42self.end_recv_time = None43
44def _run_metrics(self, metrics_lock):45"""46This function will poll the response queue until it has received all
47responses. It records the startup latency, the average, max, min latency
48as well as througput of requests.
49"""
50warmup_response_time = None51response_times = []52
53for i in range(self.num_iters + 1):54response, request_time = self.response_queue.get()55if warmup_response_time is None:56self.warmup_event.set()57warmup_response_time = time.time() - request_time58else:59response_times.append(time.time() - request_time)60
61self.end_recv_time = time.time()62self.poll_gpu = False63
64response_times = np.array(response_times)65with metrics_lock:66self.metrics_dict["warmup_latency"] = warmup_response_time67self.metrics_dict["average_latency"] = response_times.mean()68self.metrics_dict["max_latency"] = response_times.max()69self.metrics_dict["min_latency"] = response_times.min()70self.metrics_dict["throughput"] = (self.num_iters * self.batch_size) / (71self.end_recv_time - self.start_send_time72)73
74def _run_gpu_utilization(self, metrics_lock):75"""76This function will poll nvidia-smi for GPU utilization every 100ms to
77record the average GPU utilization.
78"""
79
80def get_gpu_utilization():81try:82nvidia_smi_output = subprocess.check_output(83[84"nvidia-smi",85"--query-gpu=utilization.gpu",86"--id=0",87"--format=csv,noheader,nounits",88]89)90gpu_utilization = nvidia_smi_output.decode().strip()91return gpu_utilization92except subprocess.CalledProcessError:93return "N/A"94
95gpu_utilizations = []96
97while self.poll_gpu:98gpu_utilization = get_gpu_utilization()99if gpu_utilization != "N/A":100gpu_utilizations.append(float(gpu_utilization))101
102with metrics_lock:103self.metrics_dict["gpu_util"] = torch.tensor(gpu_utilizations).mean().item()104
105def _send_requests(self):106"""107This function will send one warmup request, and then num_iters requests
108to the backend process.
109"""
110
111fake_data = torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)112other_data = [113torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)114for i in range(self.num_iters)115]116
117# Send one batch of warmup data118self.request_queue.put((fake_data, time.time()))119# Tell backend to poll queue for warmup request120self.read_requests_event.set()121self.warmup_event.wait()122# Tell backend to poll queue for rest of requests123self.read_requests_event.set()124
125# Send fake data126self.start_send_time = time.time()127for i in range(self.num_iters):128self.request_queue.put((other_data[i], time.time()))129
130def run(self):131# Lock for writing to metrics_dict132metrics_lock = threading.Lock()133requests_thread = threading.Thread(target=self._send_requests)134metrics_thread = threading.Thread(135target=self._run_metrics, args=(metrics_lock,)136)137gpu_utilization_thread = threading.Thread(138target=self._run_gpu_utilization, args=(metrics_lock,)139)140
141requests_thread.start()142metrics_thread.start()143
144# only start polling GPU utilization after the warmup request is complete145self.warmup_event.wait()146gpu_utilization_thread.start()147
148requests_thread.join()149metrics_thread.join()150gpu_utilization_thread.join()151
152
153class BackendWorker:154"""155This worker will take tensors from the request queue, do some computation,
156and then return the result back in the response queue.
157"""
158
159def __init__(160self,161metrics_dict,162request_queue,163response_queue,164read_requests_event,165batch_size,166num_workers,167model_dir=".",168compile_model=True,169):170super().__init__()171self.device = "cuda:0"172self.metrics_dict = metrics_dict173self.request_queue = request_queue174self.response_queue = response_queue175self.read_requests_event = read_requests_event176self.batch_size = batch_size177self.num_workers = num_workers178self.model_dir = model_dir179self.compile_model = compile_model180self._setup_complete = False181self.h2d_stream = torch.cuda.Stream()182self.d2h_stream = torch.cuda.Stream()183# maps thread_id to the cuda.Stream associated with that worker thread184self.stream_map = {}185
186def _setup(self):187import time188
189from torchvision.models.resnet import BasicBlock, ResNet190
191import torch192
193# Create ResNet18 on meta device194with torch.device("meta"):195m = ResNet(BasicBlock, [2, 2, 2, 2])196
197# Load pretrained weights198start_load_time = time.time()199state_dict = torch.load(200f"{self.model_dir}/resnet18-f37072fd.pth",201mmap=True,202map_location=self.device,203)204self.metrics_dict["torch_load_time"] = time.time() - start_load_time205m.load_state_dict(state_dict, assign=True)206m.eval()207
208if self.compile_model:209start_compile_time = time.time()210m.compile()211end_compile_time = time.time()212self.metrics_dict["m_compile_time"] = end_compile_time - start_compile_time213return m214
215def model_predict(216self,217model,218input_buffer,219copy_event,220compute_event,221copy_sem,222compute_sem,223response_list,224request_time,225):226# copy_sem makes sure copy_event has been recorded in the data copying thread227copy_sem.acquire()228self.stream_map[threading.get_native_id()].wait_event(copy_event)229with torch.cuda.stream(self.stream_map[threading.get_native_id()]):230with torch.no_grad():231response_list.append(model(input_buffer))232compute_event.record()233compute_sem.release()234del input_buffer235
236def copy_data(self, input_buffer, data, copy_event, copy_sem):237data = data.pin_memory()238with torch.cuda.stream(self.h2d_stream):239input_buffer.copy_(data, non_blocking=True)240copy_event.record()241copy_sem.release()242
243def respond(self, compute_event, compute_sem, response_list, request_time):244# compute_sem makes sure compute_event has been recorded in the model_predict thread245compute_sem.acquire()246self.d2h_stream.wait_event(compute_event)247with torch.cuda.stream(self.d2h_stream):248self.response_queue.put((response_list[0].cpu(), request_time))249
250async def run(self):251def worker_initializer():252self.stream_map[threading.get_native_id()] = torch.cuda.Stream()253
254worker_pool = ThreadPoolExecutor(255max_workers=self.num_workers, initializer=worker_initializer256)257h2d_pool = ThreadPoolExecutor(max_workers=1)258d2h_pool = ThreadPoolExecutor(max_workers=1)259
260self.read_requests_event.wait()261# Clear as we will wait for this event again before continuing to262# poll the request_queue for the non-warmup requests263self.read_requests_event.clear()264while True:265try:266data, request_time = self.request_queue.get(timeout=5)267except Empty:268break269
270if not self._setup_complete:271model = self._setup()272
273copy_sem = threading.Semaphore(0)274compute_sem = threading.Semaphore(0)275copy_event = torch.cuda.Event()276compute_event = torch.cuda.Event()277response_list = []278input_buffer = torch.empty(279[self.batch_size, 3, 250, 250], dtype=torch.float32, device="cuda"280)281asyncio.get_running_loop().run_in_executor(282h2d_pool,283self.copy_data,284input_buffer,285data,286copy_event,287copy_sem,288)289asyncio.get_running_loop().run_in_executor(290worker_pool,291self.model_predict,292model,293input_buffer,294copy_event,295compute_event,296copy_sem,297compute_sem,298response_list,299request_time,300)301asyncio.get_running_loop().run_in_executor(302d2h_pool,303self.respond,304compute_event,305compute_sem,306response_list,307request_time,308)309
310if not self._setup_complete:311self.read_requests_event.wait()312self._setup_complete = True313
314
315if __name__ == "__main__":316parser = argparse.ArgumentParser()317parser.add_argument("--num_iters", type=int, default=100)318parser.add_argument("--batch_size", type=int, default=32)319parser.add_argument("--model_dir", type=str, default=".")320parser.add_argument(321"--compile", default=True, action=argparse.BooleanOptionalAction322)323parser.add_argument("--output_file", type=str, default="output.csv")324parser.add_argument(325"--profile", default=False, action=argparse.BooleanOptionalAction326)327parser.add_argument("--num_workers", type=int, default=4)328args = parser.parse_args()329
330downloaded_checkpoint = False331if not os.path.isfile(f"{args.model_dir}/resnet18-f37072fd.pth"):332p = subprocess.run(333[334"wget",335"https://download.pytorch.org/models/resnet18-f37072fd.pth",336]337)338if p.returncode == 0:339downloaded_checkpoint = True340else:341raise RuntimeError("Failed to download checkpoint")342
343try:344mp.set_start_method("forkserver")345request_queue = mp.Queue()346response_queue = mp.Queue()347read_requests_event = mp.Event()348
349manager = mp.Manager()350metrics_dict = manager.dict()351metrics_dict["batch_size"] = args.batch_size352metrics_dict["compile"] = args.compile353
354frontend = FrontendWorker(355metrics_dict,356request_queue,357response_queue,358read_requests_event,359args.batch_size,360num_iters=args.num_iters,361)362backend = BackendWorker(363metrics_dict,364request_queue,365response_queue,366read_requests_event,367args.batch_size,368args.num_workers,369args.model_dir,370args.compile,371)372
373frontend.start()374
375if args.profile:376
377def trace_handler(prof):378prof.export_chrome_trace("trace.json")379
380with torch.profiler.profile(on_trace_ready=trace_handler) as prof:381asyncio.run(backend.run())382else:383asyncio.run(backend.run())384
385frontend.join()386
387metrics_dict = {k: [v] for k, v in metrics_dict._getvalue().items()}388output = pd.DataFrame.from_dict(metrics_dict, orient="columns")389output_file = "./results/" + args.output_file390is_empty = not os.path.isfile(output_file)391
392with open(output_file, "a+", newline="") as file:393output.to_csv(file, header=is_empty, index=False)394
395finally:396# Cleanup checkpoint file if we downloaded it397if downloaded_checkpoint:398os.remove(f"{args.model_dir}/resnet18-f37072fd.pth")399