pytorch
1086 строк · 35.8 Кб
1# Owner(s): ["oncall: distributed"]
2
3import datetime
4import os
5import socket
6import struct
7import sys
8import tempfile
9import threading
10import time
11from datetime import timedelta
12from sys import platform
13
14import torch
15import torch.distributed as dist
16import torch.distributed.distributed_c10d as c10d
17import torch.distributed.rpc as rpc
18from torch.distributed import DistError, DistNetworkError, DistStoreError
19from torch.testing._internal.common_distributed import MultiThreadedTestCase
20from torch.testing._internal.common_utils import instantiate_parametrized_tests
21
22
23if not dist.is_available():
24print("torch.distributed not available, skipping tests", file=sys.stderr)
25sys.exit(0)
26
27import torch.testing._internal.common_utils as common
28from torch.testing._internal.common_distributed import (
29create_tcp_store,
30skip_if_win32,
31tp_transports,
32)
33from torch.testing._internal.common_utils import (
34ADDRESS_IN_USE,
35CONNECT_TIMEOUT,
36load_tests,
37retry_on_connect_failures,
38run_tests,
39TestCase,
40)
41
42
43# load_tests from common_utils is used to automatically filter tests for
44# sharding on sandcastle. This line silences flake warnings
45load_tests = load_tests
46
47if platform == "darwin":
48LOOPBACK = "lo0"
49else:
50LOOPBACK = "lo"
51
52DEFAULT_HOSTNAME = "localhost"
53
54torch.backends.cuda.matmul.allow_tf32 = False
55
56
57def gpus_for_rank(world_size):
58"""Multigpu tests are designed to simulate the multi nodes with multi
59GPUs on each node. Nccl backend requires equal #GPUs in each process.
60On a single node, all visible GPUs are evenly
61divided to subsets, each process only uses a subset.
62"""
63visible_devices = list(range(torch.cuda.device_count()))
64gpus_per_process = torch.cuda.device_count() // world_size
65gpus_for_rank = []
66for rank in range(world_size):
67gpus_for_rank.append(
68visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process]
69)
70return gpus_for_rank
71
72
73class StoreTestBase:
74def _create_store(self, i):
75raise RuntimeError("not implemented")
76
77def _test_set_get_check(self, fs):
78fs.add("key", 1)
79fs.add("key", 2)
80fs.add("key", 3)
81fs.set("key0", "value0")
82fs.add("key3", 1)
83fs.set("key1", "value1")
84fs.add("key3", 2)
85fs.set("key2", "value2")
86fs.add("key3", 3)
87fs.add("key3", 4)
88fs.add("key3", 5)
89fs.add("key3", 6)
90self.assertEqual(fs.num_keys(), self.num_keys_total)
91self.assertEqual(b"6", fs.get("key"))
92self.assertEqual(b"value0", fs.get("key0"))
93self.assertEqual(b"value1", fs.get("key1"))
94self.assertEqual(b"value2", fs.get("key2"))
95self.assertEqual(b"21", fs.get("key3"))
96self.assertTrue(fs.check(["key3"]))
97self.assertFalse(fs.check(["Randomkey3"]))
98
99fs.set("-key3", "7")
100self.assertEqual(b"7", fs.get("-key3"))
101fs.delete_key("-key3")
102self.assertEqual(fs.num_keys(), self.num_keys_total)
103
104def test_set_get_check(self):
105self._test_set_get_check(self._create_store())
106
107def _test_compare_set(self, store):
108missing_key_result = store.compare_set(
109"cs_key0", "wrong_old_value", "new_value0"
110)
111self.assertEqual(b"wrong_old_value", missing_key_result)
112
113store.set("cs_key0", "value0")
114self.assertEqual(b"value0", store.get("cs_key0"))
115old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
116self.assertEqual(b"value0", old_value_result)
117self.assertEqual(b"value0", store.get("cs_key0"))
118new_value_result = store.compare_set("cs_key0", "value0", "new_value0")
119self.assertEqual(b"new_value0", new_value_result)
120self.assertEqual(b"new_value0", store.get("cs_key0"))
121empty_old_value_result = store.compare_set("cs_key1", "", "new_value1")
122self.assertEqual(b"new_value1", empty_old_value_result)
123self.assertEqual(b"new_value1", store.get("cs_key1"))
124
125def test_compare_set(self):
126self._test_compare_set(self._create_store())
127
128def _test_simple_wait(self, fs):
129with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"):
130fs.wait(["bad_key"], timedelta(seconds=0.25))
131fs.add("good_key", 1)
132fs.wait(["good_key"])
133
134def test_simple_wait(self):
135self._test_simple_wait(self._create_store())
136
137def _test_append(self, store):
138if not store.has_extended_api():
139self.skipTest("Store doesn't support extended APIs")
140store.set("foo", "po")
141store.append("foo", "tato")
142store.append("bar", "po")
143store.append("bar", "tato")
144self.assertEqual(b"potato", store.get("foo"))
145self.assertEqual(b"potato", store.get("bar"))
146
147def test_append(self):
148self._test_append(self._create_store())
149
150def _test_multi_set(self, store):
151if not store.has_extended_api():
152self.skipTest("Store doesn't support extended APIs")
153store.multi_set(["foo", "bar"], ["po", "tato"])
154self.assertEqual(b"po", store.get("foo"))
155self.assertEqual(b"tato", store.get("bar"))
156
157def test_multi_set(self):
158self._test_multi_set(self._create_store())
159
160def _test_multi_get(self, store):
161if not store.has_extended_api():
162self.skipTest("Store doesn't support extended APIs")
163store.set("foo", "po")
164store.set("bar", "tato")
165v0, v1 = store.multi_get(["foo", "bar"])
166self.assertEqual(b"po", v0)
167self.assertEqual(b"tato", v1)
168
169def test_multi_get(self):
170self._test_multi_get(self._create_store())
171
172# This is the number of keys used in test_set_get. Adding this as a class
173# property instead of hardcoding in the test since some Store
174# implementations will have differing number of keys. In the base case,
175# there will be 5 keys: key, key0, key1, key2, key3.
176@property
177def num_keys_total(self):
178return 5
179
180
181class FileStoreTest(TestCase, StoreTestBase):
182def setUp(self):
183super().setUp()
184self.file = tempfile.NamedTemporaryFile(delete=False)
185
186def _create_store(self):
187store = dist.FileStore(self.file.name, 1)
188store.set_timeout(timedelta(seconds=300))
189return store
190
191def test_init_pg_and_rpc_with_same_file(self):
192file = tempfile.NamedTemporaryFile(delete=False)
193# Init RPC using file
194rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
195rpc_backend_options.init_method = f"file://{file.name}"
196rpc_backend_options._transports = tp_transports()
197rpc.init_rpc(
198"worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options
199)
200
201# Init PG using file
202dist.init_process_group(
203"gloo", rank=0, world_size=1, init_method=f"file://{file.name}"
204)
205dist.destroy_process_group()
206assert os.path.exists(file.name)
207
208rpc.shutdown()
209os.remove(file.name)
210
211def test_refcount(self):
212file = tempfile.NamedTemporaryFile(delete=False)
213store = dist.FileStore(file.name, 1)
214store2 = dist.FileStore(file.name, 1)
215
216del store
217assert os.path.exists(file.name)
218del store2
219assert not os.path.exists(file.name)
220
221@property
222def num_keys_total(self):
223return 6
224
225
226@skip_if_win32()
227class HashStoreTest(TestCase, StoreTestBase):
228def _create_store(self):
229store = dist.HashStore()
230store.set_timeout(timedelta(seconds=300))
231return store
232
233
234class PrefixStoreTest(TestCase):
235def setUp(self):
236# delete is false as FileStore will automatically clean up the file
237self.file = tempfile.NamedTemporaryFile(delete=False)
238
239def test_get_underlying_store(self):
240tcp_store = dist.TCPStore(
241host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True
242)
243hash_store = dist.HashStore()
244file_store = dist.FileStore(self.file.name, world_size=1)
245for store in [tcp_store, hash_store, file_store]:
246with self.subTest(f"Testing getting underlying_store for {type(store)}"):
247prefix_store = dist.PrefixStore("prefix", store)
248self.assertEqual(prefix_store.underlying_store, store)
249
250
251class PrefixFileStoreTest(TestCase, StoreTestBase):
252def setUp(self):
253super().setUp()
254self.file = tempfile.NamedTemporaryFile(delete=False)
255self.filestore = dist.FileStore(self.file.name, 1)
256self.prefix = "test_prefix"
257self.filestore.set_timeout(timedelta(seconds=300))
258
259def _create_store(self):
260return dist.PrefixStore(self.prefix, self.filestore)
261
262@property
263def num_keys_total(self):
264return 6
265
266
267class TCPStoreTest(TestCase, StoreTestBase):
268_use_libuv = False
269
270def _create_store(self):
271store = create_tcp_store(use_libuv=self._use_libuv)
272store.set_timeout(timedelta(seconds=300))
273return store
274
275def _create_store_with_ws(self, addr, world_size):
276return create_tcp_store(
277addr, world_size, wait_for_workers=False, use_libuv=self._use_libuv
278)
279
280def test_address_already_in_use(self):
281addr = DEFAULT_HOSTNAME
282port = common.find_free_port()
283
284err_msg_reg = f"^The server socket has failed to listen on any local .*{port}"
285with self.assertRaisesRegex(RuntimeError, err_msg_reg):
286# Use noqa to silence flake8.
287# Need to store in an unused variable here to ensure the first
288# object is not destroyed before the second object is created.
289store1 = dist.TCPStore(
290addr, port, 1, True, use_libuv=self._use_libuv
291) # noqa: F841
292store2 = dist.TCPStore(
293addr, port, 1, True, use_libuv=self._use_libuv
294) # noqa: F841
295self.assertEqual(store1.libuvBackend, self._use_libuv)
296self.assertEqual(store2.libuvBackend, self._use_libuv)
297
298@retry_on_connect_failures
299def test_multitenancy(self):
300addr = DEFAULT_HOSTNAME
301port = common.find_free_port()
302
303# Use noqa to silence flake8.
304# Need to store in an unused variable here to ensure the first
305# object is not destroyed before the second object is created.
306store1 = dist.TCPStore(
307addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv
308) # type: ignore[call-arg] # noqa: F841
309store2 = dist.TCPStore(
310addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv
311) # type: ignore[call-arg] # noqa: F841
312self.assertEqual(store1.libuvBackend, self._use_libuv)
313self.assertEqual(store2.libuvBackend, self._use_libuv)
314
315def test_repr(self) -> None:
316# server
317store1 = self._create_store()
318self.assertRegex(
319repr(store1),
320r"TCPStore\("
321r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
322r"server=TCPServer\(port=\d+\)\)",
323)
324
325# client
326store2 = dist.TCPStore(
327store1.host,
328store1.port,
329world_size=2,
330is_master=False,
331)
332self.assertRegex(
333repr(store2),
334r"TCPStore\("
335r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
336r"server=<nullptr>\)",
337)
338
339@skip_if_win32()
340@retry_on_connect_failures
341def test_init_pg_and_rpc_with_same_socket(self):
342addr = DEFAULT_HOSTNAME
343port = common.find_free_port()
344
345os.environ["MASTER_ADDR"] = addr
346os.environ["MASTER_PORT"] = str(port)
347
348# We internally use a multi-tenant TCP store. Both PG and RPC should successfully
349# initialize even when using the same socket address.
350
351os.environ["USE_LIBUV"] = "1" if self._use_libuv else "0"
352dist.init_process_group(
353backend="gloo",
354init_method="env://",
355rank=0,
356world_size=1,
357)
358
359backend_opts = rpc.TensorPipeRpcBackendOptions(
360init_method=f"tcp://{addr}:{port}", _transports=tp_transports()
361)
362rpc.init_rpc(
363name="worker0",
364rank=0,
365world_size=1,
366rpc_backend_options=backend_opts,
367)
368
369del os.environ["USE_LIBUV"]
370assert "USE_LIBUV" not in os.environ
371rpc.shutdown()
372dist.destroy_process_group()
373
374@skip_if_win32()
375def test_take_over_listen_socket(self):
376listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
377listen_sock.bind(("localhost", 0))
378addr, port, *_ = listen_sock.getsockname()
379listen_fd = listen_sock.detach()
380
381store = dist.TCPStore(
382addr,
383port,
3841,
385is_master=True,
386master_listen_fd=listen_fd,
387use_libuv=self._use_libuv,
388)
389
390self.assertEqual(store.libuvBackend, self._use_libuv)
391store.set("key", "value")
392self.assertEqual(b"value", store.get("key"))
393
394# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
395# the user and one additional key used for coordinate all the workers.
396@property
397def num_keys_total(self):
398return 6
399
400def _test_numkeys_delkeys(self, fs):
401# We start off with one init key in the store to coordinate workers
402self.assertEqual(fs.num_keys(), 1)
403fs.add("key", 1)
404fs.add("key", 2)
405fs.add("key", 3)
406fs.set("key0", "value0")
407fs.add("key3", 1)
408fs.set("key1", "value1")
409self.assertEqual(fs.num_keys(), 5)
410fs.delete_key("key")
411self.assertEqual(fs.num_keys(), 4)
412fs.set_timeout(timedelta(seconds=2))
413with self.assertRaises(RuntimeError):
414fs.get("key")
415fs.delete_key("key0")
416fs.delete_key("key3")
417self.assertEqual(fs.num_keys(), 2)
418fs.set("key4", "value2")
419self.assertEqual(fs.num_keys(), 3)
420self.assertEqual(b"value1", fs.get("key1"))
421self.assertEqual(b"value2", fs.get("key4"))
422
423def test_numkeys_delkeys(self):
424self._test_numkeys_delkeys(self._create_store())
425
426def _create_client(self, index, addr, port, world_size):
427client_store = dist.TCPStore(
428addr,
429port,
430world_size=world_size,
431timeout=timedelta(seconds=10),
432use_libuv=self._use_libuv,
433)
434self.assertEqual(b"value", client_store.get("key"))
435client_store.set(f"new_key{index}", f"new_value{index}")
436self.assertEqual(
437f"next_value{index}".encode(),
438client_store.compare_set(
439f"new_key{index}", f"new_value{index}", f"next_value{index}"
440),
441)
442
443def _multi_worker_helper(self, world_size):
444addr = DEFAULT_HOSTNAME
445server_store = self._create_store_with_ws(addr, world_size)
446self.assertEqual(server_store.libuvBackend, self._use_libuv)
447server_store.set("key", "value")
448port = server_store.port
449
450num_indices = world_size if world_size else 1
451for i in range(num_indices):
452self._create_client(i, addr, port, world_size)
453
454def test_multi_worker_with_fixed_world_size(self):
455self._multi_worker_helper(5)
456
457def test_multi_worker_with_nonfixed_world_size(self):
458self._multi_worker_helper(None)
459
460def test_append(self):
461store = self._create_store()
462self.assertEqual(store.libuvBackend, self._use_libuv)
463store.set("foo", "po")
464store.append("foo", "tato")
465store.append("bar", "po")
466store.append("bar", "tato")
467self.assertEqual(b"potato", store.get("foo"))
468self.assertEqual(b"potato", store.get("bar"))
469
470def test_multi_set(self):
471store = self._create_store()
472self.assertEqual(store.libuvBackend, self._use_libuv)
473store.multi_set(["foo", "bar"], ["po", "tato"])
474self.assertEqual(b"po", store.get("foo"))
475self.assertEqual(b"tato", store.get("bar"))
476
477def test_multi_get(self):
478store = self._create_store()
479self.assertEqual(store.libuvBackend, self._use_libuv)
480store.set("foo", "po")
481store.set("bar", "tato")
482v0, v1 = store.multi_get(["foo", "bar"])
483self.assertEqual(b"po", v0)
484self.assertEqual(b"tato", v1)
485
486def test_store_timeout_on_missing_clients(self):
487with self.assertRaisesRegex(
488DistStoreError,
489r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined.",
490):
491# world_size is 2 so it should timeout
492dist.TCPStore(
493"localhost",
4940,
4952,
496True,
497timeout=timedelta(seconds=2),
498use_libuv=self._use_libuv,
499)
500
501# when wait_for_workers is not set, then there should be no exception raised
502dist.TCPStore(
503"localhost",
5040,
5052,
506True,
507timeout=timedelta(seconds=2),
508wait_for_workers=False,
509use_libuv=self._use_libuv,
510)
511
512
513class LibUvTCPStoreTest(TCPStoreTest):
514_use_libuv = True
515
516def _create_store(self):
517store = create_tcp_store(use_libuv=True)
518store.set_timeout(timedelta(seconds=300))
519return store
520
521def _create_store_with_ws(self, addr, world_size):
522return create_tcp_store(
523addr, world_size, wait_for_workers=False, use_libuv=True
524)
525
526def test_take_over_listen_socket(self):
527"""
528override the take_over_listen_socket test in TCPStoreTest.
529Reason: we have not thoroughly tested libuv TCPStore initialization using
530open Socket so we decide to not support this use for now.
531TODO (xilunwu): enable this use case
532"""
533listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
534listen_sock.bind(("localhost", 0))
535addr, port, *_ = listen_sock.getsockname()
536listen_fd = listen_sock.detach()
537
538err_msg_reg = (
539"^The libuv TCPStore backend does not support "
540"initialization with an listen fd"
541)
542
543with self.assertRaisesRegex(NotImplementedError, err_msg_reg):
544store = dist.TCPStore(
545addr,
546port,
5471,
548is_master=True,
549master_listen_fd=listen_fd,
550use_libuv=self._use_libuv,
551)
552
553
554class PrefixTCPStoreTest(TestCase, StoreTestBase):
555def setUp(self):
556super().setUp()
557self.tcpstore = create_tcp_store()
558self.prefix = "test_prefix"
559self.tcpstore.set_timeout(timedelta(seconds=300))
560
561def _create_store(self):
562return dist.PrefixStore(self.prefix, self.tcpstore)
563
564# The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
565# added by the user and one additional key used for coordinate all the
566# workers.
567@property
568def num_keys_total(self):
569return 6
570
571def test_underlying_non_prefix_store(self):
572store = self._create_store()
573wrapped_store = dist.PrefixStore(
574self.prefix, dist.PrefixStore(self.prefix, store)
575)
576self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
577self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
578
579
580class MyPythonStore(dist.Store):
581def __init__(self) -> None:
582super().__init__()
583self.store = {}
584
585def set(self, key, value):
586if not isinstance(key, (str, bytes)):
587raise AssertionError("Expected set to be called with string key")
588if type(value) is not bytes:
589raise AssertionError("Expected set to be called with bytes value")
590self.store[key] = value
591
592def get(self, key):
593value = self.store.get(key, b"")
594if type(value) is not bytes:
595raise AssertionError("Expected get to return bytes value")
596return value
597
598def add(self, key, value):
599new = int(self.store.get(key, 0)) + value
600self.set(key, bytes(str(new).encode("utf-8")))
601return new
602
603def compare_set(self, key, expected, newValue):
604if type(expected) is not bytes:
605raise AssertionError("compare_set::expected not bytes")
606if type(newValue) is not bytes:
607raise AssertionError("compare_set::newValue not bytes")
608
609val = self.store.get(key, None)
610if expected == val or val is None:
611val = self.store[key] = newValue
612return val
613
614
615class PythonStoreTest(TestCase):
616def test_set_get(self):
617# If we were to inherit from StoreTestBase and try to use
618# its test_set_get function, we would exercise the Python
619# API directly, instead of going through the C++ trampoline.
620# We care about testing the C++ trampoline, so run the
621# equivalent of StoreTestBase.test_set_get from C++.
622# See `torch/csrc/distributed/c10d/init.cpp` for the definition
623# of this test function.
624dist._test_python_store(MyPythonStore())
625
626
627class RendezvousTest(TestCase):
628def test_unknown_handler(self):
629with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
630dist.rendezvous("invalid://")
631
632def test_url_with_node_params(self):
633with self.assertRaisesRegex(AssertionError, "has node-specific arguments"):
634dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16)
635
636
637class RendezvousEnvTest(TestCase):
638@retry_on_connect_failures
639def test_nominal(self):
640os.environ["WORLD_SIZE"] = "1"
641os.environ["MASTER_ADDR"] = "127.0.0.1"
642os.environ["MASTER_PORT"] = str(common.find_free_port())
643
644# Single rank
645os.environ["RANK"] = "0"
646gen0 = dist.rendezvous("env://")
647store0, rank0, size0 = next(gen0)
648self.assertEqual(0, rank0)
649self.assertEqual(1, size0)
650
651store0.set("key0", "value0")
652
653# check with get
654self.assertEqual(b"value0", store0.get("key0"))
655
656
657class RendezvousFileTest(TestCase):
658def test_common_errors(self):
659with self.assertRaisesRegex(ValueError, "path missing"):
660gen = dist.rendezvous("file://?rank=0&world_size=1")
661next(gen)
662with self.assertRaisesRegex(ValueError, "rank parameter missing"):
663gen = dist.rendezvous("file:///tmp/foo?world_size=1")
664next(gen)
665with self.assertRaisesRegex(ValueError, "size parameter missing"):
666gen = dist.rendezvous("file:///tmp/foo?rank=0")
667next(gen)
668
669def test_nominal(self):
670with tempfile.NamedTemporaryFile(delete=False) as file:
671url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
672gen0 = dist.rendezvous(url + "&rank=0")
673store0, rank0, size0 = next(gen0)
674self.assertEqual(0, rank0)
675self.assertEqual(2, size0)
676gen1 = dist.rendezvous(url + "&rank=1")
677store1, rank1, size1 = next(gen1)
678self.assertEqual(1, rank1)
679self.assertEqual(2, size1)
680
681# Set value on both stores
682store0.set("key0", "value0")
683store1.set("key1", "value1")
684
685# Cross check with get
686self.assertEqual(b"value0", store1.get("key0"))
687self.assertEqual(b"value1", store0.get("key1"))
688
689
690@skip_if_win32()
691class RendezvousTCPTest(TestCase):
692def create_tcp_url(self):
693addr = DEFAULT_HOSTNAME
694port = common.find_free_port()
695url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
696return url
697
698def test_common_errors(self):
699with self.assertRaisesRegex(ValueError, "port number missing"):
700gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
701next(gen)
702with self.assertRaisesRegex(ValueError, "rank parameter missing"):
703gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
704next(gen)
705with self.assertRaisesRegex(ValueError, "size parameter missing"):
706gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
707next(gen)
708
709def test_dns_timeout(self):
710with self.assertRaisesRegex(
711DistNetworkError, "client socket has timed out after.*dnsnotexist"
712) as manager:
713gen = dist.rendezvous(
714"tcp://dnsnotexist:23456?world_size=2&rank=0",
715timeout=timedelta(seconds=1),
716)
717next(gen)
718self.assertTrue(isinstance(manager.exception, DistError))
719
720@retry_on_connect_failures
721def test_nominal(self):
722url = self.create_tcp_url()
723gen0 = dist.rendezvous(url + "&rank=0")
724store0, rank0, size0 = next(gen0)
725self.assertEqual(0, rank0)
726self.assertEqual(1, size0)
727
728# Set value on the single store
729store0.set("key0", "value0")
730
731# check with get
732self.assertEqual(b"value0", store0.get("key0"))
733
734@retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE))
735def test_tcp_store_timeout_set(self):
736url = self.create_tcp_url()
737test_store_timeout = timedelta(seconds=0.1)
738gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
739store0, rank0, size0 = next(gen0)
740store0.set_timeout(test_store_timeout)
741# this should time out in 0.1s. If the timeout passed into rendezvous was
742# not respected, it will take much longer to timeout.
743start = time.time()
744with self.assertRaisesRegex(
745DistStoreError, "wait timeout after 100ms, keys: /nonexistant key"
746):
747store0.get("nonexistant key")
748
749end = time.time()
750time_diff = end - start
751self.assertGreater(10, time_diff)
752
753def test_tcp_store_timeout_doest_break_client(self):
754url = self.create_tcp_url()
755test_store_timeout = timedelta(seconds=0.1)
756gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
757store0, rank0, size0 = next(gen0)
758store0.set_timeout(test_store_timeout)
759# this should time out in 10s. If the timeout passed into rendezvous was
760# not respected, it will take much longer to timeout.
761start = time.time()
762with self.assertRaisesRegex(
763DistStoreError, "wait timeout after 100ms, keys: /the_key"
764):
765store0.get("the_key")
766
767store0.set("the_key", "x")
768
769self.assertEqual(b"x", store0.get("the_key"))
770
771end = time.time()
772time_diff = end - start
773self.assertGreater(10, time_diff)
774
775def test_tcp_store_url_with_libuv(self):
776url = self.create_tcp_url()
777gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
778store0, rank0, size0 = next(gen0)
779self.assertTrue(store0.libuvBackend)
780
781
782class DummyStore(dist.Store):
783def __init__(self) -> None:
784self.appends = []
785self.multi_sets = []
786self.multi_gets = []
787self.multi_get_res = []
788super().__init__()
789
790def append(self, key, value):
791self.appends.append((key, value))
792
793def multi_get(self, keys):
794self.multi_gets.append(keys)
795return self.multi_get_res.pop(0)
796
797def multi_set(self, keys, values):
798self.multi_sets.append((keys, values))
799
800def has_extended_api(self):
801return True
802
803
804class TestPythonStore(TestCase):
805def test_optional_methods_fail(self):
806class TestStore(dist.Store):
807pass
808
809store = TestStore()
810self.assertFalse(store.has_extended_api())
811with self.assertRaisesRegex(RuntimeError, "Not implemented."):
812store.append("foo", "bar")
813with self.assertRaisesRegex(RuntimeError, "Not implemented."):
814store.multi_get(["foo", "bar"])
815with self.assertRaisesRegex(RuntimeError, "Not implemented."):
816store.multi_set(["foo", "bar"], [b"v", b"v"])
817
818def test_has_extended_api_passthrough(self):
819class TestStore(dist.Store):
820pass
821
822test_store = TestStore()
823store = dist.PrefixStore("p", test_store)
824self.assertFalse(store.has_extended_api())
825with self.assertRaisesRegex(RuntimeError, "Not implemented."):
826store.append("foo", "bar")
827with self.assertRaisesRegex(RuntimeError, "Not implemented."):
828store.multi_get(["foo", "bar"])
829with self.assertRaisesRegex(RuntimeError, "Not implemented."):
830store.multi_set(["foo", "bar"], [b"v", b"v"])
831
832def test_has_extended_api_roundtrip(self):
833store = DummyStore()
834prefix = dist.PrefixStore("p", store)
835self.assertTrue(prefix.has_extended_api())
836
837def test_append_roundtrip(self):
838store = DummyStore()
839prefix = dist.PrefixStore("p", store)
840prefix.append("foo", "bar")
841self.assertEqual(1, len(store.appends))
842self.assertEqual(("p/foo", b"bar"), store.appends[0])
843
844def test_multi_get_roundtrip(self):
845store = DummyStore()
846prefix = dist.PrefixStore("p", store)
847store.multi_get_res.append([b"x", b"y"])
848res = prefix.multi_get(["foo", "bar"])
849self.assertEqual(1, len(store.multi_gets))
850self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0])
851self.assertEqual([b"x", b"y"], res)
852
853def test_multi_set_roundtrip(self):
854store = DummyStore()
855prefix = dist.PrefixStore("p", store)
856prefix.multi_set(["foo", "bar"], [b"x", b"y"])
857self.assertEqual(1, len(store.multi_sets))
858self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0])
859self.assertEqual([b"x", b"y"], store.multi_sets[0][1])
860
861def test_extended_methods_fallbacks(self):
862test_store = MyPythonStore()
863store = dist.PrefixStore("p", test_store)
864self.assertFalse(store.has_extended_api())
865store.append("foo", b"po")
866store.append("foo", b"tato")
867self.assertEqual(store.get("foo"), b"potato")
868
869store.multi_set(["a", "b"], [b"c", b"d"])
870self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"])
871
872
873class TestMultiThreadedWait(MultiThreadedTestCase):
874file_store = dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1)
875hash_store = dist.HashStore()
876
877tcp_store = create_tcp_store(use_libuv=False)
878tcp_store_uv = create_tcp_store(use_libuv=True)
879
880@property
881def world_size(self):
882return 2
883
884def setUp(self):
885super().setUp()
886self._spawn_threads()
887
888def _test_wait(self, store):
889store.set_timeout(timedelta(seconds=2))
890if dist.get_rank() == 0:
891store.wait(["key1"])
892self.assertEqual(b"value1", store.get("key1"))
893if dist.get_rank() == 1:
894store.set("key1", "value1")
895
896def test_wait_hash_store(self):
897self._test_wait(self.hash_store)
898
899def test_wait_file_store(self):
900self._test_wait(self.file_store)
901
902def test_wait_prefix_file_store(self):
903store = dist.PrefixStore("pre", self.file_store)
904self._test_wait(store)
905
906def _test_wait_tcp_store(self, master_store):
907store = (
908master_store
909if dist.get_rank() == 0
910else dist.TCPStore(
911host_name=master_store.host,
912port=master_store.port,
913is_master=False,
914wait_for_workers=False,
915use_libuv=False,
916)
917)
918self._test_wait(store)
919
920prefix_store = dist.PrefixStore("pre", store)
921self._test_wait(prefix_store)
922
923def test_wait_tcp_store(self):
924self._test_wait_tcp_store(self.tcp_store)
925
926def test_wait_tcp_store_uv(self):
927self._test_wait_tcp_store(self.tcp_store_uv)
928
929
930instantiate_parametrized_tests(TestMultiThreadedWait)
931
932
933@skip_if_win32()
934class TimeoutTest(TestCase):
935def tearDown(self):
936import signal
937
938super().tearDown()
939signal.signal(signal.SIGUSR1, signal.SIG_IGN)
940
941def test_interrupt_doesnt_break_wait(self):
942import signal
943
944rank_res = [None, None]
945
946def run(rank, my_store):
947nonlocal rank_res
948try:
949if rank == 0:
950time.sleep(4)
951my_store.set("foo", "bar")
952else:
953my_store.wait(["foo"], datetime.timedelta(seconds=10))
954rank_res[rank] = True
955except Error as e: # noqa: F821
956rank_res[rank] = e
957time.sleep(1)
958
959rank0_store = dist.TCPStore(
960host_name=DEFAULT_HOSTNAME,
961port=0,
962world_size=2,
963is_master=True,
964wait_for_workers=False,
965)
966rank1_store = dist.TCPStore(
967host_name=DEFAULT_HOSTNAME,
968port=rank0_store.port,
969world_size=2,
970is_master=False,
971wait_for_workers=False,
972)
973
974ths = []
975for i in range(2):
976t = threading.Thread(
977target=run,
978args=(
979i,
980[rank0_store, rank1_store][i],
981),
982)
983t.start()
984ths.append(t)
985
986def handler(a, b):
987pass
988
989signal.signal(signal.SIGUSR1, handler)
990time.sleep(1)
991signal.pthread_kill(ths[1].ident, signal.SIGUSR1)
992
993for t in ths:
994t.join()
995self.assertTrue(rank_res[0], "rank0")
996self.assertTrue(rank_res[1], "rank1")
997
998
999class InitPgWithNonUvStore(TestCase):
1000"""
1001This test shows how to use the legacy TCPStore (non-libuv) backend since libuv is now
1002the default backend.
1003"""
1004
1005def tearDown(self):
1006super().tearDown()
1007os.environ.pop("USE_LIBUV", None)
1008os.environ.pop("MASTER_ADDR", None)
1009os.environ.pop("MASTER_PORT", None)
1010
1011def test_with_url_param(self):
1012port = common.find_free_port()
1013dist.init_process_group(
1014"gloo",
1015rank=0,
1016world_size=1,
1017init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=0",
1018)
1019self._run_test()
1020
1021def test_with_env_var(self):
1022port = common.find_free_port()
1023os.environ["USE_LIBUV"] = "0"
1024os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME
1025os.environ["MASTER_PORT"] = str(port)
1026dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://")
1027self._run_test()
1028
1029def _run_test(self):
1030pg = dist.group.WORLD
1031store = c10d._get_process_group_store(pg)
1032self.assertTrue(isinstance(store, dist.PrefixStore))
1033# c10d does multiple levels of wrapping
1034while isinstance(store, dist.PrefixStore):
1035store = store.underlying_store
1036self.assertTrue(isinstance(store, dist.TCPStore))
1037self.assertFalse(store.libuvBackend)
1038dist.destroy_process_group()
1039
1040
1041class TestClientProtocol(TestCase):
1042def test_client_connect(self) -> None:
1043sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1044sock.bind(("localhost", 0))
1045port = sock.getsockname()[1]
1046
1047def listen() -> None:
1048sock.listen()
1049conn, _ = sock.accept()
1050
1051# VALIDATE
1052# 0x3C85F7CE
1053self.assertEqual(conn.recv(5), b"\x00\xce\xf7\x85\x3c")
1054
1055# PING
1056data = conn.recv(5)
1057self.assertEqual(data[0], 13)
1058nonce = struct.unpack("i", data[1:])[0]
1059self.assertEqual(nonce, os.getpid())
1060
1061# send PING nonce response
1062conn.sendall(data[1:])
1063
1064conn.close()
1065
1066thread = threading.Thread(target=listen)
1067thread.start()
1068
1069store = dist.TCPStore(
1070host_name="localhost",
1071port=port,
1072world_size=2,
1073is_master=False,
1074timeout=timedelta(seconds=2),
1075wait_for_workers=False,
1076)
1077
1078thread.join()
1079
1080
1081if __name__ == "__main__":
1082assert (
1083not torch.cuda._initialized
1084), "test_distributed must not have initialized CUDA context on main process"
1085
1086run_tests()
1087