paddlenlp

Форк
0
110 строк · 3.9 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import copy
16
import random
17

18
import numpy as np
19
import paddle
20
from ppfleetx.data import dataset as fleetx_dataset
21
from ppfleetx.data import sampler, utils
22
from ppfleetx.distributed.apis import env
23
from ppfleetx.utils.log import logger
24

25

26
def build_auto_dataset(config, mode):
27
    """
28
    build dataset for auto parallel
29
    """
30
    assert mode in ["Train", "Eval", "Test"], "Dataset mode should be Train, Eval, Test"
31

32
    if mode not in config:
33
        return None
34

35
    dataset = build_dataset(config, mode)
36

37
    collate_fn = None
38
    if "collate_fn" in config[mode].keys():
39
        collate_fn_cfg = config[mode].pop("collate_fn", None)
40
        if isinstance(collate_fn_cfg, str):
41
            collate_fn = getattr(utils, collate_fn_cfg) if collate_fn_cfg is not None else None
42
        elif isinstance(collate_fn_cfg, dict):
43
            collate_fn_class_name = collate_fn_cfg.pop("name")
44
            collate_fn = eval("utils.{}".format(collate_fn_class_name))(**collate_fn_cfg)
45
            logger.debug("build collate_fn({}) success...".format(collate_fn))
46

47
    dataset.collate_fn = collate_fn
48
    dataset.sample_split = config[mode].pop("sample_split", None)
49
    return dataset
50

51

52
def build_dataset(config, mode):
53
    # build dataset
54
    config_dataset = config[mode].dataset
55
    config_dataset = copy.deepcopy(config_dataset)
56
    dataset_name = config_dataset.pop("name")
57
    dataset = eval("fleetx_dataset.{}".format(dataset_name))(**config_dataset)
58

59
    logger.debug("build dataset({}) success...".format(dataset))
60

61
    return dataset
62

63

64
def build_dataloader(config, mode):
65
    assert mode in ["Train", "Eval", "Test"], "Dataset mode should be Train, Eval, Test"
66

67
    if mode not in config:
68
        return None
69

70
    dataset = build_dataset(config, mode)
71

72
    batch_sampler = None
73
    # build sampler
74
    if "sampler" in config[mode].keys():
75
        config_sampler = config[mode].sampler
76
        config_sampler = copy.deepcopy(config_sampler)
77
        sampler_name = config_sampler.pop("name")
78
        batch_sampler = eval("sampler.{}".format(sampler_name))(dataset, **config_sampler)
79
        logger.debug("build batch_sampler({}) success...".format(batch_sampler))
80

81
    collate_fn = None
82
    config_loader = {}
83
    # build dataloader
84
    if "loader" in config[mode].keys():
85
        config_loader = config[mode].loader
86
        config_loader = copy.deepcopy(config_loader)
87

88
        collate_fn_cfg = config_loader.pop("collate_fn", None)
89
        if isinstance(collate_fn_cfg, str):
90
            collate_fn = getattr(utils, collate_fn_cfg) if collate_fn_cfg is not None else None
91
        elif isinstance(collate_fn_cfg, dict):
92
            collate_fn_class_name = collate_fn_cfg.pop("name")
93
            collate_fn = eval("utils.{}".format(collate_fn_class_name))(**collate_fn_cfg)
94
            logger.debug("build collate_fn({}) success...".format(collate_fn))
95

96
    def worker_init_fn(worker_id):
97
        """set seed in subproces for dataloader when num_workers > 0"""
98
        np.random.seed(env.get_dp_seed() + worker_id)
99
        random.seed(env.get_dp_seed() + worker_id)
100

101
    data_loader = paddle.io.DataLoader(
102
        dataset=dataset,
103
        batch_sampler=batch_sampler,
104
        collate_fn=collate_fn,
105
        worker_init_fn=worker_init_fn,
106
        **config_loader,
107
    )
108

109
    logger.debug("build data_loader({}) success...".format(data_loader))
110
    return data_loader
111

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.