colossalai
31 строка · 1.1 Кб
1import math2
3import numpy as np4
5
6class DistributedSampler:7def __init__(self, dataset, num_replicas: int, rank: int) -> None:8self.dataset = dataset9self.num_replicas = num_replicas10self.rank = rank11
12if len(self.dataset) % self.num_replicas != 0:13self.num_samples = math.ceil(14(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]15)16else:17self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)18
19self.total_size = self.num_samples * self.num_replicas20
21indices = list(range(len(self.dataset)))22indices = indices[: self.total_size]23assert len(indices) == self.total_size24# subsample25indices = indices[self.rank : self.total_size : self.num_replicas]26assert len(indices) == self.num_samples27self.indices = indices28
29def sample(self, batch_size: int) -> list:30sampled_indices = np.random.choice(self.indices, batch_size, replace=False)31return [self.dataset[idx] for idx in sampled_indices]32