pytorch

Форк
0
/
test_store.py 
851 строка · 30.2 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import datetime
4
import os
5
import socket
6
import sys
7
import tempfile
8
import time
9
import threading
10
from datetime import timedelta
11
from sys import platform
12

13
import torch
14
import torch.distributed as dist
15
import torch.distributed.distributed_c10d as c10d
16
import torch.distributed.rpc as rpc
17
from torch.distributed import DistNetworkError, DistError, DistStoreError
18
from torch.testing._internal.common_distributed import MultiThreadedTestCase
19
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
20

21
if not dist.is_available():
22
    print("torch.distributed not available, skipping tests", file=sys.stderr)
23
    sys.exit(0)
24

25
import torch.testing._internal.common_utils as common
26
from torch.testing._internal.common_distributed import (
27
    skip_if_win32,
28
    create_tcp_store,
29
    tp_transports
30
)
31
from torch.testing._internal.common_utils import (
32
    TestCase,
33
    load_tests,
34
    run_tests,
35
    retry_on_connect_failures,
36
    ADDRESS_IN_USE,
37
    CONNECT_TIMEOUT,
38
)
39

40
# load_tests from common_utils is used to automatically filter tests for
41
# sharding on sandcastle. This line silences flake warnings
42
load_tests = load_tests
43

44
if platform == "darwin":
45
    LOOPBACK = "lo0"
46
else:
47
    LOOPBACK = "lo"
48

49
DEFAULT_HOSTNAME = "localhost"
50

51
torch.backends.cuda.matmul.allow_tf32 = False
52

53

54
def gpus_for_rank(world_size):
55
    """Multigpu tests are designed to simulate the multi nodes with multi
56
    GPUs on each node. Nccl backend requires equal #GPUs in each process.
57
    On a single node, all visible GPUs are evenly
58
    divided to subsets, each process only uses a subset.
59
    """
60
    visible_devices = list(range(torch.cuda.device_count()))
61
    gpus_per_process = torch.cuda.device_count() // world_size
62
    gpus_for_rank = []
63
    for rank in range(world_size):
64
        gpus_for_rank.append(
65
            visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]
66
        )
67
    return gpus_for_rank
68

69

70
class StoreTestBase:
71
    def _create_store(self, i):
72
        raise RuntimeError("not implemented")
73

74
    def _test_set_get_check(self, fs):
75
        fs.add("key", 1)
76
        fs.add("key", 2)
77
        fs.add("key", 3)
78
        fs.set("key0", "value0")
79
        fs.add("key3", 1)
80
        fs.set("key1", "value1")
81
        fs.add("key3", 2)
82
        fs.set("key2", "value2")
83
        fs.add("key3", 3)
84
        fs.add("key3", 4)
85
        fs.add("key3", 5)
86
        fs.add("key3", 6)
87
        self.assertEqual(fs.num_keys(), self.num_keys_total)
88
        self.assertEqual(b"6", fs.get("key"))
89
        self.assertEqual(b"value0", fs.get("key0"))
90
        self.assertEqual(b"value1", fs.get("key1"))
91
        self.assertEqual(b"value2", fs.get("key2"))
92
        self.assertEqual(b"21", fs.get("key3"))
93
        self.assertTrue(fs.check(["key3"]))
94
        self.assertFalse(fs.check(["Randomkey3"]))
95

96
        fs.set("-key3", "7")
97
        self.assertEqual(b"7", fs.get("-key3"))
98
        fs.delete_key("-key3")
99
        self.assertEqual(fs.num_keys(), self.num_keys_total)
100

101
    def test_set_get_check(self):
102
        self._test_set_get_check(self._create_store())
103

104
    def _test_compare_set(self, store):
105
        missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
106
        self.assertEqual(b"wrong_old_value", missing_key_result)
107

108
        store.set("cs_key0", "value0")
109
        self.assertEqual(b"value0", store.get("cs_key0"))
110
        old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
111
        self.assertEqual(b"value0", old_value_result)
112
        self.assertEqual(b"value0", store.get("cs_key0"))
113
        new_value_result = store.compare_set("cs_key0", "value0", "new_value0")
114
        self.assertEqual(b"new_value0", new_value_result)
115
        self.assertEqual(b"new_value0", store.get("cs_key0"))
116
        empty_old_value_result = store.compare_set("cs_key1", "", "new_value1")
117
        self.assertEqual(b"new_value1", empty_old_value_result)
118
        self.assertEqual(b"new_value1", store.get("cs_key1"))
119

120
    def test_compare_set(self):
121
        self._test_compare_set(self._create_store())
122

123
    def _test_simple_wait(self, fs):
124
        with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"):
125
            fs.wait(["bad_key"], timedelta(seconds=0.25))
126
        fs.add("good_key", 1)
127
        fs.wait(["good_key"])
128

129
    def test_simple_wait(self):
130
        self._test_simple_wait(self._create_store())
131

132
    def _test_append(self, store):
133
        if not store.has_extended_api():
134
            self.skipTest("Store doesn't support extended APIs")
135
        store.set("foo", "po")
136
        store.append("foo", "tato")
137
        store.append("bar", "po")
138
        store.append("bar", "tato")
139
        self.assertEqual(b"potato", store.get("foo"))
140
        self.assertEqual(b"potato", store.get("bar"))
141

142
    def test_append(self):
143
        self._test_append(self._create_store())
144

145
    def _test_multi_set(self, store):
146
        if not store.has_extended_api():
147
            self.skipTest("Store doesn't support extended APIs")
148
        store.multi_set(["foo", "bar"], ["po", "tato"])
149
        self.assertEqual(b"po", store.get("foo"))
150
        self.assertEqual(b"tato", store.get("bar"))
151

152
    def test_multi_set(self):
153
        self._test_multi_set(self._create_store())
154

155
    def _test_multi_get(self, store):
156
        if not store.has_extended_api():
157
            self.skipTest("Store doesn't support extended APIs")
158
        store.set("foo", "po")
159
        store.set("bar", "tato")
160
        v0, v1 = store.multi_get(["foo", "bar"])
161
        self.assertEqual(b"po", v0)
162
        self.assertEqual(b"tato", v1)
163

164
    def test_multi_get(self):
165
        self._test_multi_get(self._create_store())
166

167
    # This is the number of keys used in test_set_get. Adding this as a class
168
    # property instead of hardcoding in the test since some Store
169
    # implementations will have differing number of keys. In the base case,
170
    # there will be 5 keys: key, key0, key1, key2, key3.
171
    @property
172
    def num_keys_total(self):
173
        return 5
174

175

176
class FileStoreTest(TestCase, StoreTestBase):
177
    def setUp(self):
178
        super().setUp()
179
        self.file = tempfile.NamedTemporaryFile(delete=False)
180

181
    def _create_store(self):
182
        store = dist.FileStore(self.file.name, 1)
183
        store.set_timeout(timedelta(seconds=300))
184
        return store
185

186
    def test_init_pg_and_rpc_with_same_file(self):
187
        file = tempfile.NamedTemporaryFile(delete=False)
188
        # Init RPC using file
189
        rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
190
        rpc_backend_options.init_method = f"file://{file.name}"
191
        rpc_backend_options._transports = tp_transports()
192
        rpc.init_rpc("worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options)
193

194
        # Init PG using file
195
        dist.init_process_group("gloo", rank=0, world_size=1, init_method=f"file://{file.name}")
196
        dist.destroy_process_group()
197
        assert os.path.exists(file.name)
198

199
        rpc.shutdown()
200
        os.remove(file.name)
201

202
    def test_refcount(self):
203
        file = tempfile.NamedTemporaryFile(delete=False)
204
        store = dist.FileStore(file.name, 1)
205
        store2 = dist.FileStore(file.name, 1)
206

207
        del store
208
        assert os.path.exists(file.name)
209
        del store2
210
        assert not os.path.exists(file.name)
211

212
    @property
213
    def num_keys_total(self):
214
        return 6
215

216

217
@skip_if_win32()
218
class HashStoreTest(TestCase, StoreTestBase):
219
    def _create_store(self):
220
        store = dist.HashStore()
221
        store.set_timeout(timedelta(seconds=300))
222
        return store
223

224

225
class PrefixStoreTest(TestCase):
226
    def setUp(self):
227
        # delete is false as FileStore will automatically clean up the file
228
        self.file = tempfile.NamedTemporaryFile(delete=False)
229

230
    def test_get_underlying_store(self):
231
        tcp_store = dist.TCPStore(host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True)
232
        hash_store = dist.HashStore()
233
        file_store = dist.FileStore(self.file.name, world_size=1)
234
        for store in [tcp_store, hash_store, file_store]:
235
            with self.subTest(f"Testing getting underlying_store for {type(store)}"):
236
                prefix_store = dist.PrefixStore("prefix", store)
237
                self.assertEqual(prefix_store.underlying_store, store)
238

239

240
class PrefixFileStoreTest(TestCase, StoreTestBase):
241
    def setUp(self):
242
        super().setUp()
243
        self.file = tempfile.NamedTemporaryFile(delete=False)
244
        self.filestore = dist.FileStore(self.file.name, 1)
245
        self.prefix = "test_prefix"
246
        self.filestore.set_timeout(timedelta(seconds=300))
247

248
    def _create_store(self):
249
        return dist.PrefixStore(self.prefix, self.filestore)
250

251
    @property
252
    def num_keys_total(self):
253
        return 6
254

255

256
class TCPStoreTest(TestCase, StoreTestBase):
257
    def _create_store(self):
258
        store = create_tcp_store()
259
        store.set_timeout(timedelta(seconds=300))
260
        return store
261

262
    def _create_store_with_ws(self, addr, world_size):
263
        return create_tcp_store(addr, world_size, wait_for_workers=False)
264

265
    def test_address_already_in_use(self):
266
        err_msg_reg = "^The server socket has failed to listen on any local "
267
        with self.assertRaisesRegex(RuntimeError, err_msg_reg):
268
            addr = DEFAULT_HOSTNAME
269
            port = common.find_free_port()
270

271
            # Use noqa to silence flake8.
272
            # Need to store in an unused variable here to ensure the first
273
            # object is not destroyed before the second object is created.
274
            store1 = dist.TCPStore(addr, port, 1, True)  # noqa: F841
275
            store2 = dist.TCPStore(addr, port, 1, True)  # noqa: F841
276

277
    @retry_on_connect_failures
278
    def test_multitenancy(self):
279
        addr = DEFAULT_HOSTNAME
280
        port = common.find_free_port()
281

282
        # Use noqa to silence flake8.
283
        # Need to store in an unused variable here to ensure the first
284
        # object is not destroyed before the second object is created.
285
        store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841
286
        store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841
287

288
    @skip_if_win32()
289
    @retry_on_connect_failures
290
    def test_init_pg_and_rpc_with_same_socket(self):
291
        addr = DEFAULT_HOSTNAME
292
        port = common.find_free_port()
293

294
        os.environ["MASTER_ADDR"] = addr
295
        os.environ["MASTER_PORT"] = str(port)
296

297
        # We internally use a multi-tenant TCP store. Both PG and RPC should successfully
298
        # initialize even when using the same socket address.
299

300
        dist.init_process_group(
301
            backend="gloo",
302
            init_method="env://",
303
            rank=0,
304
            world_size=1,
305
        )
306

307
        backend_opts = rpc.TensorPipeRpcBackendOptions(
308
            init_method=f"tcp://{addr}:{port}",
309
            _transports=tp_transports()
310
        )
311
        rpc.init_rpc(
312
            name="worker0",
313
            rank=0,
314
            world_size=1,
315
            rpc_backend_options=backend_opts,
316
        )
317

318
        rpc.shutdown()
319

320
    @skip_if_win32()
321
    def test_take_over_listen_socket(self):
322
        listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
323
        listen_sock.bind(("localhost", 0))
324
        addr, port, *_ = listen_sock.getsockname()
325
        listen_fd = listen_sock.detach()
326

327
        store = dist.TCPStore(addr, port, 1, is_master=True, master_listen_fd=listen_fd)
328

329
        store.set("key", "value")
330
        self.assertEqual(b"value", store.get("key"))
331

332
    # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
333
    # the user and one additional key used for coordinate all the workers.
334
    @property
335
    def num_keys_total(self):
336
        return 6
337

338
    def _test_numkeys_delkeys(self, fs):
339
        # We start off with one init key in the store to coordinate workers
340
        self.assertEqual(fs.num_keys(), 1)
341
        fs.add("key", 1)
342
        fs.add("key", 2)
343
        fs.add("key", 3)
344
        fs.set("key0", "value0")
345
        fs.add("key3", 1)
346
        fs.set("key1", "value1")
347
        self.assertEqual(fs.num_keys(), 5)
348
        fs.delete_key("key")
349
        self.assertEqual(fs.num_keys(), 4)
350
        fs.set_timeout(timedelta(seconds=2))
351
        with self.assertRaises(RuntimeError):
352
            fs.get("key")
353
        fs.delete_key("key0")
354
        fs.delete_key("key3")
355
        self.assertEqual(fs.num_keys(), 2)
356
        fs.set("key4", "value2")
357
        self.assertEqual(fs.num_keys(), 3)
358
        self.assertEqual(b"value1", fs.get("key1"))
359
        self.assertEqual(b"value2", fs.get("key4"))
360

361
    def test_numkeys_delkeys(self):
362
        self._test_numkeys_delkeys(self._create_store())
363

364
    def _create_client(self, index, addr, port, world_size):
365
        client_store = dist.TCPStore(addr, port, world_size=world_size, timeout=timedelta(seconds=10))
366
        self.assertEqual(b"value", client_store.get("key"))
367
        client_store.set(f"new_key{index}", f"new_value{index}")
368
        self.assertEqual(f"next_value{index}".encode(),
369
                         client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
370

371
    def _multi_worker_helper(self, world_size):
372
        addr = DEFAULT_HOSTNAME
373
        server_store = self._create_store_with_ws(addr, world_size)
374
        server_store.set("key", "value")
375
        port = server_store.port
376

377
        num_indices = world_size if world_size else 1
378
        for i in range(num_indices):
379
            self._create_client(i, addr, port, world_size)
380

381
    def test_multi_worker_with_fixed_world_size(self):
382
        self._multi_worker_helper(5)
383

384
    def test_multi_worker_with_nonfixed_world_size(self):
385
        self._multi_worker_helper(None)
386

387
    def test_append(self):
388
        store = self._create_store()
389
        store.set("foo", "po")
390
        store.append("foo", "tato")
391
        store.append("bar", "po")
392
        store.append("bar", "tato")
393
        self.assertEqual(b"potato", store.get("foo"))
394
        self.assertEqual(b"potato", store.get("bar"))
395

396
    def test_multi_set(self):
397
        store = self._create_store()
398
        store.multi_set(["foo", "bar"], ["po", "tato"])
399
        self.assertEqual(b"po", store.get("foo"))
400
        self.assertEqual(b"tato", store.get("bar"))
401

402
    def test_multi_get(self):
403
        store = self._create_store()
404
        store.set("foo", "po")
405
        store.set("bar", "tato")
406
        v0, v1 = store.multi_get(["foo", "bar"])
407
        self.assertEqual(b"po", v0)
408
        self.assertEqual(b"tato", v1)
409

410
    def test_store_timeout_on_missing_clients(self):
411
        with self.assertRaisesRegex(DistStoreError, r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined."):
412
            # world_size is 2 so it should timeout
413
            dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2))
414

415
        # when wait_for_workers is not set, then there should be no exception raised
416
        dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2), wait_for_workers=False)
417

418
class LibUvTCPStoreTest(TCPStoreTest):
419

420
    def _create_store(self):
421
        store = create_tcp_store(use_libuv=True)
422
        store.set_timeout(timedelta(seconds=300))
423
        return store
424

425
    def _create_store_with_ws(self, addr, world_size):
426
        return create_tcp_store(addr, world_size, wait_for_workers=False, use_libuv=True)
427

428

429
class PrefixTCPStoreTest(TestCase, StoreTestBase):
430
    def setUp(self):
431
        super().setUp()
432
        self.tcpstore = create_tcp_store()
433
        self.prefix = "test_prefix"
434
        self.tcpstore.set_timeout(timedelta(seconds=300))
435

436
    def _create_store(self):
437
        return dist.PrefixStore(self.prefix, self.tcpstore)
438

439
    # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
440
    # added by the user and one additional key used for coordinate all the
441
    # workers.
442
    @property
443
    def num_keys_total(self):
444
        return 6
445

446
    def test_underlying_non_prefix_store(self):
447
        store = self._create_store()
448
        wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
449
        self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
450
        self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
451

452
class MyPythonStore(dist.Store):
453
    def __init__(self):
454
        super().__init__()
455
        self.store = {}
456

457
    def set(self, key, value):
458
        if not isinstance(key, (str, bytes)):
459
            raise AssertionError("Expected set to be called with string key")
460
        if type(value) is not bytes:
461
            raise AssertionError("Expected set to be called with bytes value")
462
        self.store[key] = value
463

464
    def get(self, key):
465
        value = self.store.get(key, b"")
466
        if type(value) is not bytes:
467
            raise AssertionError("Expected get to return bytes value")
468
        return value
469

470
    def add(self, key, value):
471
        new = int(self.store.get(key, 0)) + value
472
        self.set(key, bytes(str(new).encode("utf-8")))
473
        return new
474

475
    def compare_set(self, key, expected, newValue):
476
        if type(expected) is not bytes:
477
            raise AssertionError("compare_set::expected not bytes")
478
        if type(newValue) is not bytes:
479
            raise AssertionError("compare_set::newValue not bytes")
480

481
        val = self.store.get(key, None)
482
        if expected == val or val is None:
483
            val = self.store[key] = newValue
484
        return val
485

486
class PythonStoreTest(TestCase):
487
    def test_set_get(self):
488
        # If we were to inherit from StoreTestBase and try to use
489
        # its test_set_get function, we would exercise the Python
490
        # API directly, instead of going through the C++ trampoline.
491
        # We care about testing the C++ trampoline, so run the
492
        # equivalent of StoreTestBase.test_set_get from C++.
493
        # See `torch/csrc/distributed/c10d/init.cpp` for the definition
494
        # of this test function.
495
        dist._test_python_store(MyPythonStore())
496

497

498
class RendezvousTest(TestCase):
499
    def test_unknown_handler(self):
500
        with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
501
            dist.rendezvous("invalid://")
502

503
    def test_url_with_node_params(self):
504
        with self.assertRaisesRegex(AssertionError, "has node-specific arguments"):
505
            dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16)
506

507

508
class RendezvousEnvTest(TestCase):
509
    @retry_on_connect_failures
510
    def test_nominal(self):
511
        os.environ["WORLD_SIZE"] = "1"
512
        os.environ["MASTER_ADDR"] = "127.0.0.1"
513
        os.environ["MASTER_PORT"] = str(common.find_free_port())
514

515
        # Single rank
516
        os.environ["RANK"] = "0"
517
        gen0 = dist.rendezvous("env://")
518
        store0, rank0, size0 = next(gen0)
519
        self.assertEqual(0, rank0)
520
        self.assertEqual(1, size0)
521

522
        store0.set("key0", "value0")
523

524
        # check with get
525
        self.assertEqual(b"value0", store0.get("key0"))
526

527

528
class RendezvousFileTest(TestCase):
529
    def test_common_errors(self):
530
        with self.assertRaisesRegex(ValueError, "path missing"):
531
            gen = dist.rendezvous("file://?rank=0&world_size=1")
532
            next(gen)
533
        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
534
            gen = dist.rendezvous("file:///tmp/foo?world_size=1")
535
            next(gen)
536
        with self.assertRaisesRegex(ValueError, "size parameter missing"):
537
            gen = dist.rendezvous("file:///tmp/foo?rank=0")
538
            next(gen)
539

540
    def test_nominal(self):
541
        with tempfile.NamedTemporaryFile(delete=False) as file:
542
            url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
543
            gen0 = dist.rendezvous(url + "&rank=0")
544
            store0, rank0, size0 = next(gen0)
545
            self.assertEqual(0, rank0)
546
            self.assertEqual(2, size0)
547
            gen1 = dist.rendezvous(url + "&rank=1")
548
            store1, rank1, size1 = next(gen1)
549
            self.assertEqual(1, rank1)
550
            self.assertEqual(2, size1)
551

552
            # Set value on both stores
553
            store0.set("key0", "value0")
554
            store1.set("key1", "value1")
555

556
            # Cross check with get
557
            self.assertEqual(b"value0", store1.get("key0"))
558
            self.assertEqual(b"value1", store0.get("key1"))
559

560

561
@skip_if_win32()
562
class RendezvousTCPTest(TestCase):
563
    def create_tcp_url(self):
564
        addr = DEFAULT_HOSTNAME
565
        port = common.find_free_port()
566
        url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
567
        return url
568

569
    def test_common_errors(self):
570
        with self.assertRaisesRegex(ValueError, "port number missing"):
571
            gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
572
            next(gen)
573
        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
574
            gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
575
            next(gen)
576
        with self.assertRaisesRegex(ValueError, "size parameter missing"):
577
            gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
578
            next(gen)
579

580
    def test_dns_timeout(self):
581
        with self.assertRaisesRegex(DistNetworkError, "client socket has timed out after.*dnsnotexist") as manager:
582
            gen = dist.rendezvous(
583
                "tcp://dnsnotexist:23456?world_size=2&rank=0",
584
                timeout=timedelta(seconds=1),
585
            )
586
            next(gen)
587
        self.assertTrue(isinstance(manager.exception, DistError))
588

589
    @retry_on_connect_failures
590
    def test_nominal(self):
591
        url = self.create_tcp_url()
592
        gen0 = dist.rendezvous(url + "&rank=0")
593
        store0, rank0, size0 = next(gen0)
594
        self.assertEqual(0, rank0)
595
        self.assertEqual(1, size0)
596

597
        # Set value on the single store
598
        store0.set("key0", "value0")
599

600
        # check with get
601
        self.assertEqual(b"value0", store0.get("key0"))
602

603
    @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE))
604
    def test_tcp_store_timeout_set(self):
605
        url = self.create_tcp_url()
606
        test_store_timeout = timedelta(seconds=10)
607
        gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
608
        store0, rank0, size0 = next(gen0)
609
        # this should time out in 10s. If the timeout passed into rendezvous was
610
        # not respected, it will take much longer to timeout.
611
        start = time.time()
612
        with self.assertRaisesRegex(RuntimeError, "Timeout"):
613
            store0.get("nonexistant key")
614

615
        end = time.time()
616
        time_diff = end - start
617
        self.assertGreater(test_store_timeout.seconds * 10, time_diff)
618

619
    def test_tcp_store_timeout_doest_break_client(self):
620
        url = self.create_tcp_url()
621
        test_store_timeout = timedelta(seconds=10)
622
        gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
623
        store0, rank0, size0 = next(gen0)
624
        # this should time out in 10s. If the timeout passed into rendezvous was
625
        # not respected, it will take much longer to timeout.
626
        start = time.time()
627
        with self.assertRaisesRegex(RuntimeError, "Timeout"):
628
            store0.get("the_key")
629

630
        store0.set("the_key", "x")
631

632
        self.assertEqual(b"x", store0.get("the_key"))
633

634
        end = time.time()
635
        time_diff = end - start
636
        self.assertGreater(test_store_timeout.seconds * 10, time_diff)
637

638
    def test_tcp_store_url_with_libuv(self):
639
        url = self.create_tcp_url()
640
        gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
641
        store0, rank0, size0 = next(gen0)
642
        self.assertTrue(store0.libuvBackend)
643

644
class DummyStore(dist.Store):
645
    def __init__(self):
646
        self.appends = []
647
        self.multi_sets = []
648
        self.multi_gets = []
649
        self.multi_get_res = []
650
        super().__init__()
651

652
    def append(self, key, value):
653
        self.appends.append((key, value))
654

655
    def multi_get(self, keys):
656
        self.multi_gets.append(keys)
657
        return self.multi_get_res.pop(0)
658

659
    def multi_set(self, keys, values):
660
        self.multi_sets.append((keys, values))
661

662
    def has_extended_api(self):
663
        return True
664

665
class TestPythonStore(TestCase):
666
    def test_optional_methods_fail(self):
667
        class TestStore(dist.Store):
668
            pass
669
        store = TestStore()
670
        self.assertFalse(store.has_extended_api())
671
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
672
            store.append("foo", "bar")
673
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
674
            store.multi_get(["foo", "bar"])
675
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
676
            store.multi_set(["foo", "bar"], [b"v", b"v"])
677

678
    def test_has_extended_api_passthrough(self):
679
        class TestStore(dist.Store):
680
            pass
681
        test_store = TestStore()
682
        store = dist.PrefixStore("p", test_store)
683
        self.assertFalse(store.has_extended_api())
684
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
685
            store.append("foo", "bar")
686
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
687
            store.multi_get(["foo", "bar"])
688
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
689
            store.multi_set(["foo", "bar"], [b"v", b"v"])
690

691
    def test_has_extended_api_roundtrip(self):
692
        store = DummyStore()
693
        prefix = dist.PrefixStore("p", store)
694
        self.assertTrue(prefix.has_extended_api())
695

696
    def test_append_roundtrip(self):
697
        store = DummyStore()
698
        prefix = dist.PrefixStore("p", store)
699
        prefix.append("foo", "bar")
700
        self.assertEqual(1, len(store.appends))
701
        self.assertEqual(("p/foo", b"bar"), store.appends[0])
702

703
    def test_multi_get_roundtrip(self):
704
        store = DummyStore()
705
        prefix = dist.PrefixStore("p", store)
706
        store.multi_get_res.append([b"x", b"y"])
707
        res = prefix.multi_get(["foo", "bar"])
708
        self.assertEqual(1, len(store.multi_gets))
709
        self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0])
710
        self.assertEqual([b"x", b"y"], res)
711

712
    def test_multi_set_roundtrip(self):
713
        store = DummyStore()
714
        prefix = dist.PrefixStore("p", store)
715
        prefix.multi_set(["foo", "bar"], [b'x', b'y'])
716
        self.assertEqual(1, len(store.multi_sets))
717
        self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0])
718
        self.assertEqual([b'x', b'y'], store.multi_sets[0][1])
719

720
    def test_extended_methods_fallbacks(self):
721
        test_store = MyPythonStore()
722
        store = dist.PrefixStore("p", test_store)
723
        self.assertFalse(store.has_extended_api())
724
        store.append("foo", b"po")
725
        store.append("foo", b"tato")
726
        self.assertEqual(store.get("foo"), b"potato")
727

728
        store.multi_set(["a", "b"], [b"c", b"d"])
729
        self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"])
730

731

732
class TestMultiThreadedWait(MultiThreadedTestCase):
733
    # TODO: Use less hacky means of instantiating stores.
734
    # Note, stores accumulate values per test.
735
    stores = [
736
        dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1),
737
        dist.HashStore(),
738
        dist.PrefixStore("pre", dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1)),
739
        create_tcp_store(),
740
        create_tcp_store(use_libuv=True),
741
        dist.PrefixStore("pre", create_tcp_store()),
742
        dist.PrefixStore("pre", create_tcp_store(use_libuv=True)),
743
    ]
744

745
    @property
746
    def world_size(self):
747
        return 2
748

749
    def setUp(self):
750
        super().setUp()
751
        self._spawn_threads()
752

753
    # Iterates over self.stores, keep 7 in sync with len(self.stores).
754
    @parametrize("i", range(7))
755
    def test_wait(self, i):
756
        store = self.stores[i]
757
        store.set_timeout(timedelta(seconds=2))
758
        if dist.get_rank() == 0:
759
            store.wait(["key1"])
760
            self.assertEqual(b"value1", store.get("key1"))
761
        if dist.get_rank() == 1:
762
            store.set("key1", "value1")
763

764

765
instantiate_parametrized_tests(TestMultiThreadedWait)
766

767
@skip_if_win32()
768
class TimeoutTest(TestCase):
769
    def tearDown(self):
770
        import signal
771
        super().tearDown()
772
        signal.signal(signal.SIGUSR1, signal.SIG_IGN)
773

774
    def test_interrupt_doesnt_break_wait(self):
775
        import signal
776
        rank_res = [None, None]
777

778
        def run(rank, my_store):
779
            nonlocal rank_res
780
            try:
781
                if rank == 0:
782
                    time.sleep(4)
783
                    my_store.set("foo", "bar")
784
                else:
785
                    my_store.wait(["foo"], datetime.timedelta(seconds=10))
786
                rank_res[rank] = True
787
            except Error as e:  # noqa: F821
788
                rank_res[rank] = e
789
            time.sleep(1)
790

791
        rank0_store = dist.TCPStore(
792
            host_name=DEFAULT_HOSTNAME, port=0, world_size=2, is_master=True, wait_for_workers=False)
793
        rank1_store = dist.TCPStore(
794
            host_name=DEFAULT_HOSTNAME, port=rank0_store.port, world_size=2, is_master=False, wait_for_workers=False)
795

796
        ths = []
797
        for i in range(2):
798
            t = threading.Thread(target=run, args=(i, [rank0_store, rank1_store][i],))
799
            t.start()
800
            ths.append(t)
801

802
        def handler(a, b):
803
            pass
804

805
        signal.signal(signal.SIGUSR1, handler)
806
        time.sleep(1)
807
        signal.pthread_kill(ths[1].ident, signal.SIGUSR1)
808

809
        for t in ths:
810
            t.join()
811
        self.assertTrue(rank_res[0], "rank0")
812
        self.assertTrue(rank_res[1], "rank1")
813

814

815
class InitPgWithUvStore(TestCase):
816
    def tearDown(self):
817
        super().tearDown()
818
        os.environ.pop("USE_LIBUV", None)
819
        os.environ.pop("MASTER_ADDR", None)
820
        os.environ.pop("MASTER_PORT", None)
821

822
    def test_with_url_param(self):
823
        port = common.find_free_port()
824
        dist.init_process_group("gloo", rank=0, world_size=1, init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=1")
825
        self._run_test()
826

827
    def test_with_env_var(self):
828
        port = common.find_free_port()
829
        os.environ["USE_LIBUV"] = "1"
830
        os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME
831
        os.environ["MASTER_PORT"] = str(port)
832
        dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://")
833
        self._run_test()
834

835
    def _run_test(self):
836
        pg = dist.group.WORLD
837
        store = c10d._get_process_group_store(pg)
838
        self.assertTrue(isinstance(store, dist.PrefixStore))
839
        # c10d does multiple levels of wrapping
840
        while isinstance(store, dist.PrefixStore):
841
            store = store.underlying_store
842
        self.assertTrue(isinstance(store, dist.TCPStore))
843
        self.assertTrue(store.libuvBackend)
844
        dist.destroy_process_group()
845

846
if __name__ == "__main__":
847
    assert (
848
        not torch.cuda._initialized
849
    ), "test_distributed must not have initialized CUDA context on main process"
850

851
    run_tests()
852

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.