pytorch
1445 строк · 52.5 Кб
1# Owner(s): ["oncall: distributed"]
2
3# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4#
5# This source code is licensed under the BSD license found in the
6# LICENSE file in the root directory of this source tree.
7
8import copy
9import os
10import sys
11import unittest
12from contextlib import nullcontext
13from typing import Any, cast, List
14
15import numpy as np
16
17import torch
18import torch.distributed as dist
19
20
21if not dist.is_available():
22print("Distributed not available, skipping tests", file=sys.stderr)
23sys.exit(0)
24from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
25hook_with_zero_step,
26hook_with_zero_step_interleaved,
27)
28from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
29from torch.distributed.algorithms.join import Join, Joinable, JoinHook
30from torch.distributed.optim import ZeroRedundancyOptimizer
31from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
32from torch.nn.parallel import DistributedDataParallel as DDP
33from torch.optim import AdamW, SGD
34from torch.testing._internal import common_distributed
35from torch.testing._internal.common_utils import (
36instantiate_parametrized_tests,
37IS_WINDOWS,
38parametrize,
39run_tests,
40TEST_WITH_ASAN,
41TEST_WITH_DEV_DBG_ASAN,
42)
43
44
45try:
46import torchvision
47
48HAS_TORCHVISION = True
49except ImportError:
50HAS_TORCHVISION = False
51
52
53# Use GLOO on GPU when running CUDA + Windows
54def _get_backend_for_tests():
55return (
56dist.Backend.NCCL
57if not IS_WINDOWS and torch.cuda.is_available()
58# Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
59# no GPUs are available.
60else dist.Backend.GLOO
61)
62
63
64BACKEND = _get_backend_for_tests()
65
66
67@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.")
68class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
69def setUp(self):
70super().setUp()
71os.environ["WORLD_SIZE"] = str(self.world_size)
72self._spawn_processes()
73
74@property
75def device(self):
76return (
77torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
78)
79
80@property
81def world_size(self):
82return 1
83
84def tearDown(self):
85try:
86torch.distributed.destroy_process_group()
87except AssertionError:
88pass
89try:
90os.remove(self.file_name)
91except OSError:
92pass
93
94def dist_init(self, rank, world_size=-1, backend=BACKEND):
95if world_size < 1:
96world_size = self.world_size
97store = dist.FileStore(self.file_name, world_size)
98return dist.init_process_group(
99backend=backend,
100store=store,
101rank=rank,
102world_size=world_size,
103)
104
105
106# TODO: skip_but_pass_in_sandcastle_if does not work here.
107@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.")
108class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
109def test_state_dict(self):
110"""Check that ZeroRedundancyOptimizer exposes the expected state dict
111interface, irrespective of the sharding."""
112self.dist_init(self.rank)
113LR1 = 0.1
114LR2 = 0.01
115MOMENTUM = 0.9
116RECIPIENT_RANK = 0 # rank 0 is the only rank since the world size is 1
117x = torch.tensor([1.0], device=self.device, requires_grad=True)
118o = ZeroRedundancyOptimizer(
119[x],
120optimizer_class=SGD,
121lr=LR1,
122momentum=MOMENTUM,
123)
124x.backward()
125o.step()
126self.assertEqual(x, torch.tensor([0.9], device=self.device))
127self.assertEqual(
128o.optim.state[x]["momentum_buffer"],
129torch.tensor([1.0], device=self.device),
130)
131
132o.zero_grad()
133o.consolidate_state_dict(to=RECIPIENT_RANK)
134state_dict = o.state_dict()
135
136# Check that the state dict has keys compliant with PyTorch
137self.assertIn("param_groups", state_dict.keys())
138self.assertIn("state", state_dict.keys())
139
140# Check that the state has the expected keys
141self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
142self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
143self.assertFalse(state_dict["param_groups"][0]["nesterov"])
144self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
145self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)
146
147# Check that the state and the `param_groups` attribute are in sync
148for k in state_dict["param_groups"][0]:
149if k != "params":
150self.assertEqual(
151state_dict["param_groups"][0][k],
152o.param_groups[0][k],
153)
154
155# Check that the state is reloaded with the correct values and device
156o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR2)
157o.load_state_dict(state_dict)
158self.assertEqual(
159o.optim.state[x]["momentum_buffer"],
160torch.tensor([1.0], device=self.device),
161)
162
163# We should we using `LR1` and not `LR2` after reloading, both within
164# the optimizer and as exposed by the `param_groups` attribute
165self.assertEqual(o.param_groups[0]["lr"], LR1)
166x.backward()
167o.step()
168self.assertEqual(x, torch.tensor([0.71], device=self.device))
169self.assertEqual(
170o.optim.state[x]["momentum_buffer"],
171torch.tensor([1.9], device=self.device),
172)
173
174# Check that the exposed `param_groups`` are on the proper device
175self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
176
177def test_lr_scheduler(self):
178"""Check that a normal PyTorch ``lr_scheduler`` is usable with
179ZeroRedundancyOptimizer."""
180self.dist_init(self.rank)
181NUM_ITERS = 5
182LR = 0.01
183x = torch.tensor([1.0], device=self.device, requires_grad=True)
184x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
185o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR)
186o2 = torch.optim.SGD([x2], lr=LR)
187s = torch.optim.lr_scheduler.StepLR(o, 1)
188s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
189for _ in range(NUM_ITERS):
190x.backward()
191o.zero_grad()
192o.step()
193s.step()
194x2.backward()
195o2.zero_grad()
196o2.step()
197s2.step()
198self.assertEqual(x, x2)
199
200def test_step_with_kwargs(self):
201"""Check that the ``step(**kwargs)`` interface is properly exposed."""
202self.dist_init(self.rank)
203LR = 0.1
204
205class SGDWithStepKWArg(torch.optim.SGD):
206def step(self, closure=None, kwarg=None):
207super().step()
208kwarg.append(5)
209
210kwarg: List[Any] = []
211x = torch.tensor([1.0], device=self.device, requires_grad=True)
212o = ZeroRedundancyOptimizer(
213[x],
214optimizer_class=SGDWithStepKWArg,
215lr=LR,
216)
217x.backward()
218o.step(0, kwarg=kwarg)
219self.assertEqual(kwarg, [5])
220self.assertEqual(x, torch.tensor([0.9], device=self.device))
221
222def test_step_with_extra_inner_key(self):
223"""Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
224extra keys to ``param_groups`` exposes those keys through ZeRO's own
225``param_groups``."""
226self.dist_init(self.rank)
227LR = 0.1
228
229class SGDWithNewKey(torch.optim.SGD):
230# Dummy optimizer which adds a new key to the param groups
231def step(self, closure=None):
232super().step()
233self.param_groups[0]["new_key"] = 0.1
234
235x = torch.tensor([1.0], device=self.device, requires_grad=True)
236o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithNewKey, lr=LR)
237x.backward()
238o.step()
239self.assertEqual(o.param_groups[0]["new_key"], 0.1)
240self.assertEqual(x, torch.tensor([0.9], device=self.device))
241
242def test_step_without_closure(self):
243"""Check that the ``step()`` method (without closure) is handled as
244expected."""
245self.dist_init(self.rank)
246LR = 0.1
247
248class SGDWithoutClosure(torch.optim.SGD):
249def step(self):
250return super().step()
251
252x = torch.tensor([1.0], device=self.device, requires_grad=True)
253o = ZeroRedundancyOptimizer(
254[x],
255optimizer_class=SGDWithoutClosure,
256lr=LR,
257)
258x.backward()
259o.step()
260self.assertEqual(x, torch.tensor([0.9], device=self.device))
261
262def test_zero_grad(self):
263"""Check that the ``zero_grad`` method is properly handled."""
264self.dist_init(self.rank)
265LR = 0.01
266x = torch.rand(1)
267m = torch.nn.Linear(1, 1)
268o = ZeroRedundancyOptimizer(m.parameters(), optimizer_class=SGD, lr=LR)
269y = m(x)
270y.backward(x)
271self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
272self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
273o.zero_grad()
274self.assertIsNone(m.weight.grad)
275self.assertIsNone(m.bias.grad)
276
277def test_constructor(self):
278"""Check the robustness of the ZeroRedundancyOptimizer constructor by
279passing different values for the ``params`` argument."""
280self.dist_init(self.rank)
281LR = 0.01
282m = torch.nn.Sequential(
283torch.nn.Linear(5, 10),
284torch.nn.Linear(10, 10),
285torch.nn.Linear(10, 10),
286)
287# Test various constructor inputs in the form: (input, expected error)
288ctor_inputs = [
289([], ValueError), # empty parameter list
290(torch.randn(1), TypeError), # non-iterable: `torch.Tensor`
291(1.2, TypeError), # non-iterable: `float`
292(
293[
294{"params": [l.weight for l in m]},
295{"params": [l.bias for l in m]},
296],
297None,
298), # iterable of dict
299(
300list(m.parameters()) + [42],
301TypeError,
302), # iterable containing invalid type
303(m.parameters(), None), # `params` as a generator
304(list(m.parameters()), None), # `params` as a list
305]
306for ctor_input, error in ctor_inputs:
307context = self.assertRaises(error) if error else nullcontext()
308with context:
309ZeroRedundancyOptimizer(
310ctor_input,
311optimizer_class=SGD,
312lr=LR,
313)
314
315# Test constructing with multiple parameter groups more thoroughly
316WD = 0.01
317BETAS = (0.9, 0.999)
318EPS = 1e-8
319params = [
320{"params": [l.weight for l in m], "weight_decay": 0.0},
321{"params": [l.bias for l in m], "weight_decay": WD},
322]
323o = ZeroRedundancyOptimizer(
324params,
325optimizer_class=AdamW,
326lr=LR,
327betas=BETAS,
328eps=EPS,
329)
330assert (
331len(o.param_groups) == 2
332), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
333assert len(o.optim.param_groups) == 2, (
334"Expected 2 local optimizer param groups, but got "
335f"{len(o.optim.param_groups)}"
336)
337
338def test_same_dense_param_type(self):
339"""Check that ZeroRedundancyOptimizer raises an exception if the input
340parameters include sparse tensors or different dense types.
341
342NOTE: This test should be removed once support for sparse parameters
343and varying parameter types is added.
344"""
345self.dist_init(self.rank)
346LR = 0.01
347inputs = [
348[torch.sparse_coo_tensor(size=(2, 3))],
349[torch.FloatTensor(1), torch.DoubleTensor(1)],
350[
351torch.FloatTensor(1),
352torch.FloatTensor(1),
353torch.sparse_coo_tensor(size=(2, 3)),
354],
355]
356for input in inputs:
357with self.assertRaises(ValueError):
358ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=LR)
359
360
361class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
362@property
363def device(self):
364return (
365torch.device(self.rank)
366if torch.cuda.is_available()
367else torch.device("cpu")
368)
369
370@property
371def world_size(self):
372return min(4, max(2, torch.cuda.device_count()))
373
374@property
375def context(self):
376return (
377nullcontext()
378if not torch.cuda.is_available()
379else torch.cuda.device(self.rank)
380)
381
382def _check_same_model_params(
383self,
384model_a: torch.nn.Module,
385model_b: torch.nn.Module,
386message: str = "",
387) -> None:
388# Check that model parameters match
389for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
390torch.testing.assert_close(
391p_a,
392p_b,
393atol=1e-3,
394rtol=1e-5,
395msg=f"Model parameters differ:\n{p_a} {p_b}\n" + message,
396)
397# Check that model buffers match
398for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
399torch.testing.assert_close(
400b_a,
401b_b,
402msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
403)
404
405@common_distributed.skip_if_no_gpu
406@common_distributed.skip_if_rocm
407def test_step(self):
408"""Check that ZeroRedundancyOptimizer properly exposes the ``step()``
409interface."""
410self.dist_init(self.rank, world_size=self.world_size)
411LR = 0.01
412
413with self.context:
414x = torch.tensor([float(self.rank + 1)], device=self.device)
415m = torch.nn.Linear(1, 1)
416m.weight.data = torch.tensor([[1.0]])
417m.bias.data = torch.tensor([2.0])
418m = m.to(self.device)
419m_zero = copy.deepcopy(m).to(self.device)
420
421o = SGD(m.parameters(), lr=LR)
422o_zero = ZeroRedundancyOptimizer(
423m_zero.parameters(),
424optimizer_class=SGD,
425lr=LR,
426)
427
428y = m(x)
429y.backward(x)
430y_zero = m_zero(x)
431y_zero.backward(x)
432
433for p in m.parameters():
434dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
435p.grad.data /= self.world_size
436o.step()
437for p in m_zero.parameters():
438dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
439p.grad.data /= self.world_size
440o_zero.step()
441
442self.assertEqual(m.weight, m_zero.weight)
443self.assertEqual(m.bias, m_zero.bias)
444
445@common_distributed.skip_if_no_gpu
446@common_distributed.skip_if_rocm
447def test_step_with_closure(self):
448"""Check that ZeroRedundancyOptimizer properly exposes the
449``step(closure)`` interface."""
450self.dist_init(self.rank, world_size=self.world_size)
451
452with self.context:
453for bucket_view in [False, True]:
454x_val = self.rank + 1
455weight = 1.0
456bias = 2.0
457error = 1.0
458target = torch.tensor(
459[x_val * weight + bias + error],
460device=self.device,
461)
462loss_fn = torch.nn.L1Loss()
463
464x = torch.tensor([float(x_val)], device=self.device)
465m = torch.nn.Linear(1, 1)
466m.weight.data = torch.tensor([[weight]])
467m.bias.data = torch.tensor([bias])
468m.to(self.device)
469
470o = ZeroRedundancyOptimizer(
471m.parameters(),
472optimizer_class=SGD,
473parameters_as_bucket_view=bucket_view,
474lr=0.1,
475)
476
477y = m(x)
478y.backward(x)
479for p in m.parameters():
480dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
481p.grad.data /= self.world_size
482
483def closure():
484o.zero_grad()
485output = m(x)
486loss = loss_fn(output, target)
487loss.backward()
488return loss
489
490loss = o.step(closure=closure)
491
492self.assertEqual(loss, torch.tensor(error))
493self.assertEqual(m.weight, torch.tensor([[1.1]]))
494self.assertEqual(m.bias, torch.tensor([2.1]))
495
496@common_distributed.skip_if_no_gpu
497def test_lr_scheduler(self):
498"""Check that a normal PyTorch ``lr_scheduler`` is usable with
499ZeroRedundancyOptimizer."""
500self.dist_init(self.rank)
501x = torch.tensor([1.0], device=self.device, requires_grad=True)
502x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
503o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
504o2 = torch.optim.SGD([x2], lr=0.01)
505s = torch.optim.lr_scheduler.StepLR(o, 1)
506s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
507for _ in range(5):
508x.backward()
509o.zero_grad()
510o.step()
511s.step()
512x2.backward()
513o2.zero_grad()
514o2.step()
515s2.step()
516self.assertEqual(x, x2)
517
518def test_sharding(self):
519"""
520Check ZeroRedundancyOptimizer's parameter sharding at construction
521time.
522
523NOTE: The correctness of this test depends on the ZeRO implementation
524using the sorted-greedy partitioning algorithm. For details, see
525``ZeroRedundancyOptimizer._partition_parameters()`` in
526zero_redundancy_optimizer.py.
527"""
528self.dist_init(self.rank)
529LR = 0.01
530sizes = [9, 7, 5, 3]
531params = []
532for size in sizes * self.world_size:
533params.append(torch.rand(size, 1))
534o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
535self.assertEqual(
536sum(x.numel() for x in o.optim.param_groups[0]["params"]),
537sum(sizes),
538)
539
540def test_add_param_group(self):
541"""Check that ZeroRedundancyOptimizer properly handles adding a new
542parameter group a posteriori and that all ranks get a shard of the
543contained parameters.
544
545NOTE: The correctness of this test depends on the ZeRO implementation
546using the sorted-greedy partitioning algorithm. For details, see
547``ZeroRedundancyOptimizer._partition_parameters()`` in
548zero_redundancy_optimizer.py.
549"""
550self.dist_init(self.rank)
551LR = 0.01
552
553# Test with all parameters trainable to begin with
554def all_trainable():
555params = []
556sizes = [9, 7, 5, 3]
557sizes_world = sizes * self.world_size
558for size in sizes_world[:-1]:
559params.append(torch.rand(size, 1))
560
561# Make sure that the params are trainable so that they are factored
562# into the size-based parameter partitioning
563for p in params:
564p.requires_grad = True
565
566o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
567self.assertEqual(len(o.param_groups), 1)
568o.add_param_group({"params": [torch.rand(3, 1)]})
569# Verify that new group is added to the correct partition, making
570# all partitions have the same elements
571self.assertEqual(len(o.param_groups), 2)
572self.assertEqual(
573sum(x.numel() for g in o.optim.param_groups for x in g["params"]),
574sum(sizes),
575)
576self.assertEqual(len(o.optim.param_groups), 2)
577
578# Test a pathological config with a first big non-trainable param
579def some_trainable():
580params = []
581for size in [100, 3, 5, 2, 6, 4]:
582params.append(torch.rand(size, 1))
583
584# Make sure that all but the first param are trainable so that they
585# are factored into the size-based parameter partitioning
586for p in params[1:]:
587p.requires_grad = True
588
589o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
590self.assertEqual(len(o.param_groups), 1)
591o.add_param_group({"params": [torch.rand(3, 1)]})
592self.assertEqual(len(o.param_groups), 2)
593self.assertEqual(len(o.optim.param_groups), 2)
594
595all_trainable()
596some_trainable()
597
598@common_distributed.skip_if_no_gpu
599def test_multiple_param_groups(self):
600"""
601Check parity between constructing ZeRO with multiple parameter groups
602upfront versus adding parameter groups to ZeRO after construction
603versus a non-sharded optimizer.
604"""
605self.dist_init(self.rank)
606BATCH_SIZE, NUM_ITERS = 8, 3
607INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
608WD, LR = 0.01, 0.01
609model1 = torch.nn.Sequential(
610torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
611torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
612torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
613)
614model2 = copy.deepcopy(model1)
615model3 = copy.deepcopy(model1)
616model1 = model1.to(self.device)
617model2 = model2.to(self.device)
618model3 = model3.to(self.device)
619inputs = [
620torch.randn(BATCH_SIZE, INPUT_DIM).to(self.device) for _ in range(NUM_ITERS)
621]
622# Construct `optim1` with both parameter groups upfront
623optim1 = ZeroRedundancyOptimizer(
624[
625{"params": [l.weight for l in model1], "weight_decay": 0.0},
626{"params": [l.bias for l in model1], "weight_decay": WD},
627],
628optimizer_class=AdamW,
629lr=LR,
630)
631# Construct `optim2` by adding the second parameter after
632optim2 = ZeroRedundancyOptimizer(
633[l.weight for l in model2],
634optimizer_class=AdamW,
635lr=LR,
636weight_decay=0.0,
637)
638optim2.add_param_group({"params": [l.bias for l in model2], "weight_decay": WD})
639# Construct `optim3` as a non-sharded optimizer
640optim3 = AdamW(
641[
642{"params": [l.weight for l in model3], "weight_decay": 0.0},
643{"params": [l.bias for l in model3], "weight_decay": WD},
644],
645lr=LR,
646)
647# Check parity over a few iterations
648for input in inputs:
649for model, optim in (
650(model1, optim1),
651(model2, optim2),
652(model3, optim3),
653):
654optim.zero_grad()
655out = model(input)
656loss = out.sum()
657loss.backward()
658optim.step()
659for layer1, layer2, layer3 in zip(model1, model2, model3):
660torch.testing.assert_close(layer1.weight, layer2.weight)
661torch.testing.assert_close(layer1.weight, layer3.weight)
662torch.testing.assert_close(layer1.bias, layer2.bias)
663torch.testing.assert_close(layer1.bias, layer3.bias)
664
665@common_distributed.skip_if_no_gpu
666@common_distributed.skip_if_rocm
667def test_collect_shards(self):
668"""Check the state consolidation mechanism and the state dict exposed
669by ZeroRedundancyOptimizer."""
670self.dist_init(self.rank)
671LR = 1e-3
672MOMENTUM = 0.99
673BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
674REFERENCE_RANK = 0
675target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=self.device)
676inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=self.device)
677model = torch.nn.Sequential(
678torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
679torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
680).to(self.device)
681loss_fn = torch.nn.L1Loss()
682loss_fn.to(self.device)
683optimizer = ZeroRedundancyOptimizer(
684model.parameters(),
685optimizer_class=SGD,
686lr=LR,
687momentum=MOMENTUM, # ensure there exists state to shard
688)
689
690def closure():
691optimizer.zero_grad()
692output = model(inputs)
693loss = loss_fn(output, target)
694loss.backward()
695return loss
696
697# Run a dummy step so that the optimizer state dict exists
698_ = optimizer.step(closure=closure)
699
700# Get the optimizer state on the reference rank
701optimizer.consolidate_state_dict(to=REFERENCE_RANK)
702if self.rank == REFERENCE_RANK:
703# Check that the state has the correct size
704optimizer_state_dict = optimizer.state_dict()
705self.assertEqual(
706len(optimizer_state_dict["state"]),
707len(list(model.parameters())),
708)
709else:
710optimizer_state_dict = {}
711
712# Load the optimizer state on all ranks without any exceptions
713optimizer_state_dict = _broadcast_object(
714optimizer_state_dict,
715src_rank=REFERENCE_RANK,
716group=dist.group.WORLD,
717device=self.device,
718)
719optimizer.load_state_dict(optimizer_state_dict)
720
721def test_nondefault_process_group(self):
722"""Check that ZeroRedundancyOptimizer works with a non-default process
723group consisting only of even ranks."""
724# Skip the test if below the minimum world size since then the test is
725# trivial
726MIN_WORLD_SIZE = 4
727if self.world_size < MIN_WORLD_SIZE:
728common_distributed.logger.info(
729"Skipping `test_nondefault_process_group()` since world size "
730"of %s is less than %s",
731self.world_size,
732MIN_WORLD_SIZE,
733)
734return
735BACKEND = dist.Backend.GLOO
736self.dist_init(self.rank, self.world_size, BACKEND)
737# Use GPU if enough are available, or fall back to CPU otherwise, which
738# is fine since Gloo backend supports both
739if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
740device = torch.device(self.rank)
741else:
742device = torch.device("cpu")
743# Create a new process group consisting of the even ranks to exercise
744# the case where the global and local ranks do not necessarily match
745subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
746process_group = dist.new_group(
747ranks=subgroup_ranks,
748backend=BACKEND,
749)
750# Ranks not participating in the new process group are no longer needed
751if self.rank not in subgroup_ranks:
752return
753
754# Set different seeds across ranks so that each rank gets different
755# training data and hence the model sync check is meaningful
756torch.manual_seed(self.rank)
757np.random.seed(self.rank)
758
759EPOCHS, BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 3, 20, 10, 5
760LR = 1e-3
761MOMENTUM = 0.99
762REFERENCE_RANK = 0
763assert (
764REFERENCE_RANK in subgroup_ranks
765), "Reference rank must be in the new process group"
766loss_fn = torch.nn.L1Loss().to(device)
767
768def check(optimizer):
769for _ in range(EPOCHS):
770target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=device)
771inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=device)
772
773def closure():
774optimizer.zero_grad()
775output = model(inputs)
776loss = loss_fn(output, target)
777loss /= self.world_size
778loss.backward()
779dist.all_reduce(loss, group=process_group)
780return loss
781
782_ = optimizer.step(closure=closure)
783
784# Check that the parameters match across ranks after a step
785for pg in optimizer.param_groups:
786for p in pg["params"]:
787receptacle = (
788[p.clone() for _ in subgroup_ranks]
789if self.rank == REFERENCE_RANK
790else []
791)
792dist.gather(
793p,
794receptacle,
795dst=REFERENCE_RANK,
796group=process_group,
797)
798if self.rank == REFERENCE_RANK:
799reference_param = receptacle[0]
800for param in receptacle[1:]:
801torch.testing.assert_close(
802reference_param,
803param,
804msg="Models differ between ranks",
805)
806
807model = torch.nn.Sequential(
808torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
809torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
810).to(device)
811optimizer = ZeroRedundancyOptimizer(
812model.parameters(),
813optimizer_class=SGD,
814lr=LR,
815momentum=MOMENTUM, # ensure there exists state to shard
816process_group=process_group,
817)
818check(optimizer)
819
820@common_distributed.skip_if_no_gpu
821@parametrize(
822"optimizer_class_str",
823["Adam", "AdamW", "SGD"],
824# Use string to appease the internal test name parser
825)
826@parametrize(
827"maximize",
828[False, True],
829)
830def test_local_optimizer_parity(
831self,
832optimizer_class_str: str,
833maximize: bool,
834):
835"""When combined with DDP, check that a local optimizer gives the same
836results as wrapping that optimizer with ZeroRedundancyOptimizer."""
837self.dist_init(self.rank)
838BATCHES = 20
839BATCH_SIZE = 64
840LR = 1e-3
841INPUT_DIM = 2
842HIDDEN_DIM = 3
843OUTPUT_DIM = 3
844torch.manual_seed(self.rank)
845np.random.seed(self.rank)
846if optimizer_class_str == "Adam":
847optimizer_class = torch.optim.Adam
848elif optimizer_class_str == "AdamW":
849optimizer_class = torch.optim.AdamW
850elif optimizer_class_str == "SGD":
851optimizer_class = torch.optim.SGD
852else:
853assert 0, f"Unsupported optimizer class: {optimizer_class_str}"
854
855with self.context:
856# Define a base model with a different buffer for each rank
857model = torch.nn.Sequential(
858torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
859torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
860torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
861).to(self.device)
862model.test_buffer = torch.nn.Buffer(
863torch.ones((1), device=self.device) * self.rank,
864)
865# Define models/optimizers for DDP with ZeRO and DDP with local
866# optimizer
867defaults = {"maximize": True} if maximize else {}
868sharded_optimizer = ZeroRedundancyOptimizer(
869params=model.parameters(),
870optimizer_class=optimizer_class,
871lr=LR,
872**defaults,
873)
874sharded_ddp_model = DDP(
875module=model,
876device_ids=[self.rank],
877broadcast_buffers=True,
878find_unused_parameters=True,
879)
880local_model = copy.deepcopy(model).to(self.device)
881ddp_optimizer = optimizer_class(
882local_model.parameters(),
883lr=LR,
884**defaults,
885)
886ddp_model = DDP(
887local_model,
888device_ids=[self.rank],
889broadcast_buffers=True,
890find_unused_parameters=True,
891)
892# Check that the model is properly synchronized between ranks
893# at construction time
894self._check_same_model_params(
895sharded_ddp_model,
896ddp_model,
897"Models differ from the start",
898)
899
900def check_step():
901input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM))
902
903def closure_ddp(input_tensor=input_tensor):
904ddp_optimizer.zero_grad()
905ddp_loss = ddp_model(input_tensor).abs().sum()
906ddp_loss.backward()
907return ddp_loss
908
909def closure_sharded(input_tensor=input_tensor):
910sharded_optimizer.zero_grad()
911sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
912sharded_loss.backward()
913return sharded_loss
914
915loss_ddp = cast(
916torch.Tensor,
917ddp_optimizer.step(closure=closure_ddp),
918)
919loss_sharded_optim = cast(
920torch.Tensor,
921sharded_optimizer.step(closure=closure_sharded),
922)
923torch.testing.assert_close(
924loss_ddp,
925loss_sharded_optim,
926msg="Losses differ between local optimizer and ZeRO",
927)
928self._check_same_model_params(
929sharded_ddp_model,
930ddp_model,
931"Models differ after a step",
932)
933
934# Check that parity is maintained
935for i in range(BATCHES):
936check_step()
937# For the second half of batches, change the parameter
938# trainability to further test parity
939if i > BATCHES // 2:
940next(ddp_model.parameters()).requires_grad = bool(i % 2)
941next(sharded_ddp_model.parameters()).requires_grad = bool(i % 2)
942
943# Check that the `state_dict` checkpoints are compatible between
944# the local optimizer and ZeRO
945REFERENCE_RANK = 0
946# - Get states
947ddp_state_dict = ddp_optimizer.state_dict()
948sharded_optimizer.consolidate_state_dict(to=REFERENCE_RANK)
949sharded_optim_state_dict = [
950sharded_optimizer.state_dict() if self.rank == REFERENCE_RANK else {}
951]
952dist.broadcast_object_list(
953sharded_optim_state_dict,
954src=REFERENCE_RANK,
955group=dist.group.WORLD,
956)
957sharded_optim_state_dict = sharded_optim_state_dict[0]
958
959# - Cross-load the states
960# Run one step and check that the models are still the same
961ddp_state_dict_ref = copy.deepcopy(ddp_state_dict)
962ddp_optimizer.load_state_dict(sharded_optim_state_dict)
963sharded_optimizer.load_state_dict(ddp_state_dict)
964check_step()
965
966# - Reload their respective states
967# Run one step and check that the models are still the same
968ddp_optimizer.load_state_dict(ddp_state_dict_ref)
969sharded_optimizer.load_state_dict(sharded_optim_state_dict)
970check_step()
971
972def _test_zero_join(self, device):
973"""Check that the ZeRO join hook allows training with uneven inputs
974when using the given device."""
975NUM_INPUTS = 3
976NUM_EPOCHS = 2
977LR = 0.01
978torch.manual_seed(0)
979torch.cuda.manual_seed(0)
980
981rank = self.rank
982world_size = self.world_size
983is_gpu = device.type == "cuda"
984backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
985self.dist_init(rank, world_size, backend)
986
987model = torch.nn.Sequential(
988torch.nn.Linear(2, 3),
989torch.nn.Linear(3, 3),
990torch.nn.Linear(3, 3),
991)
992model.to(device)
993
994# DDP ensures correct gradients in data parallel training, so DDP with
995# local optimizers on uneven inputs should be equivalent to ZeRO on
996# uneven inputs with gradients being manually set
997ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
998local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
999zero_model = copy.deepcopy(model)
1000zero_model.to(device)
1001zero_optim = ZeroRedundancyOptimizer(
1002zero_model.parameters(),
1003torch.optim.Adam,
1004lr=LR,
1005)
1006loss_fn = torch.nn.MSELoss()
1007
1008# Use uneven inputs: rank i has i extra inputs
1009inputs = [torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank)]
1010labels = torch.randn(20, 3).to(device)
1011
1012# Save the gradients and parameters from DDP as the ground truth; do
1013# so on the last-joining rank (in this case, the largest rank)
1014grads_at_each_iter = []
1015params_at_each_iter = []
1016with ddp_model.join():
1017for _ in range(NUM_EPOCHS):
1018for input in inputs:
1019output = ddp_model(input)
1020loss_fn(output, labels).backward()
1021if rank == world_size - 1:
1022grads = []
1023for p in ddp_model.parameters():
1024grads.append(p.grad.detach().clone().to(device))
1025local_optim.step()
1026if rank == world_size - 1:
1027params = []
1028for p in ddp_model.parameters():
1029params.append(p.detach().clone().to(device))
1030grads_at_each_iter.append(grads)
1031params_at_each_iter.append(params)
1032
1033# Broadcast the saved gradients and parameters to all of the other
1034# ranks (which joined early)
1035grads_and_params = [grads_at_each_iter, params_at_each_iter]
1036grads_and_params = _broadcast_object(
1037grads_and_params,
1038src_rank=world_size - 1,
1039group=dist.group.WORLD,
1040device=device,
1041)
1042grads_at_each_iter = grads_and_params[0]
1043params_at_each_iter = grads_and_params[1]
1044# TODO: Replace this `_broadcast_object` with `broadcast_object_list`
1045# once the latter supports loading to the destination device instead
1046# of the source device
1047
1048# A process must still set the remaining gradients after joining, so we
1049# define a join hook to do this before the ZeRO join hook
1050class _JoinGradInfo:
1051def __init__(self, grads):
1052self.grads = grads # remaining gradients to set (in order)
1053self.index = 0
1054
1055class _SetGradsJoinHook(JoinHook):
1056def __init__(self, zero_optim, grads):
1057zero_optim._join_grad_info = _JoinGradInfo(grads)
1058self.zero = zero_optim
1059super().__init__()
1060
1061def main_hook(self):
1062join_grad_info = self.zero._join_grad_info
1063grads = self.zero._join_grad_info.grads[join_grad_info.index]
1064join_grad_info.index += 1
1065for p, grad in zip(self.zero._all_params, grads):
1066p.grad = grad.detach().clone().to(device)
1067
1068class _GradientSetter(Joinable):
1069def __init__(self) -> None:
1070super().__init__()
1071
1072def join_hook(self, **kwargs):
1073assert "zero_optim" in kwargs
1074assert "grads" in kwargs
1075zero_optim = kwargs["zero_optim"]
1076grads = kwargs["grads"]
1077return _SetGradsJoinHook(zero_optim, grads)
1078
1079@property
1080def join_device(self):
1081return device
1082
1083@property
1084def join_process_group(self):
1085return dist.group.WORLD
1086
1087num_grads_after_joining = NUM_EPOCHS * (world_size - rank - 1)
1088grads = grads_at_each_iter[-num_grads_after_joining:]
1089gradient_setter = _GradientSetter()
1090iter = 0
1091with Join(
1092[gradient_setter, zero_optim],
1093zero_optim=zero_optim,
1094grads=grads,
1095):
1096for _ in range(NUM_EPOCHS):
1097for input in inputs:
1098# Notify join context that this process has not joined
1099Join.notify_join_context(gradient_setter)
1100# Set gradients manually
1101for p, grad in zip(
1102zero_model.parameters(),
1103grads_at_each_iter[iter],
1104):
1105p.grad = grad.detach().clone().to(device)
1106# Perform optimizer step and check parity
1107zero_optim.step()
1108for p, ddp_p in zip(
1109zero_model.parameters(),
1110params_at_each_iter[iter],
1111):
1112torch.testing.assert_close(
1113p,
1114ddp_p,
1115msg="Parameters differ between using ZeRO and "
1116"local optimizer",
1117)
1118iter += 1
1119
1120@common_distributed.requires_nccl()
1121@common_distributed.skip_if_no_gpu
1122def test_zero_join_gpu(self):
1123"""Check that the ZeRO join hook allows training with uneven inputs
1124on GPU."""
1125self._test_zero_join(self.device)
1126
1127@common_distributed.requires_gloo()
1128def test_zero_join_cpu(self):
1129"""Check that the ZeRO join hook allows training with uneven inputs
1130on CPU."""
1131self._test_zero_join(torch.device("cpu"))
1132
1133def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
1134# Use two processes each with two GPUs
1135assert self.rank < 2
1136NUM_EPOCHS = 2
1137NUM_INPUTS = 4
1138LR = 0.01
1139torch.manual_seed(0)
1140torch.cuda.manual_seed(0)
1141
1142class ModelParallelModel(torch.nn.Module):
1143def __init__(self, dev0, dev1):
1144super().__init__()
1145self.dev0 = dev0
1146self.dev1 = dev1
1147self.net0 = torch.nn.Linear(10, 10).to(dev0)
1148self.relu = torch.nn.ReLU()
1149self.net1 = torch.nn.Linear(10, 5).to(dev1)
1150
1151def forward(self, x):
1152x = x.to(self.dev0)
1153x = self.relu(self.net0(x))
1154x = x.to(self.dev1)
1155return self.net1(x)
1156
1157class LocalModel(torch.nn.Module):
1158def __init__(self) -> None:
1159super().__init__()
1160self.net0 = torch.nn.Linear(10, 10)
1161self.relu = torch.nn.ReLU()
1162self.net1 = torch.nn.Linear(10, 5)
1163
1164def forward(self, x):
1165return self.net1(self.relu(self.net0(x)))
1166
1167dev0 = torch.device(2 * self.rank)
1168dev1 = torch.device(2 * self.rank + 1)
1169mp_model = ModelParallelModel(dev0, dev1)
1170ddp_model = DDP(mp_model)
1171local_model = LocalModel().to(dev0)
1172
1173# Ensure the parameters are the same across the two models
1174def copy_param(p):
1175return torch.nn.Parameter(p.detach().clone().to(dev0))
1176
1177local_model.net0.weight = copy_param(mp_model.net0.weight)
1178local_model.net0.bias = copy_param(mp_model.net0.bias)
1179local_model.net1.weight = copy_param(mp_model.net1.weight)
1180local_model.net1.bias = copy_param(mp_model.net1.bias)
1181
1182# Compare parity between DDP with model parallelism using ZeRO and
1183# a local model using a local optimizer
1184zero_optim = ZeroRedundancyOptimizer(
1185ddp_model.parameters(),
1186optimizer_class=torch.optim.Adam,
1187parameters_as_bucket_view=parameters_as_bucket_view,
1188lr=LR,
1189)
1190local_optim = torch.optim.Adam(local_model.parameters(), lr=LR)
1191inputs = [torch.randn(20, 10).to(dev0) for _ in range(NUM_INPUTS)]
1192
1193for _ in range(NUM_EPOCHS):
1194for input in inputs:
1195
1196def closure_local():
1197local_optim.zero_grad()
1198local_loss = local_model(input).abs().sum()
1199local_loss.backward()
1200return local_loss
1201
1202def closure_ddp():
1203zero_optim.zero_grad()
1204ddp_loss = ddp_model(input).abs().sum()
1205ddp_loss.backward()
1206return ddp_loss
1207
1208local_loss = cast(torch.Tensor, local_optim.step(closure=closure_local))
1209ddp_loss = cast(torch.Tensor, zero_optim.step(closure=closure_ddp))
1210
1211# Increased tolerances are needed to pass when using TF32
1212# See: https://github.com/pytorch/pytorch/issues/67764
1213torch.testing.assert_close(
1214local_loss.cpu(),
1215ddp_loss.cpu(),
1216rtol=1e-03,
1217atol=1e-08,
1218), "Losses differ between local optimizer and ZeRO"
1219
1220for local_p, ddp_p in zip(
1221local_model.parameters(), ddp_model.parameters()
1222):
1223torch.testing.assert_close(
1224local_p.cpu(),
1225ddp_p.cpu(),
1226rtol=1e-03,
1227atol=1e-04,
1228), "Models differ after a step"
1229
1230@common_distributed.skip_if_lt_x_gpu(4)
1231@parametrize(
1232"parameters_as_bucket_view",
1233[False, True],
1234)
1235def test_zero_model_parallel(
1236self,
1237parameters_as_bucket_view: bool,
1238):
1239"""Check that ZeRO works with model parallelism where the model's
1240layers are assigned to different devices."""
1241if self.rank >= 2:
1242return
1243self.dist_init(self.rank, world_size=2)
1244self._test_zero_model_parallel(parameters_as_bucket_view)
1245
1246def _test_ddp_zero_overlap(
1247self,
1248device,
1249hook_constructor,
1250gradient_as_bucket_view,
1251static_graph,
1252**kwargs,
1253):
1254SGD_LR = 0.01
1255SGD_MOMENTUM = 0.9
1256SGD_WEIGHT_DECAY = 0.001
1257NUM_INPUTS = 5
1258torch.manual_seed(0)
1259torch.cuda.manual_seed(0)
1260
1261rank = self.rank
1262is_gpu = device.type == "cuda"
1263if is_gpu:
1264torch.cuda.set_device(device)
1265models_to_test = [
1266(
1267torch.nn.Sequential(
1268torch.nn.Linear(1000, 2000),
1269torch.nn.Linear(2000, 500),
1270),
1271[torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)],
1272)
1273]
1274if HAS_TORCHVISION:
1275models_to_test.append(
1276(
1277torchvision.models.resnet50(),
1278[torch.randn(1, 3, 3, 1000).to(device) for _ in range(NUM_INPUTS)],
1279)
1280)
1281for model, inputs in models_to_test:
1282# Enable determinism in cudnn operators
1283with torch.backends.cudnn.flags(
1284enabled=True, deterministic=True, benchmark=False
1285):
1286device_ids = [rank] if is_gpu else None
1287# Set up the DDP model overlapping with ZeRO
1288ddp_model_overlap = DDP(
1289copy.deepcopy(model).to(device),
1290device_ids=device_ids,
1291gradient_as_bucket_view=gradient_as_bucket_view,
1292)
1293if static_graph:
1294ddp_model_overlap._set_static_graph()
1295zero_optim = ZeroRedundancyOptimizer(
1296ddp_model_overlap.parameters(),
1297optimizer_class=torch.optim.SGD,
1298overlap_with_ddp=True,
1299lr=SGD_LR,
1300momentum=SGD_MOMENTUM,
1301weight_decay=SGD_WEIGHT_DECAY,
1302)
1303ddp_model_overlap.register_comm_hook(
1304None,
1305hook_constructor(
1306allreduce_hook,
1307ddp_model_overlap,
1308zero_optim,
1309**kwargs,
1310),
1311)
1312
1313# Set up the DDP model with local optimizer
1314ddp_model_local = DDP(
1315copy.deepcopy(model).to(device),
1316device_ids=device_ids,
1317gradient_as_bucket_view=gradient_as_bucket_view,
1318)
1319if static_graph:
1320ddp_model_local._set_static_graph()
1321local_optim = torch.optim.SGD(
1322ddp_model_local.parameters(),
1323lr=SGD_LR,
1324momentum=SGD_MOMENTUM,
1325weight_decay=SGD_WEIGHT_DECAY,
1326)
1327
1328# Check that the parameters match initially
1329for p1, p2 in zip(
1330ddp_model_overlap.parameters(), ddp_model_local.parameters()
1331):
1332self.assertEqual(p1, p2)
1333
1334# Save the parameters to ensure they were updated
1335init_params_overlap = copy.deepcopy(
1336list(ddp_model_overlap.parameters())
1337)
1338
1339# Ensure that this test runs independently
1340dist.barrier()
1341
1342# Run the DDP model overlapping with ZeRO
1343# NOTE: Overlapping currently requires 2 or 3 warmup iterations
1344# to ensure DDP buckets have been rebuilt (depending on the
1345# value of `static_graph`)
1346num_warmup_inputs = 2 if not static_graph else 3
1347for input in inputs[:num_warmup_inputs]:
1348output = ddp_model_overlap(input)
1349loss = output.sum()
1350loss.backward()
1351for input in inputs:
1352zero_optim.zero_grad()
1353output = ddp_model_overlap(input)
1354loss = output.sum()
1355loss.backward()
1356
1357# Run the DDP model with local optimizer
1358for input in inputs:
1359local_optim.zero_grad()
1360output = ddp_model_local(input)
1361loss = output.sum()
1362loss.backward()
1363local_optim.step()
1364dist.barrier()
1365
1366# Check that the parameters are equal
1367for p1, p2 in zip(
1368ddp_model_overlap.parameters(), ddp_model_local.parameters()
1369):
1370self.assertEqual(p1, p2)
1371
1372# Check that the parameters were updated
1373self.assertNotEqual(
1374init_params_overlap,
1375list(ddp_model_overlap.parameters()),
1376)
1377
1378# Ensure that this test runs independently
1379dist.barrier()
1380
1381# NOTE: The test is skipped if using Windows since functional optimizers
1382# are not currently supported.
1383@common_distributed.skip_if_win32()
1384@common_distributed.requires_nccl()
1385@common_distributed.skip_if_no_gpu
1386@common_distributed.skip_if_rocm
1387@parametrize(
1388"use_gpu",
1389[True],
1390# Add `False` once the Gloo sync issue causing hangs is fixed
1391# See: https://github.com/pytorch/pytorch/issues/62300
1392)
1393@parametrize(
1394"use_interleaved_hook",
1395[False, True],
1396)
1397@parametrize(
1398"gradient_as_bucket_view",
1399[False, True],
1400)
1401@parametrize(
1402"static_graph",
1403[False, True],
1404)
1405@parametrize(
1406"shard_buckets",
1407[False, True],
1408)
1409def test_ddp_zero_overlap(
1410self,
1411use_gpu: bool,
1412use_interleaved_hook: bool,
1413gradient_as_bucket_view: bool,
1414static_graph: bool,
1415shard_buckets: bool,
1416):
1417"""
1418Check that overlapping DDP with ZeRO using the given method determined
1419by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
1420and DDP arguments achieves parity with DDP using a local optimizer.
1421"""
1422device = torch.device(self.rank) if use_gpu else torch.device("cpu")
1423backend = _get_backend_for_tests()
1424self.dist_init(self.rank, self.world_size, backend)
1425hook_constructor = (
1426hook_with_zero_step
1427if not use_interleaved_hook
1428else hook_with_zero_step_interleaved
1429)
1430
1431self._test_ddp_zero_overlap(
1432device,
1433hook_constructor,
1434gradient_as_bucket_view,
1435static_graph,
1436shard_buckets=shard_buckets,
1437)
1438
1439
1440instantiate_parametrized_tests(TestZeroRedundancyOptimizerSingleRank)
1441instantiate_parametrized_tests(TestZeroRedundancyOptimizerDistributed)
1442
1443if __name__ == "__main__":
1444# ! unittest should not be used here, else the tests are not properly registered
1445run_tests()
1446