pytorch

Форк
0
117 строк · 3.3 Кб
1
#!/usr/bin/env python3
2
# Owner(s): ["oncall: r2p"]
3

4
# Copyright (c) Facebook, Inc. and its affiliates.
5
# All rights reserved.
6
#
7
# This source code is licensed under the BSD-style license found in the
8
# LICENSE file in the root directory of this source tree.abs
9
import abc
10
import unittest.mock as mock
11

12
from torch.distributed.elastic.metrics.api import (
13
    _get_metric_name,
14
    MetricData,
15
    MetricHandler,
16
    MetricStream,
17
    prof,
18
)
19
from torch.testing._internal.common_utils import run_tests, TestCase
20

21

22
def foo_1():
23
    pass
24

25

26
class TestMetricsHandler(MetricHandler):
27
    def __init__(self) -> None:
28
        self.metric_data = {}
29

30
    def emit(self, metric_data: MetricData):
31
        self.metric_data[metric_data.name] = metric_data
32

33

34
class Parent(abc.ABC):
35
    @abc.abstractmethod
36
    def func(self):
37
        raise NotImplementedError
38

39
    def base_func(self):
40
        self.func()
41

42

43
class Child(Parent):
44
    # need to decorate the implementation not the abstract method!
45
    @prof
46
    def func(self):
47
        pass
48

49

50
class MetricsApiTest(TestCase):
51
    def foo_2(self):
52
        pass
53

54
    @prof
55
    def bar(self):
56
        pass
57

58
    @prof
59
    def throw(self):
60
        raise RuntimeError
61

62
    @prof(group="torchelastic")
63
    def bar2(self):
64
        pass
65

66
    def test_get_metric_name(self):
67
        # Note: since pytorch uses main method to launch tests,
68
        # the module will be different between fb and oss, this
69
        # allows keeping the module name consistent.
70
        foo_1.__module__ = "api_test"
71
        self.assertEqual("api_test.foo_1", _get_metric_name(foo_1))
72
        self.assertEqual("MetricsApiTest.foo_2", _get_metric_name(self.foo_2))
73

74
    def test_profile(self):
75
        handler = TestMetricsHandler()
76
        stream = MetricStream("torchelastic", handler)
77
        # patch instead of configure to avoid conflicts when running tests in parallel
78
        with mock.patch(
79
            "torch.distributed.elastic.metrics.api.getStream", return_value=stream
80
        ):
81
            self.bar()
82

83
            self.assertEqual(1, handler.metric_data["MetricsApiTest.bar.success"].value)
84
            self.assertNotIn("MetricsApiTest.bar.failure", handler.metric_data)
85
            self.assertIn("MetricsApiTest.bar.duration.ms", handler.metric_data)
86

87
            with self.assertRaises(RuntimeError):
88
                self.throw()
89

90
            self.assertEqual(
91
                1, handler.metric_data["MetricsApiTest.throw.failure"].value
92
            )
93
            self.assertNotIn("MetricsApiTest.bar_raise.success", handler.metric_data)
94
            self.assertIn("MetricsApiTest.throw.duration.ms", handler.metric_data)
95

96
            self.bar2()
97
            self.assertEqual(
98
                "torchelastic",
99
                handler.metric_data["MetricsApiTest.bar2.success"].group_name,
100
            )
101

102
    def test_inheritance(self):
103
        handler = TestMetricsHandler()
104
        stream = MetricStream("torchelastic", handler)
105
        # patch instead of configure to avoid conflicts when running tests in parallel
106
        with mock.patch(
107
            "torch.distributed.elastic.metrics.api.getStream", return_value=stream
108
        ):
109
            c = Child()
110
            c.base_func()
111

112
            self.assertEqual(1, handler.metric_data["Child.func.success"].value)
113
            self.assertIn("Child.func.duration.ms", handler.metric_data)
114

115

116
if __name__ == "__main__":
117
    run_tests()
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.