colossalai
50 строк · 1.4 Кб
1from typing import Any
2
3import torch
4import torch.distributed as dist
5from torch.utils._pytree import tree_map
6from torch.utils.data import DataLoader
7
8
9class CycledDataLoader:
10"""
11Why do we need this class?
12In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
13However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
14NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
15"""
16
17def __init__(
18self,
19dataloader: DataLoader,
20) -> None:
21self.dataloader = dataloader
22
23self.count = 0
24self.dataloader_iter = None
25
26def next(self):
27# defer initialization
28if self.dataloader_iter is None:
29self.dataloader_iter = iter(self.dataloader)
30
31self.count += 1
32try:
33return next(self.dataloader_iter)
34except StopIteration:
35self.count = 0
36self.dataloader_iter = iter(self.dataloader)
37return next(self.dataloader_iter)
38
39
40def is_rank_0() -> bool:
41return not dist.is_initialized() or dist.get_rank() == 0
42
43
44def to_device(x: Any, device: torch.device) -> Any:
45def _to(t: Any):
46if isinstance(t, torch.Tensor):
47return t.to(device)
48return t
49
50return tree_map(_to, x)
51