pytorch

Форк
0
/
test_store.py 
1086 строк · 35.8 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import datetime
4
import os
5
import socket
6
import struct
7
import sys
8
import tempfile
9
import threading
10
import time
11
from datetime import timedelta
12
from sys import platform
13

14
import torch
15
import torch.distributed as dist
16
import torch.distributed.distributed_c10d as c10d
17
import torch.distributed.rpc as rpc
18
from torch.distributed import DistError, DistNetworkError, DistStoreError
19
from torch.testing._internal.common_distributed import MultiThreadedTestCase
20
from torch.testing._internal.common_utils import instantiate_parametrized_tests
21

22

23
if not dist.is_available():
24
    print("torch.distributed not available, skipping tests", file=sys.stderr)
25
    sys.exit(0)
26

27
import torch.testing._internal.common_utils as common
28
from torch.testing._internal.common_distributed import (
29
    create_tcp_store,
30
    skip_if_win32,
31
    tp_transports,
32
)
33
from torch.testing._internal.common_utils import (
34
    ADDRESS_IN_USE,
35
    CONNECT_TIMEOUT,
36
    load_tests,
37
    retry_on_connect_failures,
38
    run_tests,
39
    TestCase,
40
)
41

42

43
# load_tests from common_utils is used to automatically filter tests for
44
# sharding on sandcastle. This line silences flake warnings
45
load_tests = load_tests
46

47
if platform == "darwin":
48
    LOOPBACK = "lo0"
49
else:
50
    LOOPBACK = "lo"
51

52
DEFAULT_HOSTNAME = "localhost"
53

54
torch.backends.cuda.matmul.allow_tf32 = False
55

56

57
def gpus_for_rank(world_size):
58
    """Multigpu tests are designed to simulate the multi nodes with multi
59
    GPUs on each node. Nccl backend requires equal #GPUs in each process.
60
    On a single node, all visible GPUs are evenly
61
    divided to subsets, each process only uses a subset.
62
    """
63
    visible_devices = list(range(torch.cuda.device_count()))
64
    gpus_per_process = torch.cuda.device_count() // world_size
65
    gpus_for_rank = []
66
    for rank in range(world_size):
67
        gpus_for_rank.append(
68
            visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process]
69
        )
70
    return gpus_for_rank
71

72

73
class StoreTestBase:
74
    def _create_store(self, i):
75
        raise RuntimeError("not implemented")
76

77
    def _test_set_get_check(self, fs):
78
        fs.add("key", 1)
79
        fs.add("key", 2)
80
        fs.add("key", 3)
81
        fs.set("key0", "value0")
82
        fs.add("key3", 1)
83
        fs.set("key1", "value1")
84
        fs.add("key3", 2)
85
        fs.set("key2", "value2")
86
        fs.add("key3", 3)
87
        fs.add("key3", 4)
88
        fs.add("key3", 5)
89
        fs.add("key3", 6)
90
        self.assertEqual(fs.num_keys(), self.num_keys_total)
91
        self.assertEqual(b"6", fs.get("key"))
92
        self.assertEqual(b"value0", fs.get("key0"))
93
        self.assertEqual(b"value1", fs.get("key1"))
94
        self.assertEqual(b"value2", fs.get("key2"))
95
        self.assertEqual(b"21", fs.get("key3"))
96
        self.assertTrue(fs.check(["key3"]))
97
        self.assertFalse(fs.check(["Randomkey3"]))
98

99
        fs.set("-key3", "7")
100
        self.assertEqual(b"7", fs.get("-key3"))
101
        fs.delete_key("-key3")
102
        self.assertEqual(fs.num_keys(), self.num_keys_total)
103

104
    def test_set_get_check(self):
105
        self._test_set_get_check(self._create_store())
106

107
    def _test_compare_set(self, store):
108
        missing_key_result = store.compare_set(
109
            "cs_key0", "wrong_old_value", "new_value0"
110
        )
111
        self.assertEqual(b"wrong_old_value", missing_key_result)
112

113
        store.set("cs_key0", "value0")
114
        self.assertEqual(b"value0", store.get("cs_key0"))
115
        old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
116
        self.assertEqual(b"value0", old_value_result)
117
        self.assertEqual(b"value0", store.get("cs_key0"))
118
        new_value_result = store.compare_set("cs_key0", "value0", "new_value0")
119
        self.assertEqual(b"new_value0", new_value_result)
120
        self.assertEqual(b"new_value0", store.get("cs_key0"))
121
        empty_old_value_result = store.compare_set("cs_key1", "", "new_value1")
122
        self.assertEqual(b"new_value1", empty_old_value_result)
123
        self.assertEqual(b"new_value1", store.get("cs_key1"))
124

125
    def test_compare_set(self):
126
        self._test_compare_set(self._create_store())
127

128
    def _test_simple_wait(self, fs):
129
        with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"):
130
            fs.wait(["bad_key"], timedelta(seconds=0.25))
131
        fs.add("good_key", 1)
132
        fs.wait(["good_key"])
133

134
    def test_simple_wait(self):
135
        self._test_simple_wait(self._create_store())
136

137
    def _test_append(self, store):
138
        if not store.has_extended_api():
139
            self.skipTest("Store doesn't support extended APIs")
140
        store.set("foo", "po")
141
        store.append("foo", "tato")
142
        store.append("bar", "po")
143
        store.append("bar", "tato")
144
        self.assertEqual(b"potato", store.get("foo"))
145
        self.assertEqual(b"potato", store.get("bar"))
146

147
    def test_append(self):
148
        self._test_append(self._create_store())
149

150
    def _test_multi_set(self, store):
151
        if not store.has_extended_api():
152
            self.skipTest("Store doesn't support extended APIs")
153
        store.multi_set(["foo", "bar"], ["po", "tato"])
154
        self.assertEqual(b"po", store.get("foo"))
155
        self.assertEqual(b"tato", store.get("bar"))
156

157
    def test_multi_set(self):
158
        self._test_multi_set(self._create_store())
159

160
    def _test_multi_get(self, store):
161
        if not store.has_extended_api():
162
            self.skipTest("Store doesn't support extended APIs")
163
        store.set("foo", "po")
164
        store.set("bar", "tato")
165
        v0, v1 = store.multi_get(["foo", "bar"])
166
        self.assertEqual(b"po", v0)
167
        self.assertEqual(b"tato", v1)
168

169
    def test_multi_get(self):
170
        self._test_multi_get(self._create_store())
171

172
    # This is the number of keys used in test_set_get. Adding this as a class
173
    # property instead of hardcoding in the test since some Store
174
    # implementations will have differing number of keys. In the base case,
175
    # there will be 5 keys: key, key0, key1, key2, key3.
176
    @property
177
    def num_keys_total(self):
178
        return 5
179

180

181
class FileStoreTest(TestCase, StoreTestBase):
182
    def setUp(self):
183
        super().setUp()
184
        self.file = tempfile.NamedTemporaryFile(delete=False)
185

186
    def _create_store(self):
187
        store = dist.FileStore(self.file.name, 1)
188
        store.set_timeout(timedelta(seconds=300))
189
        return store
190

191
    def test_init_pg_and_rpc_with_same_file(self):
192
        file = tempfile.NamedTemporaryFile(delete=False)
193
        # Init RPC using file
194
        rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
195
        rpc_backend_options.init_method = f"file://{file.name}"
196
        rpc_backend_options._transports = tp_transports()
197
        rpc.init_rpc(
198
            "worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options
199
        )
200

201
        # Init PG using file
202
        dist.init_process_group(
203
            "gloo", rank=0, world_size=1, init_method=f"file://{file.name}"
204
        )
205
        dist.destroy_process_group()
206
        assert os.path.exists(file.name)
207

208
        rpc.shutdown()
209
        os.remove(file.name)
210

211
    def test_refcount(self):
212
        file = tempfile.NamedTemporaryFile(delete=False)
213
        store = dist.FileStore(file.name, 1)
214
        store2 = dist.FileStore(file.name, 1)
215

216
        del store
217
        assert os.path.exists(file.name)
218
        del store2
219
        assert not os.path.exists(file.name)
220

221
    @property
222
    def num_keys_total(self):
223
        return 6
224

225

226
@skip_if_win32()
227
class HashStoreTest(TestCase, StoreTestBase):
228
    def _create_store(self):
229
        store = dist.HashStore()
230
        store.set_timeout(timedelta(seconds=300))
231
        return store
232

233

234
class PrefixStoreTest(TestCase):
235
    def setUp(self):
236
        # delete is false as FileStore will automatically clean up the file
237
        self.file = tempfile.NamedTemporaryFile(delete=False)
238

239
    def test_get_underlying_store(self):
240
        tcp_store = dist.TCPStore(
241
            host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True
242
        )
243
        hash_store = dist.HashStore()
244
        file_store = dist.FileStore(self.file.name, world_size=1)
245
        for store in [tcp_store, hash_store, file_store]:
246
            with self.subTest(f"Testing getting underlying_store for {type(store)}"):
247
                prefix_store = dist.PrefixStore("prefix", store)
248
                self.assertEqual(prefix_store.underlying_store, store)
249

250

251
class PrefixFileStoreTest(TestCase, StoreTestBase):
252
    def setUp(self):
253
        super().setUp()
254
        self.file = tempfile.NamedTemporaryFile(delete=False)
255
        self.filestore = dist.FileStore(self.file.name, 1)
256
        self.prefix = "test_prefix"
257
        self.filestore.set_timeout(timedelta(seconds=300))
258

259
    def _create_store(self):
260
        return dist.PrefixStore(self.prefix, self.filestore)
261

262
    @property
263
    def num_keys_total(self):
264
        return 6
265

266

267
class TCPStoreTest(TestCase, StoreTestBase):
268
    _use_libuv = False
269

270
    def _create_store(self):
271
        store = create_tcp_store(use_libuv=self._use_libuv)
272
        store.set_timeout(timedelta(seconds=300))
273
        return store
274

275
    def _create_store_with_ws(self, addr, world_size):
276
        return create_tcp_store(
277
            addr, world_size, wait_for_workers=False, use_libuv=self._use_libuv
278
        )
279

280
    def test_address_already_in_use(self):
281
        addr = DEFAULT_HOSTNAME
282
        port = common.find_free_port()
283

284
        err_msg_reg = f"^The server socket has failed to listen on any local .*{port}"
285
        with self.assertRaisesRegex(RuntimeError, err_msg_reg):
286
            # Use noqa to silence flake8.
287
            # Need to store in an unused variable here to ensure the first
288
            # object is not destroyed before the second object is created.
289
            store1 = dist.TCPStore(
290
                addr, port, 1, True, use_libuv=self._use_libuv
291
            )  # noqa: F841
292
            store2 = dist.TCPStore(
293
                addr, port, 1, True, use_libuv=self._use_libuv
294
            )  # noqa: F841
295
            self.assertEqual(store1.libuvBackend, self._use_libuv)
296
            self.assertEqual(store2.libuvBackend, self._use_libuv)
297

298
    @retry_on_connect_failures
299
    def test_multitenancy(self):
300
        addr = DEFAULT_HOSTNAME
301
        port = common.find_free_port()
302

303
        # Use noqa to silence flake8.
304
        # Need to store in an unused variable here to ensure the first
305
        # object is not destroyed before the second object is created.
306
        store1 = dist.TCPStore(
307
            addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv
308
        )  # type: ignore[call-arg] # noqa: F841
309
        store2 = dist.TCPStore(
310
            addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv
311
        )  # type: ignore[call-arg] # noqa: F841
312
        self.assertEqual(store1.libuvBackend, self._use_libuv)
313
        self.assertEqual(store2.libuvBackend, self._use_libuv)
314

315
    def test_repr(self) -> None:
316
        # server
317
        store1 = self._create_store()
318
        self.assertRegex(
319
            repr(store1),
320
            r"TCPStore\("
321
            r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
322
            r"server=TCPServer\(port=\d+\)\)",
323
        )
324

325
        # client
326
        store2 = dist.TCPStore(
327
            store1.host,
328
            store1.port,
329
            world_size=2,
330
            is_master=False,
331
        )
332
        self.assertRegex(
333
            repr(store2),
334
            r"TCPStore\("
335
            r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
336
            r"server=<nullptr>\)",
337
        )
338

339
    @skip_if_win32()
340
    @retry_on_connect_failures
341
    def test_init_pg_and_rpc_with_same_socket(self):
342
        addr = DEFAULT_HOSTNAME
343
        port = common.find_free_port()
344

345
        os.environ["MASTER_ADDR"] = addr
346
        os.environ["MASTER_PORT"] = str(port)
347

348
        # We internally use a multi-tenant TCP store. Both PG and RPC should successfully
349
        # initialize even when using the same socket address.
350

351
        os.environ["USE_LIBUV"] = "1" if self._use_libuv else "0"
352
        dist.init_process_group(
353
            backend="gloo",
354
            init_method="env://",
355
            rank=0,
356
            world_size=1,
357
        )
358

359
        backend_opts = rpc.TensorPipeRpcBackendOptions(
360
            init_method=f"tcp://{addr}:{port}", _transports=tp_transports()
361
        )
362
        rpc.init_rpc(
363
            name="worker0",
364
            rank=0,
365
            world_size=1,
366
            rpc_backend_options=backend_opts,
367
        )
368

369
        del os.environ["USE_LIBUV"]
370
        assert "USE_LIBUV" not in os.environ
371
        rpc.shutdown()
372
        dist.destroy_process_group()
373

374
    @skip_if_win32()
375
    def test_take_over_listen_socket(self):
376
        listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
377
        listen_sock.bind(("localhost", 0))
378
        addr, port, *_ = listen_sock.getsockname()
379
        listen_fd = listen_sock.detach()
380

381
        store = dist.TCPStore(
382
            addr,
383
            port,
384
            1,
385
            is_master=True,
386
            master_listen_fd=listen_fd,
387
            use_libuv=self._use_libuv,
388
        )
389

390
        self.assertEqual(store.libuvBackend, self._use_libuv)
391
        store.set("key", "value")
392
        self.assertEqual(b"value", store.get("key"))
393

394
    # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
395
    # the user and one additional key used for coordinate all the workers.
396
    @property
397
    def num_keys_total(self):
398
        return 6
399

400
    def _test_numkeys_delkeys(self, fs):
401
        # We start off with one init key in the store to coordinate workers
402
        self.assertEqual(fs.num_keys(), 1)
403
        fs.add("key", 1)
404
        fs.add("key", 2)
405
        fs.add("key", 3)
406
        fs.set("key0", "value0")
407
        fs.add("key3", 1)
408
        fs.set("key1", "value1")
409
        self.assertEqual(fs.num_keys(), 5)
410
        fs.delete_key("key")
411
        self.assertEqual(fs.num_keys(), 4)
412
        fs.set_timeout(timedelta(seconds=2))
413
        with self.assertRaises(RuntimeError):
414
            fs.get("key")
415
        fs.delete_key("key0")
416
        fs.delete_key("key3")
417
        self.assertEqual(fs.num_keys(), 2)
418
        fs.set("key4", "value2")
419
        self.assertEqual(fs.num_keys(), 3)
420
        self.assertEqual(b"value1", fs.get("key1"))
421
        self.assertEqual(b"value2", fs.get("key4"))
422

423
    def test_numkeys_delkeys(self):
424
        self._test_numkeys_delkeys(self._create_store())
425

426
    def _create_client(self, index, addr, port, world_size):
427
        client_store = dist.TCPStore(
428
            addr,
429
            port,
430
            world_size=world_size,
431
            timeout=timedelta(seconds=10),
432
            use_libuv=self._use_libuv,
433
        )
434
        self.assertEqual(b"value", client_store.get("key"))
435
        client_store.set(f"new_key{index}", f"new_value{index}")
436
        self.assertEqual(
437
            f"next_value{index}".encode(),
438
            client_store.compare_set(
439
                f"new_key{index}", f"new_value{index}", f"next_value{index}"
440
            ),
441
        )
442

443
    def _multi_worker_helper(self, world_size):
444
        addr = DEFAULT_HOSTNAME
445
        server_store = self._create_store_with_ws(addr, world_size)
446
        self.assertEqual(server_store.libuvBackend, self._use_libuv)
447
        server_store.set("key", "value")
448
        port = server_store.port
449

450
        num_indices = world_size if world_size else 1
451
        for i in range(num_indices):
452
            self._create_client(i, addr, port, world_size)
453

454
    def test_multi_worker_with_fixed_world_size(self):
455
        self._multi_worker_helper(5)
456

457
    def test_multi_worker_with_nonfixed_world_size(self):
458
        self._multi_worker_helper(None)
459

460
    def test_append(self):
461
        store = self._create_store()
462
        self.assertEqual(store.libuvBackend, self._use_libuv)
463
        store.set("foo", "po")
464
        store.append("foo", "tato")
465
        store.append("bar", "po")
466
        store.append("bar", "tato")
467
        self.assertEqual(b"potato", store.get("foo"))
468
        self.assertEqual(b"potato", store.get("bar"))
469

470
    def test_multi_set(self):
471
        store = self._create_store()
472
        self.assertEqual(store.libuvBackend, self._use_libuv)
473
        store.multi_set(["foo", "bar"], ["po", "tato"])
474
        self.assertEqual(b"po", store.get("foo"))
475
        self.assertEqual(b"tato", store.get("bar"))
476

477
    def test_multi_get(self):
478
        store = self._create_store()
479
        self.assertEqual(store.libuvBackend, self._use_libuv)
480
        store.set("foo", "po")
481
        store.set("bar", "tato")
482
        v0, v1 = store.multi_get(["foo", "bar"])
483
        self.assertEqual(b"po", v0)
484
        self.assertEqual(b"tato", v1)
485

486
    def test_store_timeout_on_missing_clients(self):
487
        with self.assertRaisesRegex(
488
            DistStoreError,
489
            r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined.",
490
        ):
491
            # world_size is 2 so it should timeout
492
            dist.TCPStore(
493
                "localhost",
494
                0,
495
                2,
496
                True,
497
                timeout=timedelta(seconds=2),
498
                use_libuv=self._use_libuv,
499
            )
500

501
        # when wait_for_workers is not set, then there should be no exception raised
502
        dist.TCPStore(
503
            "localhost",
504
            0,
505
            2,
506
            True,
507
            timeout=timedelta(seconds=2),
508
            wait_for_workers=False,
509
            use_libuv=self._use_libuv,
510
        )
511

512

513
class LibUvTCPStoreTest(TCPStoreTest):
514
    _use_libuv = True
515

516
    def _create_store(self):
517
        store = create_tcp_store(use_libuv=True)
518
        store.set_timeout(timedelta(seconds=300))
519
        return store
520

521
    def _create_store_with_ws(self, addr, world_size):
522
        return create_tcp_store(
523
            addr, world_size, wait_for_workers=False, use_libuv=True
524
        )
525

526
    def test_take_over_listen_socket(self):
527
        """
528
        override the take_over_listen_socket test in TCPStoreTest.
529
        Reason: we have not thoroughly tested libuv TCPStore initialization using
530
        open Socket so we decide to not support this use for now.
531
        TODO (xilunwu): enable this use case
532
        """
533
        listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
534
        listen_sock.bind(("localhost", 0))
535
        addr, port, *_ = listen_sock.getsockname()
536
        listen_fd = listen_sock.detach()
537

538
        err_msg_reg = (
539
            "^The libuv TCPStore backend does not support "
540
            "initialization with an listen fd"
541
        )
542

543
        with self.assertRaisesRegex(NotImplementedError, err_msg_reg):
544
            store = dist.TCPStore(
545
                addr,
546
                port,
547
                1,
548
                is_master=True,
549
                master_listen_fd=listen_fd,
550
                use_libuv=self._use_libuv,
551
            )
552

553

554
class PrefixTCPStoreTest(TestCase, StoreTestBase):
555
    def setUp(self):
556
        super().setUp()
557
        self.tcpstore = create_tcp_store()
558
        self.prefix = "test_prefix"
559
        self.tcpstore.set_timeout(timedelta(seconds=300))
560

561
    def _create_store(self):
562
        return dist.PrefixStore(self.prefix, self.tcpstore)
563

564
    # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
565
    # added by the user and one additional key used for coordinate all the
566
    # workers.
567
    @property
568
    def num_keys_total(self):
569
        return 6
570

571
    def test_underlying_non_prefix_store(self):
572
        store = self._create_store()
573
        wrapped_store = dist.PrefixStore(
574
            self.prefix, dist.PrefixStore(self.prefix, store)
575
        )
576
        self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
577
        self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
578

579

580
class MyPythonStore(dist.Store):
581
    def __init__(self) -> None:
582
        super().__init__()
583
        self.store = {}
584

585
    def set(self, key, value):
586
        if not isinstance(key, (str, bytes)):
587
            raise AssertionError("Expected set to be called with string key")
588
        if type(value) is not bytes:
589
            raise AssertionError("Expected set to be called with bytes value")
590
        self.store[key] = value
591

592
    def get(self, key):
593
        value = self.store.get(key, b"")
594
        if type(value) is not bytes:
595
            raise AssertionError("Expected get to return bytes value")
596
        return value
597

598
    def add(self, key, value):
599
        new = int(self.store.get(key, 0)) + value
600
        self.set(key, bytes(str(new).encode("utf-8")))
601
        return new
602

603
    def compare_set(self, key, expected, newValue):
604
        if type(expected) is not bytes:
605
            raise AssertionError("compare_set::expected not bytes")
606
        if type(newValue) is not bytes:
607
            raise AssertionError("compare_set::newValue not bytes")
608

609
        val = self.store.get(key, None)
610
        if expected == val or val is None:
611
            val = self.store[key] = newValue
612
        return val
613

614

615
class PythonStoreTest(TestCase):
616
    def test_set_get(self):
617
        # If we were to inherit from StoreTestBase and try to use
618
        # its test_set_get function, we would exercise the Python
619
        # API directly, instead of going through the C++ trampoline.
620
        # We care about testing the C++ trampoline, so run the
621
        # equivalent of StoreTestBase.test_set_get from C++.
622
        # See `torch/csrc/distributed/c10d/init.cpp` for the definition
623
        # of this test function.
624
        dist._test_python_store(MyPythonStore())
625

626

627
class RendezvousTest(TestCase):
628
    def test_unknown_handler(self):
629
        with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
630
            dist.rendezvous("invalid://")
631

632
    def test_url_with_node_params(self):
633
        with self.assertRaisesRegex(AssertionError, "has node-specific arguments"):
634
            dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16)
635

636

637
class RendezvousEnvTest(TestCase):
638
    @retry_on_connect_failures
639
    def test_nominal(self):
640
        os.environ["WORLD_SIZE"] = "1"
641
        os.environ["MASTER_ADDR"] = "127.0.0.1"
642
        os.environ["MASTER_PORT"] = str(common.find_free_port())
643

644
        # Single rank
645
        os.environ["RANK"] = "0"
646
        gen0 = dist.rendezvous("env://")
647
        store0, rank0, size0 = next(gen0)
648
        self.assertEqual(0, rank0)
649
        self.assertEqual(1, size0)
650

651
        store0.set("key0", "value0")
652

653
        # check with get
654
        self.assertEqual(b"value0", store0.get("key0"))
655

656

657
class RendezvousFileTest(TestCase):
658
    def test_common_errors(self):
659
        with self.assertRaisesRegex(ValueError, "path missing"):
660
            gen = dist.rendezvous("file://?rank=0&world_size=1")
661
            next(gen)
662
        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
663
            gen = dist.rendezvous("file:///tmp/foo?world_size=1")
664
            next(gen)
665
        with self.assertRaisesRegex(ValueError, "size parameter missing"):
666
            gen = dist.rendezvous("file:///tmp/foo?rank=0")
667
            next(gen)
668

669
    def test_nominal(self):
670
        with tempfile.NamedTemporaryFile(delete=False) as file:
671
            url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
672
            gen0 = dist.rendezvous(url + "&rank=0")
673
            store0, rank0, size0 = next(gen0)
674
            self.assertEqual(0, rank0)
675
            self.assertEqual(2, size0)
676
            gen1 = dist.rendezvous(url + "&rank=1")
677
            store1, rank1, size1 = next(gen1)
678
            self.assertEqual(1, rank1)
679
            self.assertEqual(2, size1)
680

681
            # Set value on both stores
682
            store0.set("key0", "value0")
683
            store1.set("key1", "value1")
684

685
            # Cross check with get
686
            self.assertEqual(b"value0", store1.get("key0"))
687
            self.assertEqual(b"value1", store0.get("key1"))
688

689

690
@skip_if_win32()
691
class RendezvousTCPTest(TestCase):
692
    def create_tcp_url(self):
693
        addr = DEFAULT_HOSTNAME
694
        port = common.find_free_port()
695
        url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
696
        return url
697

698
    def test_common_errors(self):
699
        with self.assertRaisesRegex(ValueError, "port number missing"):
700
            gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
701
            next(gen)
702
        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
703
            gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
704
            next(gen)
705
        with self.assertRaisesRegex(ValueError, "size parameter missing"):
706
            gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
707
            next(gen)
708

709
    def test_dns_timeout(self):
710
        with self.assertRaisesRegex(
711
            DistNetworkError, "client socket has timed out after.*dnsnotexist"
712
        ) as manager:
713
            gen = dist.rendezvous(
714
                "tcp://dnsnotexist:23456?world_size=2&rank=0",
715
                timeout=timedelta(seconds=1),
716
            )
717
            next(gen)
718
        self.assertTrue(isinstance(manager.exception, DistError))
719

720
    @retry_on_connect_failures
721
    def test_nominal(self):
722
        url = self.create_tcp_url()
723
        gen0 = dist.rendezvous(url + "&rank=0")
724
        store0, rank0, size0 = next(gen0)
725
        self.assertEqual(0, rank0)
726
        self.assertEqual(1, size0)
727

728
        # Set value on the single store
729
        store0.set("key0", "value0")
730

731
        # check with get
732
        self.assertEqual(b"value0", store0.get("key0"))
733

734
    @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE))
735
    def test_tcp_store_timeout_set(self):
736
        url = self.create_tcp_url()
737
        test_store_timeout = timedelta(seconds=0.1)
738
        gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
739
        store0, rank0, size0 = next(gen0)
740
        store0.set_timeout(test_store_timeout)
741
        # this should time out in 0.1s. If the timeout passed into rendezvous was
742
        # not respected, it will take much longer to timeout.
743
        start = time.time()
744
        with self.assertRaisesRegex(
745
            DistStoreError, "wait timeout after 100ms, keys: /nonexistant key"
746
        ):
747
            store0.get("nonexistant key")
748

749
        end = time.time()
750
        time_diff = end - start
751
        self.assertGreater(10, time_diff)
752

753
    def test_tcp_store_timeout_doest_break_client(self):
754
        url = self.create_tcp_url()
755
        test_store_timeout = timedelta(seconds=0.1)
756
        gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
757
        store0, rank0, size0 = next(gen0)
758
        store0.set_timeout(test_store_timeout)
759
        # this should time out in 10s. If the timeout passed into rendezvous was
760
        # not respected, it will take much longer to timeout.
761
        start = time.time()
762
        with self.assertRaisesRegex(
763
            DistStoreError, "wait timeout after 100ms, keys: /the_key"
764
        ):
765
            store0.get("the_key")
766

767
        store0.set("the_key", "x")
768

769
        self.assertEqual(b"x", store0.get("the_key"))
770

771
        end = time.time()
772
        time_diff = end - start
773
        self.assertGreater(10, time_diff)
774

775
    def test_tcp_store_url_with_libuv(self):
776
        url = self.create_tcp_url()
777
        gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
778
        store0, rank0, size0 = next(gen0)
779
        self.assertTrue(store0.libuvBackend)
780

781

782
class DummyStore(dist.Store):
783
    def __init__(self) -> None:
784
        self.appends = []
785
        self.multi_sets = []
786
        self.multi_gets = []
787
        self.multi_get_res = []
788
        super().__init__()
789

790
    def append(self, key, value):
791
        self.appends.append((key, value))
792

793
    def multi_get(self, keys):
794
        self.multi_gets.append(keys)
795
        return self.multi_get_res.pop(0)
796

797
    def multi_set(self, keys, values):
798
        self.multi_sets.append((keys, values))
799

800
    def has_extended_api(self):
801
        return True
802

803

804
class TestPythonStore(TestCase):
805
    def test_optional_methods_fail(self):
806
        class TestStore(dist.Store):
807
            pass
808

809
        store = TestStore()
810
        self.assertFalse(store.has_extended_api())
811
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
812
            store.append("foo", "bar")
813
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
814
            store.multi_get(["foo", "bar"])
815
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
816
            store.multi_set(["foo", "bar"], [b"v", b"v"])
817

818
    def test_has_extended_api_passthrough(self):
819
        class TestStore(dist.Store):
820
            pass
821

822
        test_store = TestStore()
823
        store = dist.PrefixStore("p", test_store)
824
        self.assertFalse(store.has_extended_api())
825
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
826
            store.append("foo", "bar")
827
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
828
            store.multi_get(["foo", "bar"])
829
        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
830
            store.multi_set(["foo", "bar"], [b"v", b"v"])
831

832
    def test_has_extended_api_roundtrip(self):
833
        store = DummyStore()
834
        prefix = dist.PrefixStore("p", store)
835
        self.assertTrue(prefix.has_extended_api())
836

837
    def test_append_roundtrip(self):
838
        store = DummyStore()
839
        prefix = dist.PrefixStore("p", store)
840
        prefix.append("foo", "bar")
841
        self.assertEqual(1, len(store.appends))
842
        self.assertEqual(("p/foo", b"bar"), store.appends[0])
843

844
    def test_multi_get_roundtrip(self):
845
        store = DummyStore()
846
        prefix = dist.PrefixStore("p", store)
847
        store.multi_get_res.append([b"x", b"y"])
848
        res = prefix.multi_get(["foo", "bar"])
849
        self.assertEqual(1, len(store.multi_gets))
850
        self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0])
851
        self.assertEqual([b"x", b"y"], res)
852

853
    def test_multi_set_roundtrip(self):
854
        store = DummyStore()
855
        prefix = dist.PrefixStore("p", store)
856
        prefix.multi_set(["foo", "bar"], [b"x", b"y"])
857
        self.assertEqual(1, len(store.multi_sets))
858
        self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0])
859
        self.assertEqual([b"x", b"y"], store.multi_sets[0][1])
860

861
    def test_extended_methods_fallbacks(self):
862
        test_store = MyPythonStore()
863
        store = dist.PrefixStore("p", test_store)
864
        self.assertFalse(store.has_extended_api())
865
        store.append("foo", b"po")
866
        store.append("foo", b"tato")
867
        self.assertEqual(store.get("foo"), b"potato")
868

869
        store.multi_set(["a", "b"], [b"c", b"d"])
870
        self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"])
871

872

873
class TestMultiThreadedWait(MultiThreadedTestCase):
874
    file_store = dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1)
875
    hash_store = dist.HashStore()
876

877
    tcp_store = create_tcp_store(use_libuv=False)
878
    tcp_store_uv = create_tcp_store(use_libuv=True)
879

880
    @property
881
    def world_size(self):
882
        return 2
883

884
    def setUp(self):
885
        super().setUp()
886
        self._spawn_threads()
887

888
    def _test_wait(self, store):
889
        store.set_timeout(timedelta(seconds=2))
890
        if dist.get_rank() == 0:
891
            store.wait(["key1"])
892
            self.assertEqual(b"value1", store.get("key1"))
893
        if dist.get_rank() == 1:
894
            store.set("key1", "value1")
895

896
    def test_wait_hash_store(self):
897
        self._test_wait(self.hash_store)
898

899
    def test_wait_file_store(self):
900
        self._test_wait(self.file_store)
901

902
    def test_wait_prefix_file_store(self):
903
        store = dist.PrefixStore("pre", self.file_store)
904
        self._test_wait(store)
905

906
    def _test_wait_tcp_store(self, master_store):
907
        store = (
908
            master_store
909
            if dist.get_rank() == 0
910
            else dist.TCPStore(
911
                host_name=master_store.host,
912
                port=master_store.port,
913
                is_master=False,
914
                wait_for_workers=False,
915
                use_libuv=False,
916
            )
917
        )
918
        self._test_wait(store)
919

920
        prefix_store = dist.PrefixStore("pre", store)
921
        self._test_wait(prefix_store)
922

923
    def test_wait_tcp_store(self):
924
        self._test_wait_tcp_store(self.tcp_store)
925

926
    def test_wait_tcp_store_uv(self):
927
        self._test_wait_tcp_store(self.tcp_store_uv)
928

929

930
instantiate_parametrized_tests(TestMultiThreadedWait)
931

932

933
@skip_if_win32()
934
class TimeoutTest(TestCase):
935
    def tearDown(self):
936
        import signal
937

938
        super().tearDown()
939
        signal.signal(signal.SIGUSR1, signal.SIG_IGN)
940

941
    def test_interrupt_doesnt_break_wait(self):
942
        import signal
943

944
        rank_res = [None, None]
945

946
        def run(rank, my_store):
947
            nonlocal rank_res
948
            try:
949
                if rank == 0:
950
                    time.sleep(4)
951
                    my_store.set("foo", "bar")
952
                else:
953
                    my_store.wait(["foo"], datetime.timedelta(seconds=10))
954
                rank_res[rank] = True
955
            except Error as e:  # noqa: F821
956
                rank_res[rank] = e
957
            time.sleep(1)
958

959
        rank0_store = dist.TCPStore(
960
            host_name=DEFAULT_HOSTNAME,
961
            port=0,
962
            world_size=2,
963
            is_master=True,
964
            wait_for_workers=False,
965
        )
966
        rank1_store = dist.TCPStore(
967
            host_name=DEFAULT_HOSTNAME,
968
            port=rank0_store.port,
969
            world_size=2,
970
            is_master=False,
971
            wait_for_workers=False,
972
        )
973

974
        ths = []
975
        for i in range(2):
976
            t = threading.Thread(
977
                target=run,
978
                args=(
979
                    i,
980
                    [rank0_store, rank1_store][i],
981
                ),
982
            )
983
            t.start()
984
            ths.append(t)
985

986
        def handler(a, b):
987
            pass
988

989
        signal.signal(signal.SIGUSR1, handler)
990
        time.sleep(1)
991
        signal.pthread_kill(ths[1].ident, signal.SIGUSR1)
992

993
        for t in ths:
994
            t.join()
995
        self.assertTrue(rank_res[0], "rank0")
996
        self.assertTrue(rank_res[1], "rank1")
997

998

999
class InitPgWithNonUvStore(TestCase):
1000
    """
1001
    This test shows how to use the legacy TCPStore (non-libuv) backend since libuv is now
1002
    the default backend.
1003
    """
1004

1005
    def tearDown(self):
1006
        super().tearDown()
1007
        os.environ.pop("USE_LIBUV", None)
1008
        os.environ.pop("MASTER_ADDR", None)
1009
        os.environ.pop("MASTER_PORT", None)
1010

1011
    def test_with_url_param(self):
1012
        port = common.find_free_port()
1013
        dist.init_process_group(
1014
            "gloo",
1015
            rank=0,
1016
            world_size=1,
1017
            init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=0",
1018
        )
1019
        self._run_test()
1020

1021
    def test_with_env_var(self):
1022
        port = common.find_free_port()
1023
        os.environ["USE_LIBUV"] = "0"
1024
        os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME
1025
        os.environ["MASTER_PORT"] = str(port)
1026
        dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://")
1027
        self._run_test()
1028

1029
    def _run_test(self):
1030
        pg = dist.group.WORLD
1031
        store = c10d._get_process_group_store(pg)
1032
        self.assertTrue(isinstance(store, dist.PrefixStore))
1033
        # c10d does multiple levels of wrapping
1034
        while isinstance(store, dist.PrefixStore):
1035
            store = store.underlying_store
1036
        self.assertTrue(isinstance(store, dist.TCPStore))
1037
        self.assertFalse(store.libuvBackend)
1038
        dist.destroy_process_group()
1039

1040

1041
class TestClientProtocol(TestCase):
1042
    def test_client_connect(self) -> None:
1043
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1044
        sock.bind(("localhost", 0))
1045
        port = sock.getsockname()[1]
1046

1047
        def listen() -> None:
1048
            sock.listen()
1049
            conn, _ = sock.accept()
1050

1051
            # VALIDATE
1052
            # 0x3C85F7CE
1053
            self.assertEqual(conn.recv(5), b"\x00\xce\xf7\x85\x3c")
1054

1055
            # PING
1056
            data = conn.recv(5)
1057
            self.assertEqual(data[0], 13)
1058
            nonce = struct.unpack("i", data[1:])[0]
1059
            self.assertEqual(nonce, os.getpid())
1060

1061
            # send PING nonce response
1062
            conn.sendall(data[1:])
1063

1064
            conn.close()
1065

1066
        thread = threading.Thread(target=listen)
1067
        thread.start()
1068

1069
        store = dist.TCPStore(
1070
            host_name="localhost",
1071
            port=port,
1072
            world_size=2,
1073
            is_master=False,
1074
            timeout=timedelta(seconds=2),
1075
            wait_for_workers=False,
1076
        )
1077

1078
        thread.join()
1079

1080

1081
if __name__ == "__main__":
1082
    assert (
1083
        not torch.cuda._initialized
1084
    ), "test_distributed must not have initialized CUDA context on main process"
1085

1086
    run_tests()
1087

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.