pytorch
33 строки · 1.1 Кб
1import collections2from typing import Deque, Optional3
4import torch5
6
7class _FreeEventQueue:8"""9This tracks all pending frees corresponding to inflight all-gathers. The
10queueing pattern is iterative enqueues with a single dequeue per iteration
11once the limit ``_max_num_inflight_all_gathers`` is reached.
12"""
13
14def __init__(self) -> None:15self._queue: Deque[torch.cuda.Event] = collections.deque()16self._max_num_inflight_all_gathers = 2 # empirically chosen17
18def enqueue(self, free_event: torch.cuda.Event) -> None:19"""Enqueues a free event."""20self._queue.append(free_event)21
22def dequeue_if_needed(self) -> Optional[torch.cuda.Event]:23"""Dequeues a single event if the limit is reached."""24if len(self._queue) >= self._max_num_inflight_all_gathers:25return self._dequeue()26return None27
28def _dequeue(self) -> Optional[torch.cuda.Event]:29"""Dequeues a free event if possible."""30if self._queue:31event = self._queue.popleft()32return event33return None34