colossalai

Форк
0
31 строка · 1.1 Кб
1
import math
2

3
import numpy as np
4

5

6
class DistributedSampler:
7
    def __init__(self, dataset, num_replicas: int, rank: int) -> None:
8
        self.dataset = dataset
9
        self.num_replicas = num_replicas
10
        self.rank = rank
11

12
        if len(self.dataset) % self.num_replicas != 0:
13
            self.num_samples = math.ceil(
14
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
15
            )
16
        else:
17
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
18

19
        self.total_size = self.num_samples * self.num_replicas
20

21
        indices = list(range(len(self.dataset)))
22
        indices = indices[: self.total_size]
23
        assert len(indices) == self.total_size
24
        # subsample
25
        indices = indices[self.rank : self.total_size : self.num_replicas]
26
        assert len(indices) == self.num_samples
27
        self.indices = indices
28

29
    def sample(self, batch_size: int) -> list:
30
        sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
31
        return [self.dataset[idx] for idx in sampled_indices]
32

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.