pytorch-lightning
25 строк · 828.0 Байт
1from unittest import mock2
3import pytest4from lightning.fabric.plugins.collectives import SingleDeviceCollective5
6
7def test_can_instantiate_without_args():8SingleDeviceCollective()9
10
11def test_create_group():12collective = SingleDeviceCollective()13assert collective.is_initialized()14
15with pytest.raises(RuntimeError, match=r"SingleDeviceCollective` does not own a group"):16_ = collective.group17
18with mock.patch("lightning.fabric.plugins.collectives.single_device.SingleDeviceCollective.new_group") as new_mock:19collective.create_group(arg1=15, arg3=10)20
21group_kwargs = {"arg3": 10, "arg1": 15}22new_mock.assert_called_once_with(**group_kwargs)23
24with mock.patch("lightning.fabric.plugins.collectives.single_device.SingleDeviceCollective.destroy_group"):25collective.teardown()26