pytorch
117 строк · 3.3 Кб
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.abs
9import abc10import unittest.mock as mock11
12from torch.distributed.elastic.metrics.api import (13_get_metric_name,14MetricData,15MetricHandler,16MetricStream,17prof,18)
19from torch.testing._internal.common_utils import run_tests, TestCase20
21
22def foo_1():23pass24
25
26class TestMetricsHandler(MetricHandler):27def __init__(self) -> None:28self.metric_data = {}29
30def emit(self, metric_data: MetricData):31self.metric_data[metric_data.name] = metric_data32
33
34class Parent(abc.ABC):35@abc.abstractmethod36def func(self):37raise NotImplementedError38
39def base_func(self):40self.func()41
42
43class Child(Parent):44# need to decorate the implementation not the abstract method!45@prof46def func(self):47pass48
49
50class MetricsApiTest(TestCase):51def foo_2(self):52pass53
54@prof55def bar(self):56pass57
58@prof59def throw(self):60raise RuntimeError61
62@prof(group="torchelastic")63def bar2(self):64pass65
66def test_get_metric_name(self):67# Note: since pytorch uses main method to launch tests,68# the module will be different between fb and oss, this69# allows keeping the module name consistent.70foo_1.__module__ = "api_test"71self.assertEqual("api_test.foo_1", _get_metric_name(foo_1))72self.assertEqual("MetricsApiTest.foo_2", _get_metric_name(self.foo_2))73
74def test_profile(self):75handler = TestMetricsHandler()76stream = MetricStream("torchelastic", handler)77# patch instead of configure to avoid conflicts when running tests in parallel78with mock.patch(79"torch.distributed.elastic.metrics.api.getStream", return_value=stream80):81self.bar()82
83self.assertEqual(1, handler.metric_data["MetricsApiTest.bar.success"].value)84self.assertNotIn("MetricsApiTest.bar.failure", handler.metric_data)85self.assertIn("MetricsApiTest.bar.duration.ms", handler.metric_data)86
87with self.assertRaises(RuntimeError):88self.throw()89
90self.assertEqual(911, handler.metric_data["MetricsApiTest.throw.failure"].value92)93self.assertNotIn("MetricsApiTest.bar_raise.success", handler.metric_data)94self.assertIn("MetricsApiTest.throw.duration.ms", handler.metric_data)95
96self.bar2()97self.assertEqual(98"torchelastic",99handler.metric_data["MetricsApiTest.bar2.success"].group_name,100)101
102def test_inheritance(self):103handler = TestMetricsHandler()104stream = MetricStream("torchelastic", handler)105# patch instead of configure to avoid conflicts when running tests in parallel106with mock.patch(107"torch.distributed.elastic.metrics.api.getStream", return_value=stream108):109c = Child()110c.base_func()111
112self.assertEqual(1, handler.metric_data["Child.func.success"].value)113self.assertIn("Child.func.duration.ms", handler.metric_data)114
115
116if __name__ == "__main__":117run_tests()118