lmops

Форк
0
/
dataset_utils.py 
32 строки · 913.0 Байт
1
import random
2
import torch
3

4
def load_train_dataset(dataset,size=None,listify=True):
5
    if size is not None and size<len(dataset['train']):
6
        data = dataset['train']
7
        rand = random.Random(x=42)
8
        index_list = list(range(len(data))) 
9
        rand.shuffle(index_list) #shuffle index_list 
10
        x = data.select(index_list[:size])
11

12
    else:
13
        x = dataset['train']
14
    if listify:
15
        return list(x)
16
    else:
17
        return x
18

19
def pad2sameLen(
20
    values,
21
    pad_idx=0,
22
    left_pad=False
23
):
24
    """Convert a list of 1d tensors into a padded 2d tensor.
25
    ensuring same lengths
26
    """
27
    size = max(v.shape[-1] for v in values)
28
    if left_pad:
29
        res=torch.stack([torch.nn.functional.pad(v,(size-v.shape[-1],0),value=pad_idx) for v in values])
30
    else:
31
        res=torch.stack([torch.nn.functional.pad(v,(0,size-v.shape[-1]),value=pad_idx) for v in values])
32
    return res
33

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.