colossalai

Форк
0
50 строк · 1.4 Кб
1
from typing import Any
2

3
import torch
4
import torch.distributed as dist
5
from torch.utils._pytree import tree_map
6
from torch.utils.data import DataLoader
7

8

9
class CycledDataLoader:
10
    """
11
    Why do we need this class?
12
    In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
13
    However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
14
    NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
15
    """
16

17
    def __init__(
18
        self,
19
        dataloader: DataLoader,
20
    ) -> None:
21
        self.dataloader = dataloader
22

23
        self.count = 0
24
        self.dataloader_iter = None
25

26
    def next(self):
27
        # defer initialization
28
        if self.dataloader_iter is None:
29
            self.dataloader_iter = iter(self.dataloader)
30

31
        self.count += 1
32
        try:
33
            return next(self.dataloader_iter)
34
        except StopIteration:
35
            self.count = 0
36
            self.dataloader_iter = iter(self.dataloader)
37
            return next(self.dataloader_iter)
38

39

40
def is_rank_0() -> bool:
41
    return not dist.is_initialized() or dist.get_rank() == 0
42

43

44
def to_device(x: Any, device: torch.device) -> Any:
45
    def _to(t: Any):
46
        if isinstance(t, torch.Tensor):
47
            return t.to(device)
48
        return t
49

50
    return tree_map(_to, x)
51

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.