pytorch
851 строка · 30.2 Кб
1# Owner(s): ["oncall: distributed"]
2
3import datetime
4import os
5import socket
6import sys
7import tempfile
8import time
9import threading
10from datetime import timedelta
11from sys import platform
12
13import torch
14import torch.distributed as dist
15import torch.distributed.distributed_c10d as c10d
16import torch.distributed.rpc as rpc
17from torch.distributed import DistNetworkError, DistError, DistStoreError
18from torch.testing._internal.common_distributed import MultiThreadedTestCase
19from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
20
21if not dist.is_available():
22print("torch.distributed not available, skipping tests", file=sys.stderr)
23sys.exit(0)
24
25import torch.testing._internal.common_utils as common
26from torch.testing._internal.common_distributed import (
27skip_if_win32,
28create_tcp_store,
29tp_transports
30)
31from torch.testing._internal.common_utils import (
32TestCase,
33load_tests,
34run_tests,
35retry_on_connect_failures,
36ADDRESS_IN_USE,
37CONNECT_TIMEOUT,
38)
39
40# load_tests from common_utils is used to automatically filter tests for
41# sharding on sandcastle. This line silences flake warnings
42load_tests = load_tests
43
44if platform == "darwin":
45LOOPBACK = "lo0"
46else:
47LOOPBACK = "lo"
48
49DEFAULT_HOSTNAME = "localhost"
50
51torch.backends.cuda.matmul.allow_tf32 = False
52
53
54def gpus_for_rank(world_size):
55"""Multigpu tests are designed to simulate the multi nodes with multi
56GPUs on each node. Nccl backend requires equal #GPUs in each process.
57On a single node, all visible GPUs are evenly
58divided to subsets, each process only uses a subset.
59"""
60visible_devices = list(range(torch.cuda.device_count()))
61gpus_per_process = torch.cuda.device_count() // world_size
62gpus_for_rank = []
63for rank in range(world_size):
64gpus_for_rank.append(
65visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]
66)
67return gpus_for_rank
68
69
70class StoreTestBase:
71def _create_store(self, i):
72raise RuntimeError("not implemented")
73
74def _test_set_get_check(self, fs):
75fs.add("key", 1)
76fs.add("key", 2)
77fs.add("key", 3)
78fs.set("key0", "value0")
79fs.add("key3", 1)
80fs.set("key1", "value1")
81fs.add("key3", 2)
82fs.set("key2", "value2")
83fs.add("key3", 3)
84fs.add("key3", 4)
85fs.add("key3", 5)
86fs.add("key3", 6)
87self.assertEqual(fs.num_keys(), self.num_keys_total)
88self.assertEqual(b"6", fs.get("key"))
89self.assertEqual(b"value0", fs.get("key0"))
90self.assertEqual(b"value1", fs.get("key1"))
91self.assertEqual(b"value2", fs.get("key2"))
92self.assertEqual(b"21", fs.get("key3"))
93self.assertTrue(fs.check(["key3"]))
94self.assertFalse(fs.check(["Randomkey3"]))
95
96fs.set("-key3", "7")
97self.assertEqual(b"7", fs.get("-key3"))
98fs.delete_key("-key3")
99self.assertEqual(fs.num_keys(), self.num_keys_total)
100
101def test_set_get_check(self):
102self._test_set_get_check(self._create_store())
103
104def _test_compare_set(self, store):
105missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
106self.assertEqual(b"wrong_old_value", missing_key_result)
107
108store.set("cs_key0", "value0")
109self.assertEqual(b"value0", store.get("cs_key0"))
110old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
111self.assertEqual(b"value0", old_value_result)
112self.assertEqual(b"value0", store.get("cs_key0"))
113new_value_result = store.compare_set("cs_key0", "value0", "new_value0")
114self.assertEqual(b"new_value0", new_value_result)
115self.assertEqual(b"new_value0", store.get("cs_key0"))
116empty_old_value_result = store.compare_set("cs_key1", "", "new_value1")
117self.assertEqual(b"new_value1", empty_old_value_result)
118self.assertEqual(b"new_value1", store.get("cs_key1"))
119
120def test_compare_set(self):
121self._test_compare_set(self._create_store())
122
123def _test_simple_wait(self, fs):
124with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"):
125fs.wait(["bad_key"], timedelta(seconds=0.25))
126fs.add("good_key", 1)
127fs.wait(["good_key"])
128
129def test_simple_wait(self):
130self._test_simple_wait(self._create_store())
131
132def _test_append(self, store):
133if not store.has_extended_api():
134self.skipTest("Store doesn't support extended APIs")
135store.set("foo", "po")
136store.append("foo", "tato")
137store.append("bar", "po")
138store.append("bar", "tato")
139self.assertEqual(b"potato", store.get("foo"))
140self.assertEqual(b"potato", store.get("bar"))
141
142def test_append(self):
143self._test_append(self._create_store())
144
145def _test_multi_set(self, store):
146if not store.has_extended_api():
147self.skipTest("Store doesn't support extended APIs")
148store.multi_set(["foo", "bar"], ["po", "tato"])
149self.assertEqual(b"po", store.get("foo"))
150self.assertEqual(b"tato", store.get("bar"))
151
152def test_multi_set(self):
153self._test_multi_set(self._create_store())
154
155def _test_multi_get(self, store):
156if not store.has_extended_api():
157self.skipTest("Store doesn't support extended APIs")
158store.set("foo", "po")
159store.set("bar", "tato")
160v0, v1 = store.multi_get(["foo", "bar"])
161self.assertEqual(b"po", v0)
162self.assertEqual(b"tato", v1)
163
164def test_multi_get(self):
165self._test_multi_get(self._create_store())
166
167# This is the number of keys used in test_set_get. Adding this as a class
168# property instead of hardcoding in the test since some Store
169# implementations will have differing number of keys. In the base case,
170# there will be 5 keys: key, key0, key1, key2, key3.
171@property
172def num_keys_total(self):
173return 5
174
175
176class FileStoreTest(TestCase, StoreTestBase):
177def setUp(self):
178super().setUp()
179self.file = tempfile.NamedTemporaryFile(delete=False)
180
181def _create_store(self):
182store = dist.FileStore(self.file.name, 1)
183store.set_timeout(timedelta(seconds=300))
184return store
185
186def test_init_pg_and_rpc_with_same_file(self):
187file = tempfile.NamedTemporaryFile(delete=False)
188# Init RPC using file
189rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
190rpc_backend_options.init_method = f"file://{file.name}"
191rpc_backend_options._transports = tp_transports()
192rpc.init_rpc("worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options)
193
194# Init PG using file
195dist.init_process_group("gloo", rank=0, world_size=1, init_method=f"file://{file.name}")
196dist.destroy_process_group()
197assert os.path.exists(file.name)
198
199rpc.shutdown()
200os.remove(file.name)
201
202def test_refcount(self):
203file = tempfile.NamedTemporaryFile(delete=False)
204store = dist.FileStore(file.name, 1)
205store2 = dist.FileStore(file.name, 1)
206
207del store
208assert os.path.exists(file.name)
209del store2
210assert not os.path.exists(file.name)
211
212@property
213def num_keys_total(self):
214return 6
215
216
217@skip_if_win32()
218class HashStoreTest(TestCase, StoreTestBase):
219def _create_store(self):
220store = dist.HashStore()
221store.set_timeout(timedelta(seconds=300))
222return store
223
224
225class PrefixStoreTest(TestCase):
226def setUp(self):
227# delete is false as FileStore will automatically clean up the file
228self.file = tempfile.NamedTemporaryFile(delete=False)
229
230def test_get_underlying_store(self):
231tcp_store = dist.TCPStore(host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True)
232hash_store = dist.HashStore()
233file_store = dist.FileStore(self.file.name, world_size=1)
234for store in [tcp_store, hash_store, file_store]:
235with self.subTest(f"Testing getting underlying_store for {type(store)}"):
236prefix_store = dist.PrefixStore("prefix", store)
237self.assertEqual(prefix_store.underlying_store, store)
238
239
240class PrefixFileStoreTest(TestCase, StoreTestBase):
241def setUp(self):
242super().setUp()
243self.file = tempfile.NamedTemporaryFile(delete=False)
244self.filestore = dist.FileStore(self.file.name, 1)
245self.prefix = "test_prefix"
246self.filestore.set_timeout(timedelta(seconds=300))
247
248def _create_store(self):
249return dist.PrefixStore(self.prefix, self.filestore)
250
251@property
252def num_keys_total(self):
253return 6
254
255
256class TCPStoreTest(TestCase, StoreTestBase):
257def _create_store(self):
258store = create_tcp_store()
259store.set_timeout(timedelta(seconds=300))
260return store
261
262def _create_store_with_ws(self, addr, world_size):
263return create_tcp_store(addr, world_size, wait_for_workers=False)
264
265def test_address_already_in_use(self):
266err_msg_reg = "^The server socket has failed to listen on any local "
267with self.assertRaisesRegex(RuntimeError, err_msg_reg):
268addr = DEFAULT_HOSTNAME
269port = common.find_free_port()
270
271# Use noqa to silence flake8.
272# Need to store in an unused variable here to ensure the first
273# object is not destroyed before the second object is created.
274store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841
275store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841
276
277@retry_on_connect_failures
278def test_multitenancy(self):
279addr = DEFAULT_HOSTNAME
280port = common.find_free_port()
281
282# Use noqa to silence flake8.
283# Need to store in an unused variable here to ensure the first
284# object is not destroyed before the second object is created.
285store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841
286store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841
287
288@skip_if_win32()
289@retry_on_connect_failures
290def test_init_pg_and_rpc_with_same_socket(self):
291addr = DEFAULT_HOSTNAME
292port = common.find_free_port()
293
294os.environ["MASTER_ADDR"] = addr
295os.environ["MASTER_PORT"] = str(port)
296
297# We internally use a multi-tenant TCP store. Both PG and RPC should successfully
298# initialize even when using the same socket address.
299
300dist.init_process_group(
301backend="gloo",
302init_method="env://",
303rank=0,
304world_size=1,
305)
306
307backend_opts = rpc.TensorPipeRpcBackendOptions(
308init_method=f"tcp://{addr}:{port}",
309_transports=tp_transports()
310)
311rpc.init_rpc(
312name="worker0",
313rank=0,
314world_size=1,
315rpc_backend_options=backend_opts,
316)
317
318rpc.shutdown()
319
320@skip_if_win32()
321def test_take_over_listen_socket(self):
322listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
323listen_sock.bind(("localhost", 0))
324addr, port, *_ = listen_sock.getsockname()
325listen_fd = listen_sock.detach()
326
327store = dist.TCPStore(addr, port, 1, is_master=True, master_listen_fd=listen_fd)
328
329store.set("key", "value")
330self.assertEqual(b"value", store.get("key"))
331
332# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
333# the user and one additional key used for coordinate all the workers.
334@property
335def num_keys_total(self):
336return 6
337
338def _test_numkeys_delkeys(self, fs):
339# We start off with one init key in the store to coordinate workers
340self.assertEqual(fs.num_keys(), 1)
341fs.add("key", 1)
342fs.add("key", 2)
343fs.add("key", 3)
344fs.set("key0", "value0")
345fs.add("key3", 1)
346fs.set("key1", "value1")
347self.assertEqual(fs.num_keys(), 5)
348fs.delete_key("key")
349self.assertEqual(fs.num_keys(), 4)
350fs.set_timeout(timedelta(seconds=2))
351with self.assertRaises(RuntimeError):
352fs.get("key")
353fs.delete_key("key0")
354fs.delete_key("key3")
355self.assertEqual(fs.num_keys(), 2)
356fs.set("key4", "value2")
357self.assertEqual(fs.num_keys(), 3)
358self.assertEqual(b"value1", fs.get("key1"))
359self.assertEqual(b"value2", fs.get("key4"))
360
361def test_numkeys_delkeys(self):
362self._test_numkeys_delkeys(self._create_store())
363
364def _create_client(self, index, addr, port, world_size):
365client_store = dist.TCPStore(addr, port, world_size=world_size, timeout=timedelta(seconds=10))
366self.assertEqual(b"value", client_store.get("key"))
367client_store.set(f"new_key{index}", f"new_value{index}")
368self.assertEqual(f"next_value{index}".encode(),
369client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
370
371def _multi_worker_helper(self, world_size):
372addr = DEFAULT_HOSTNAME
373server_store = self._create_store_with_ws(addr, world_size)
374server_store.set("key", "value")
375port = server_store.port
376
377num_indices = world_size if world_size else 1
378for i in range(num_indices):
379self._create_client(i, addr, port, world_size)
380
381def test_multi_worker_with_fixed_world_size(self):
382self._multi_worker_helper(5)
383
384def test_multi_worker_with_nonfixed_world_size(self):
385self._multi_worker_helper(None)
386
387def test_append(self):
388store = self._create_store()
389store.set("foo", "po")
390store.append("foo", "tato")
391store.append("bar", "po")
392store.append("bar", "tato")
393self.assertEqual(b"potato", store.get("foo"))
394self.assertEqual(b"potato", store.get("bar"))
395
396def test_multi_set(self):
397store = self._create_store()
398store.multi_set(["foo", "bar"], ["po", "tato"])
399self.assertEqual(b"po", store.get("foo"))
400self.assertEqual(b"tato", store.get("bar"))
401
402def test_multi_get(self):
403store = self._create_store()
404store.set("foo", "po")
405store.set("bar", "tato")
406v0, v1 = store.multi_get(["foo", "bar"])
407self.assertEqual(b"po", v0)
408self.assertEqual(b"tato", v1)
409
410def test_store_timeout_on_missing_clients(self):
411with self.assertRaisesRegex(DistStoreError, r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined."):
412# world_size is 2 so it should timeout
413dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2))
414
415# when wait_for_workers is not set, then there should be no exception raised
416dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2), wait_for_workers=False)
417
418class LibUvTCPStoreTest(TCPStoreTest):
419
420def _create_store(self):
421store = create_tcp_store(use_libuv=True)
422store.set_timeout(timedelta(seconds=300))
423return store
424
425def _create_store_with_ws(self, addr, world_size):
426return create_tcp_store(addr, world_size, wait_for_workers=False, use_libuv=True)
427
428
429class PrefixTCPStoreTest(TestCase, StoreTestBase):
430def setUp(self):
431super().setUp()
432self.tcpstore = create_tcp_store()
433self.prefix = "test_prefix"
434self.tcpstore.set_timeout(timedelta(seconds=300))
435
436def _create_store(self):
437return dist.PrefixStore(self.prefix, self.tcpstore)
438
439# The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
440# added by the user and one additional key used for coordinate all the
441# workers.
442@property
443def num_keys_total(self):
444return 6
445
446def test_underlying_non_prefix_store(self):
447store = self._create_store()
448wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
449self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
450self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
451
452class MyPythonStore(dist.Store):
453def __init__(self):
454super().__init__()
455self.store = {}
456
457def set(self, key, value):
458if not isinstance(key, (str, bytes)):
459raise AssertionError("Expected set to be called with string key")
460if type(value) is not bytes:
461raise AssertionError("Expected set to be called with bytes value")
462self.store[key] = value
463
464def get(self, key):
465value = self.store.get(key, b"")
466if type(value) is not bytes:
467raise AssertionError("Expected get to return bytes value")
468return value
469
470def add(self, key, value):
471new = int(self.store.get(key, 0)) + value
472self.set(key, bytes(str(new).encode("utf-8")))
473return new
474
475def compare_set(self, key, expected, newValue):
476if type(expected) is not bytes:
477raise AssertionError("compare_set::expected not bytes")
478if type(newValue) is not bytes:
479raise AssertionError("compare_set::newValue not bytes")
480
481val = self.store.get(key, None)
482if expected == val or val is None:
483val = self.store[key] = newValue
484return val
485
486class PythonStoreTest(TestCase):
487def test_set_get(self):
488# If we were to inherit from StoreTestBase and try to use
489# its test_set_get function, we would exercise the Python
490# API directly, instead of going through the C++ trampoline.
491# We care about testing the C++ trampoline, so run the
492# equivalent of StoreTestBase.test_set_get from C++.
493# See `torch/csrc/distributed/c10d/init.cpp` for the definition
494# of this test function.
495dist._test_python_store(MyPythonStore())
496
497
498class RendezvousTest(TestCase):
499def test_unknown_handler(self):
500with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
501dist.rendezvous("invalid://")
502
503def test_url_with_node_params(self):
504with self.assertRaisesRegex(AssertionError, "has node-specific arguments"):
505dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16)
506
507
508class RendezvousEnvTest(TestCase):
509@retry_on_connect_failures
510def test_nominal(self):
511os.environ["WORLD_SIZE"] = "1"
512os.environ["MASTER_ADDR"] = "127.0.0.1"
513os.environ["MASTER_PORT"] = str(common.find_free_port())
514
515# Single rank
516os.environ["RANK"] = "0"
517gen0 = dist.rendezvous("env://")
518store0, rank0, size0 = next(gen0)
519self.assertEqual(0, rank0)
520self.assertEqual(1, size0)
521
522store0.set("key0", "value0")
523
524# check with get
525self.assertEqual(b"value0", store0.get("key0"))
526
527
528class RendezvousFileTest(TestCase):
529def test_common_errors(self):
530with self.assertRaisesRegex(ValueError, "path missing"):
531gen = dist.rendezvous("file://?rank=0&world_size=1")
532next(gen)
533with self.assertRaisesRegex(ValueError, "rank parameter missing"):
534gen = dist.rendezvous("file:///tmp/foo?world_size=1")
535next(gen)
536with self.assertRaisesRegex(ValueError, "size parameter missing"):
537gen = dist.rendezvous("file:///tmp/foo?rank=0")
538next(gen)
539
540def test_nominal(self):
541with tempfile.NamedTemporaryFile(delete=False) as file:
542url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
543gen0 = dist.rendezvous(url + "&rank=0")
544store0, rank0, size0 = next(gen0)
545self.assertEqual(0, rank0)
546self.assertEqual(2, size0)
547gen1 = dist.rendezvous(url + "&rank=1")
548store1, rank1, size1 = next(gen1)
549self.assertEqual(1, rank1)
550self.assertEqual(2, size1)
551
552# Set value on both stores
553store0.set("key0", "value0")
554store1.set("key1", "value1")
555
556# Cross check with get
557self.assertEqual(b"value0", store1.get("key0"))
558self.assertEqual(b"value1", store0.get("key1"))
559
560
561@skip_if_win32()
562class RendezvousTCPTest(TestCase):
563def create_tcp_url(self):
564addr = DEFAULT_HOSTNAME
565port = common.find_free_port()
566url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
567return url
568
569def test_common_errors(self):
570with self.assertRaisesRegex(ValueError, "port number missing"):
571gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
572next(gen)
573with self.assertRaisesRegex(ValueError, "rank parameter missing"):
574gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
575next(gen)
576with self.assertRaisesRegex(ValueError, "size parameter missing"):
577gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
578next(gen)
579
580def test_dns_timeout(self):
581with self.assertRaisesRegex(DistNetworkError, "client socket has timed out after.*dnsnotexist") as manager:
582gen = dist.rendezvous(
583"tcp://dnsnotexist:23456?world_size=2&rank=0",
584timeout=timedelta(seconds=1),
585)
586next(gen)
587self.assertTrue(isinstance(manager.exception, DistError))
588
589@retry_on_connect_failures
590def test_nominal(self):
591url = self.create_tcp_url()
592gen0 = dist.rendezvous(url + "&rank=0")
593store0, rank0, size0 = next(gen0)
594self.assertEqual(0, rank0)
595self.assertEqual(1, size0)
596
597# Set value on the single store
598store0.set("key0", "value0")
599
600# check with get
601self.assertEqual(b"value0", store0.get("key0"))
602
603@retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE))
604def test_tcp_store_timeout_set(self):
605url = self.create_tcp_url()
606test_store_timeout = timedelta(seconds=10)
607gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
608store0, rank0, size0 = next(gen0)
609# this should time out in 10s. If the timeout passed into rendezvous was
610# not respected, it will take much longer to timeout.
611start = time.time()
612with self.assertRaisesRegex(RuntimeError, "Timeout"):
613store0.get("nonexistant key")
614
615end = time.time()
616time_diff = end - start
617self.assertGreater(test_store_timeout.seconds * 10, time_diff)
618
619def test_tcp_store_timeout_doest_break_client(self):
620url = self.create_tcp_url()
621test_store_timeout = timedelta(seconds=10)
622gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
623store0, rank0, size0 = next(gen0)
624# this should time out in 10s. If the timeout passed into rendezvous was
625# not respected, it will take much longer to timeout.
626start = time.time()
627with self.assertRaisesRegex(RuntimeError, "Timeout"):
628store0.get("the_key")
629
630store0.set("the_key", "x")
631
632self.assertEqual(b"x", store0.get("the_key"))
633
634end = time.time()
635time_diff = end - start
636self.assertGreater(test_store_timeout.seconds * 10, time_diff)
637
638def test_tcp_store_url_with_libuv(self):
639url = self.create_tcp_url()
640gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
641store0, rank0, size0 = next(gen0)
642self.assertTrue(store0.libuvBackend)
643
644class DummyStore(dist.Store):
645def __init__(self):
646self.appends = []
647self.multi_sets = []
648self.multi_gets = []
649self.multi_get_res = []
650super().__init__()
651
652def append(self, key, value):
653self.appends.append((key, value))
654
655def multi_get(self, keys):
656self.multi_gets.append(keys)
657return self.multi_get_res.pop(0)
658
659def multi_set(self, keys, values):
660self.multi_sets.append((keys, values))
661
662def has_extended_api(self):
663return True
664
665class TestPythonStore(TestCase):
666def test_optional_methods_fail(self):
667class TestStore(dist.Store):
668pass
669store = TestStore()
670self.assertFalse(store.has_extended_api())
671with self.assertRaisesRegex(RuntimeError, "Not implemented."):
672store.append("foo", "bar")
673with self.assertRaisesRegex(RuntimeError, "Not implemented."):
674store.multi_get(["foo", "bar"])
675with self.assertRaisesRegex(RuntimeError, "Not implemented."):
676store.multi_set(["foo", "bar"], [b"v", b"v"])
677
678def test_has_extended_api_passthrough(self):
679class TestStore(dist.Store):
680pass
681test_store = TestStore()
682store = dist.PrefixStore("p", test_store)
683self.assertFalse(store.has_extended_api())
684with self.assertRaisesRegex(RuntimeError, "Not implemented."):
685store.append("foo", "bar")
686with self.assertRaisesRegex(RuntimeError, "Not implemented."):
687store.multi_get(["foo", "bar"])
688with self.assertRaisesRegex(RuntimeError, "Not implemented."):
689store.multi_set(["foo", "bar"], [b"v", b"v"])
690
691def test_has_extended_api_roundtrip(self):
692store = DummyStore()
693prefix = dist.PrefixStore("p", store)
694self.assertTrue(prefix.has_extended_api())
695
696def test_append_roundtrip(self):
697store = DummyStore()
698prefix = dist.PrefixStore("p", store)
699prefix.append("foo", "bar")
700self.assertEqual(1, len(store.appends))
701self.assertEqual(("p/foo", b"bar"), store.appends[0])
702
703def test_multi_get_roundtrip(self):
704store = DummyStore()
705prefix = dist.PrefixStore("p", store)
706store.multi_get_res.append([b"x", b"y"])
707res = prefix.multi_get(["foo", "bar"])
708self.assertEqual(1, len(store.multi_gets))
709self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0])
710self.assertEqual([b"x", b"y"], res)
711
712def test_multi_set_roundtrip(self):
713store = DummyStore()
714prefix = dist.PrefixStore("p", store)
715prefix.multi_set(["foo", "bar"], [b'x', b'y'])
716self.assertEqual(1, len(store.multi_sets))
717self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0])
718self.assertEqual([b'x', b'y'], store.multi_sets[0][1])
719
720def test_extended_methods_fallbacks(self):
721test_store = MyPythonStore()
722store = dist.PrefixStore("p", test_store)
723self.assertFalse(store.has_extended_api())
724store.append("foo", b"po")
725store.append("foo", b"tato")
726self.assertEqual(store.get("foo"), b"potato")
727
728store.multi_set(["a", "b"], [b"c", b"d"])
729self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"])
730
731
732class TestMultiThreadedWait(MultiThreadedTestCase):
733# TODO: Use less hacky means of instantiating stores.
734# Note, stores accumulate values per test.
735stores = [
736dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1),
737dist.HashStore(),
738dist.PrefixStore("pre", dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1)),
739create_tcp_store(),
740create_tcp_store(use_libuv=True),
741dist.PrefixStore("pre", create_tcp_store()),
742dist.PrefixStore("pre", create_tcp_store(use_libuv=True)),
743]
744
745@property
746def world_size(self):
747return 2
748
749def setUp(self):
750super().setUp()
751self._spawn_threads()
752
753# Iterates over self.stores, keep 7 in sync with len(self.stores).
754@parametrize("i", range(7))
755def test_wait(self, i):
756store = self.stores[i]
757store.set_timeout(timedelta(seconds=2))
758if dist.get_rank() == 0:
759store.wait(["key1"])
760self.assertEqual(b"value1", store.get("key1"))
761if dist.get_rank() == 1:
762store.set("key1", "value1")
763
764
765instantiate_parametrized_tests(TestMultiThreadedWait)
766
767@skip_if_win32()
768class TimeoutTest(TestCase):
769def tearDown(self):
770import signal
771super().tearDown()
772signal.signal(signal.SIGUSR1, signal.SIG_IGN)
773
774def test_interrupt_doesnt_break_wait(self):
775import signal
776rank_res = [None, None]
777
778def run(rank, my_store):
779nonlocal rank_res
780try:
781if rank == 0:
782time.sleep(4)
783my_store.set("foo", "bar")
784else:
785my_store.wait(["foo"], datetime.timedelta(seconds=10))
786rank_res[rank] = True
787except Error as e: # noqa: F821
788rank_res[rank] = e
789time.sleep(1)
790
791rank0_store = dist.TCPStore(
792host_name=DEFAULT_HOSTNAME, port=0, world_size=2, is_master=True, wait_for_workers=False)
793rank1_store = dist.TCPStore(
794host_name=DEFAULT_HOSTNAME, port=rank0_store.port, world_size=2, is_master=False, wait_for_workers=False)
795
796ths = []
797for i in range(2):
798t = threading.Thread(target=run, args=(i, [rank0_store, rank1_store][i],))
799t.start()
800ths.append(t)
801
802def handler(a, b):
803pass
804
805signal.signal(signal.SIGUSR1, handler)
806time.sleep(1)
807signal.pthread_kill(ths[1].ident, signal.SIGUSR1)
808
809for t in ths:
810t.join()
811self.assertTrue(rank_res[0], "rank0")
812self.assertTrue(rank_res[1], "rank1")
813
814
815class InitPgWithUvStore(TestCase):
816def tearDown(self):
817super().tearDown()
818os.environ.pop("USE_LIBUV", None)
819os.environ.pop("MASTER_ADDR", None)
820os.environ.pop("MASTER_PORT", None)
821
822def test_with_url_param(self):
823port = common.find_free_port()
824dist.init_process_group("gloo", rank=0, world_size=1, init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=1")
825self._run_test()
826
827def test_with_env_var(self):
828port = common.find_free_port()
829os.environ["USE_LIBUV"] = "1"
830os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME
831os.environ["MASTER_PORT"] = str(port)
832dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://")
833self._run_test()
834
835def _run_test(self):
836pg = dist.group.WORLD
837store = c10d._get_process_group_store(pg)
838self.assertTrue(isinstance(store, dist.PrefixStore))
839# c10d does multiple levels of wrapping
840while isinstance(store, dist.PrefixStore):
841store = store.underlying_store
842self.assertTrue(isinstance(store, dist.TCPStore))
843self.assertTrue(store.libuvBackend)
844dist.destroy_process_group()
845
846if __name__ == "__main__":
847assert (
848not torch.cuda._initialized
849), "test_distributed must not have initialized CUDA context on main process"
850
851run_tests()
852