pytorch-lightning

Форк
0
25 строк · 828.0 Байт
1
from unittest import mock
2

3
import pytest
4
from lightning.fabric.plugins.collectives import SingleDeviceCollective
5

6

7
def test_can_instantiate_without_args():
8
    SingleDeviceCollective()
9

10

11
def test_create_group():
12
    collective = SingleDeviceCollective()
13
    assert collective.is_initialized()
14

15
    with pytest.raises(RuntimeError, match=r"SingleDeviceCollective` does not own a group"):
16
        _ = collective.group
17

18
    with mock.patch("lightning.fabric.plugins.collectives.single_device.SingleDeviceCollective.new_group") as new_mock:
19
        collective.create_group(arg1=15, arg3=10)
20

21
    group_kwargs = {"arg3": 10, "arg1": 15}
22
    new_mock.assert_called_once_with(**group_kwargs)
23

24
    with mock.patch("lightning.fabric.plugins.collectives.single_device.SingleDeviceCollective.destroy_group"):
25
        collective.teardown()
26

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.