pytorch
120 строк · 3.5 Кб
1# Owner(s): ["oncall: distributed"]
2
3import random4import sys5import unittest6from collections import OrderedDict7from dataclasses import dataclass8from typing import List9
10import torch11import torch.nn as nn12from torch import distributed as dist13from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix14from torch.testing._internal.common_utils import (15instantiate_parametrized_tests,16parametrize,17run_tests,18subtest,19TEST_WITH_DEV_DBG_ASAN,20TestCase,21)
22
23if not dist.is_available():24print("Distributed not available, skipping tests", file=sys.stderr)25sys.exit(0)26
27if TEST_WITH_DEV_DBG_ASAN:28print(29"Skip dev-asan as torch + multiprocessing spawn have known issues",30file=sys.stderr,31)32sys.exit(0)33
34
35class TestUtils(TestCase):36@parametrize(37"devices", [["cpu"], ["cuda"], subtest(["cpu", "cuda"], name="cpu_cuda")]38)39def test_apply_to_tensors(self, devices):40if "cuda" in devices and (41not torch.cuda.is_available() or torch.cuda.device_count() < 142):43raise unittest.SkipTest("Skipped due to lack of GPU")44
45expected = 046
47def get_a_tensor():48"""Return a random tensor on random device."""49dev = random.choice(devices)50shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))51t = torch.rand(shape).to(dev)52nonlocal expected53expected += t.numel()54return t55
56@dataclass57class SomeDataClass:58some_key: str59some_float: float60some_tensor: List[torch.Tensor]61
62# create a mixed bag of data.63data = [1, "str"]64data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})65data.insert(0, {"x", get_a_tensor(), get_a_tensor()})66data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2}))67data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])})68od = OrderedDict()69od["k"] = "value"70data.append(od)71
72total = 073
74def fn(t):75nonlocal total76total += t.numel()77return t78
79new_data = _apply_to_tensors(fn, data)80self.assertEqual(total, expected)81for i, v in enumerate(data):82self.assertEqual(type(new_data[i]), type(v))83
84def test_replace_by_prefix(self):85state_dict = {86"layer.a": torch.tensor(1),87"abc.layer.def": torch.tensor(2),88"layer.b": torch.tensor(3),89}90original_state_dict = state_dict.copy()91_replace_by_prefix(state_dict, "layer.", "module.layer.")92assert state_dict == {93"module.layer.a": torch.tensor(1),94"abc.layer.def": torch.tensor(2),95"module.layer.b": torch.tensor(3),96}97_replace_by_prefix(state_dict, "module.layer.", "layer.")98assert state_dict == original_state_dict99
100def test_packed_sequence(self):101"""Test to ensure RNN packed sequences are modified correctly."""102rnn = nn.RNN(5, 5)103
104x = torch.rand((5, 1, 5), dtype=torch.float)105seq_length = torch.tensor([4], dtype=torch.int)106
107def fill_fn(x):108x.fill_(0)109
110x = nn.utils.rnn.pack_padded_sequence(x, seq_length)111x, h = rnn(x)112x = _apply_to_tensors(fill_fn, x)113x, _ = nn.utils.rnn.pad_packed_sequence(x)114self.assertEqual(torch.sum(x), 0)115
116
117instantiate_parametrized_tests(TestUtils)118
119if __name__ == "__main__":120run_tests()121