pytorch

Форк
0
/
test_public_bindings.py 
636 строк · 27.0 Кб
1
# Owner(s): ["module: autograd"]
2

3
import importlib
4
import inspect
5
import json
6
import os
7
import pkgutil
8
import unittest
9
from typing import Callable
10

11
import torch
12
from torch._utils_internal import get_file_path_2
13
from torch.testing._internal.common_utils import (
14
    IS_JETSON,
15
    IS_MACOS,
16
    IS_WINDOWS,
17
    run_tests,
18
    skipIfTorchDynamo,
19
    TestCase,
20
)
21

22

23
class TestPublicBindings(TestCase):
24
    def test_no_new_reexport_callables(self):
25
        """
26
        This test aims to stop the introduction of new re-exported callables into
27
        torch whose names do not start with _. Such callables are made available as
28
        torch.XXX, which may not be desirable.
29
        """
30
        reexported_callables = sorted(
31
            k
32
            for k, v in vars(torch).items()
33
            if callable(v) and not v.__module__.startswith("torch")
34
        )
35
        self.assertTrue(
36
            all(k.startswith("_") for k in reexported_callables), reexported_callables
37
        )
38

39
    def test_no_new_bindings(self):
40
        """
41
        This test aims to stop the introduction of new JIT bindings into torch._C
42
        whose names do not start with _. Such bindings are made available as
43
        torch.XXX, which may not be desirable.
44

45
        If your change causes this test to fail, add your new binding to a relevant
46
        submodule of torch._C, such as torch._C._jit (or other relevant submodule of
47
        torch._C). If your binding really needs to be available as torch.XXX, add it
48
        to torch._C and add it to the allowlist below.
49

50
        If you have removed a binding, remove it from the allowlist as well.
51
        """
52

53
        # This allowlist contains every binding in torch._C that is copied into torch at
54
        # the time of writing. It was generated with
55
        #
56
        #   {elem for elem in dir(torch._C) if not elem.startswith("_")}
57
        torch_C_allowlist_superset = {
58
            "AggregationType",
59
            "AliasDb",
60
            "AnyType",
61
            "Argument",
62
            "ArgumentSpec",
63
            "AwaitType",
64
            "autocast_decrement_nesting",
65
            "autocast_increment_nesting",
66
            "AVG",
67
            "BenchmarkConfig",
68
            "BenchmarkExecutionStats",
69
            "Block",
70
            "BoolType",
71
            "BufferDict",
72
            "StorageBase",
73
            "CallStack",
74
            "Capsule",
75
            "ClassType",
76
            "clear_autocast_cache",
77
            "Code",
78
            "CompilationUnit",
79
            "CompleteArgumentSpec",
80
            "ComplexType",
81
            "ConcreteModuleType",
82
            "ConcreteModuleTypeBuilder",
83
            "cpp",
84
            "CudaBFloat16TensorBase",
85
            "CudaBoolTensorBase",
86
            "CudaByteTensorBase",
87
            "CudaCharTensorBase",
88
            "CudaComplexDoubleTensorBase",
89
            "CudaComplexFloatTensorBase",
90
            "CudaDoubleTensorBase",
91
            "CudaFloatTensorBase",
92
            "CudaHalfTensorBase",
93
            "CudaIntTensorBase",
94
            "CudaLongTensorBase",
95
            "CudaShortTensorBase",
96
            "DeepCopyMemoTable",
97
            "default_generator",
98
            "DeserializationStorageContext",
99
            "device",
100
            "DeviceObjType",
101
            "DictType",
102
            "DisableTorchFunction",
103
            "DisableTorchFunctionSubclass",
104
            "DispatchKey",
105
            "DispatchKeySet",
106
            "dtype",
107
            "EnumType",
108
            "ErrorReport",
109
            "ExcludeDispatchKeyGuard",
110
            "ExecutionPlan",
111
            "FatalError",
112
            "FileCheck",
113
            "finfo",
114
            "FloatType",
115
            "fork",
116
            "FunctionSchema",
117
            "Future",
118
            "FutureType",
119
            "Generator",
120
            "GeneratorType",
121
            "get_autocast_cpu_dtype",
122
            "get_autocast_dtype",
123
            "get_autocast_ipu_dtype",
124
            "get_default_dtype",
125
            "get_num_interop_threads",
126
            "get_num_threads",
127
            "Gradient",
128
            "Graph",
129
            "GraphExecutorState",
130
            "has_cuda",
131
            "has_cudnn",
132
            "has_lapack",
133
            "has_mkl",
134
            "has_mkldnn",
135
            "has_mps",
136
            "has_openmp",
137
            "has_spectral",
138
            "iinfo",
139
            "import_ir_module_from_buffer",
140
            "import_ir_module",
141
            "InferredType",
142
            "init_num_threads",
143
            "InterfaceType",
144
            "IntType",
145
            "SymFloatType",
146
            "SymBoolType",
147
            "SymIntType",
148
            "IODescriptor",
149
            "is_anomaly_enabled",
150
            "is_anomaly_check_nan_enabled",
151
            "is_autocast_cache_enabled",
152
            "is_autocast_cpu_enabled",
153
            "is_autocast_ipu_enabled",
154
            "is_autocast_enabled",
155
            "is_grad_enabled",
156
            "is_inference_mode_enabled",
157
            "JITException",
158
            "layout",
159
            "ListType",
160
            "LiteScriptModule",
161
            "LockingLogger",
162
            "LoggerBase",
163
            "memory_format",
164
            "merge_type_from_type_comment",
165
            "ModuleDict",
166
            "Node",
167
            "NoneType",
168
            "NoopLogger",
169
            "NumberType",
170
            "OperatorInfo",
171
            "OptionalType",
172
            "OutOfMemoryError",
173
            "ParameterDict",
174
            "parse_ir",
175
            "parse_schema",
176
            "parse_type_comment",
177
            "PyObjectType",
178
            "PyTorchFileReader",
179
            "PyTorchFileWriter",
180
            "qscheme",
181
            "read_vitals",
182
            "RRefType",
183
            "ScriptClass",
184
            "ScriptClassFunction",
185
            "ScriptDict",
186
            "ScriptDictIterator",
187
            "ScriptDictKeyIterator",
188
            "ScriptList",
189
            "ScriptListIterator",
190
            "ScriptFunction",
191
            "ScriptMethod",
192
            "ScriptModule",
193
            "ScriptModuleSerializer",
194
            "ScriptObject",
195
            "ScriptObjectProperty",
196
            "SerializationStorageContext",
197
            "set_anomaly_enabled",
198
            "set_autocast_cache_enabled",
199
            "set_autocast_cpu_dtype",
200
            "set_autocast_dtype",
201
            "set_autocast_ipu_dtype",
202
            "set_autocast_cpu_enabled",
203
            "set_autocast_ipu_enabled",
204
            "set_autocast_enabled",
205
            "set_flush_denormal",
206
            "set_num_interop_threads",
207
            "set_num_threads",
208
            "set_vital",
209
            "Size",
210
            "StaticModule",
211
            "Stream",
212
            "StreamObjType",
213
            "Event",
214
            "StringType",
215
            "SUM",
216
            "SymFloat",
217
            "SymInt",
218
            "TensorType",
219
            "ThroughputBenchmark",
220
            "TracingState",
221
            "TupleType",
222
            "Type",
223
            "unify_type_list",
224
            "UnionType",
225
            "Use",
226
            "Value",
227
            "set_autocast_gpu_dtype",
228
            "get_autocast_gpu_dtype",
229
            "vitals_enabled",
230
            "wait",
231
            "Tag",
232
            "set_autocast_xla_enabled",
233
            "set_autocast_xla_dtype",
234
            "get_autocast_xla_dtype",
235
            "is_autocast_xla_enabled",
236
        }
237

238
        torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
239

240
        # torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940)
241
        explicitly_removed_torch_C_bindings = {"TensorBase"}
242

243
        torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings
244

245
        # Check that the torch._C bindings are all in the allowlist. Since
246
        # bindings can change based on how PyTorch was compiled (e.g. with/without
247
        # CUDA), the two may not be an exact match but the bindings should be
248
        # a subset of the allowlist.
249
        difference = torch_C_bindings.difference(torch_C_allowlist_superset)
250
        msg = f"torch._C had bindings that are not present in the allowlist:\n{difference}"
251
        self.assertTrue(torch_C_bindings.issubset(torch_C_allowlist_superset), msg)
252

253
    @staticmethod
254
    def _is_mod_public(modname):
255
        split_strs = modname.split(".")
256
        for elem in split_strs:
257
            if elem.startswith("_"):
258
                return False
259
        return True
260

261
    @unittest.skipIf(
262
        IS_WINDOWS or IS_MACOS,
263
        "Inductor/Distributed modules hard fail on windows and macos",
264
    )
265
    @skipIfTorchDynamo("Broken and not relevant for now")
266
    def test_modules_can_be_imported(self):
267
        failures = []
268

269
        def onerror(modname):
270
            failures.append((modname, ImportError))
271

272
        for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror):
273
            modname = mod.name
274
            try:
275
                # TODO: fix "torch/utils/model_dump/__main__.py"
276
                # which calls sys.exit() when we try to import it
277
                if "__main__" in modname:
278
                    continue
279
                importlib.import_module(modname)
280
            except Exception as e:
281
                # Some current failures are not ImportError
282

283
                failures.append((modname, type(e)))
284

285
        # It is ok to add new entries here but please be careful that these modules
286
        # do not get imported by public code.
287
        private_allowlist = {
288
            "torch._inductor.codegen.cuda.cuda_kernel",
289
            # TODO(#133647): Remove the onnx._internal entries after
290
            # onnx and onnxscript are installed in CI.
291
            "torch.onnx._internal.exporter",
292
            "torch.onnx._internal.exporter._analysis",
293
            "torch.onnx._internal.exporter._building",
294
            "torch.onnx._internal.exporter._capture_strategies",
295
            "torch.onnx._internal.exporter._compat",
296
            "torch.onnx._internal.exporter._core",
297
            "torch.onnx._internal.exporter._decomp",
298
            "torch.onnx._internal.exporter._dispatching",
299
            "torch.onnx._internal.exporter._fx_passes",
300
            "torch.onnx._internal.exporter._ir_passes",
301
            "torch.onnx._internal.exporter._isolated",
302
            "torch.onnx._internal.exporter._onnx_program",
303
            "torch.onnx._internal.exporter._registration",
304
            "torch.onnx._internal.exporter._reporting",
305
            "torch.onnx._internal.exporter._schemas",
306
            "torch.onnx._internal.exporter._tensors",
307
            "torch.onnx._internal.exporter._verification",
308
            "torch.onnx._internal.fx._pass",
309
            "torch.onnx._internal.fx.analysis",
310
            "torch.onnx._internal.fx.analysis.unsupported_nodes",
311
            "torch.onnx._internal.fx.decomposition_skip",
312
            "torch.onnx._internal.fx.diagnostics",
313
            "torch.onnx._internal.fx.fx_onnx_interpreter",
314
            "torch.onnx._internal.fx.fx_symbolic_graph_extractor",
315
            "torch.onnx._internal.fx.onnxfunction_dispatcher",
316
            "torch.onnx._internal.fx.op_validation",
317
            "torch.onnx._internal.fx.passes",
318
            "torch.onnx._internal.fx.passes._utils",
319
            "torch.onnx._internal.fx.passes.decomp",
320
            "torch.onnx._internal.fx.passes.functionalization",
321
            "torch.onnx._internal.fx.passes.modularization",
322
            "torch.onnx._internal.fx.passes.readability",
323
            "torch.onnx._internal.fx.passes.type_promotion",
324
            "torch.onnx._internal.fx.passes.virtualization",
325
            "torch.onnx._internal.fx.type_utils",
326
            "torch.testing._internal.common_distributed",
327
            "torch.testing._internal.common_fsdp",
328
            "torch.testing._internal.dist_utils",
329
            "torch.testing._internal.distributed.common_state_dict",
330
            "torch.testing._internal.distributed._shard.sharded_tensor",
331
            "torch.testing._internal.distributed._shard.test_common",
332
            "torch.testing._internal.distributed._tensor.common_dtensor",
333
            "torch.testing._internal.distributed.ddp_under_dist_autograd_test",
334
            "torch.testing._internal.distributed.distributed_test",
335
            "torch.testing._internal.distributed.distributed_utils",
336
            "torch.testing._internal.distributed.fake_pg",
337
            "torch.testing._internal.distributed.multi_threaded_pg",
338
            "torch.testing._internal.distributed.nn.api.remote_module_test",
339
            "torch.testing._internal.distributed.rpc.dist_autograd_test",
340
            "torch.testing._internal.distributed.rpc.dist_optimizer_test",
341
            "torch.testing._internal.distributed.rpc.examples.parameter_server_test",
342
            "torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test",
343
            "torch.testing._internal.distributed.rpc.faulty_agent_rpc_test",
344
            "torch.testing._internal.distributed.rpc.faulty_rpc_agent_test_fixture",
345
            "torch.testing._internal.distributed.rpc.jit.dist_autograd_test",
346
            "torch.testing._internal.distributed.rpc.jit.rpc_test",
347
            "torch.testing._internal.distributed.rpc.jit.rpc_test_faulty",
348
            "torch.testing._internal.distributed.rpc.rpc_agent_test_fixture",
349
            "torch.testing._internal.distributed.rpc.rpc_test",
350
            "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
351
            "torch.testing._internal.distributed.rpc_utils",
352
            "torch._inductor.codegen.cuda.cuda_template",
353
            "torch._inductor.codegen.cuda.gemm_template",
354
            "torch._inductor.codegen.cpp_template",
355
            "torch._inductor.codegen.cpp_gemm_template",
356
            "torch._inductor.codegen.cpp_micro_gemm",
357
            "torch._inductor.codegen.cpp_template_kernel",
358
            "torch._inductor.runtime.triton_helpers",
359
            "torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity",
360
            "torch.backends._coreml.preprocess",
361
            "torch.contrib._tensorboard_vis",
362
            "torch.distributed._composable",
363
            "torch.distributed._functional_collectives",
364
            "torch.distributed._functional_collectives_impl",
365
            "torch.distributed._shard",
366
            "torch.distributed._sharded_tensor",
367
            "torch.distributed._sharding_spec",
368
            "torch.distributed._spmd.api",
369
            "torch.distributed._spmd.batch_dim_utils",
370
            "torch.distributed._spmd.comm_tensor",
371
            "torch.distributed._spmd.data_parallel",
372
            "torch.distributed._spmd.distribute",
373
            "torch.distributed._spmd.experimental_ops",
374
            "torch.distributed._spmd.parallel_mode",
375
            "torch.distributed._tensor",
376
            "torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
377
            "torch.distributed.algorithms._optimizer_overlap",
378
            "torch.distributed.rpc._testing.faulty_agent_backend_registry",
379
            "torch.distributed.rpc._utils",
380
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.dlrm_utils",
381
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_disk_savings",
382
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_forward_time",
383
            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_model_metrics",
384
            "torch.ao.pruning._experimental.data_sparsifier.lightning.tests.test_callbacks",
385
            "torch.csrc.jit.tensorexpr.scripts.bisect",
386
            "torch.csrc.lazy.test_mnist",
387
            "torch.distributed._shard.checkpoint._fsspec_filesystem",
388
            "torch.distributed._tensor.examples.visualize_sharding_example",
389
            "torch.distributed.checkpoint._fsspec_filesystem",
390
            "torch.distributed.examples.memory_tracker_example",
391
            "torch.testing._internal.distributed.rpc.fb.thrift_rpc_agent_test_fixture",
392
            "torch.utils._cxx_pytree",
393
            "torch.utils.tensorboard._convert_np",
394
            "torch.utils.tensorboard._embedding",
395
            "torch.utils.tensorboard._onnx_graph",
396
            "torch.utils.tensorboard._proto_graph",
397
            "torch.utils.tensorboard._pytorch_graph",
398
            "torch.utils.tensorboard._utils",
399
        }
400

401
        # No new entries should be added to this list.
402
        # All public modules should be importable on all platforms.
403
        public_allowlist = {
404
            "torch.distributed.algorithms.ddp_comm_hooks",
405
            "torch.distributed.algorithms.model_averaging.averagers",
406
            "torch.distributed.algorithms.model_averaging.hierarchical_model_averager",
407
            "torch.distributed.algorithms.model_averaging.utils",
408
            "torch.distributed.checkpoint",
409
            "torch.distributed.constants",
410
            "torch.distributed.distributed_c10d",
411
            "torch.distributed.elastic.agent.server",
412
            "torch.distributed.elastic.rendezvous",
413
            "torch.distributed.fsdp",
414
            "torch.distributed.launch",
415
            "torch.distributed.launcher",
416
            "torch.distributed.nn",
417
            "torch.distributed.nn.api.remote_module",
418
            "torch.distributed.optim",
419
            "torch.distributed.optim.optimizer",
420
            "torch.distributed.rendezvous",
421
            "torch.distributed.rpc.api",
422
            "torch.distributed.rpc.backend_registry",
423
            "torch.distributed.rpc.constants",
424
            "torch.distributed.rpc.internal",
425
            "torch.distributed.rpc.options",
426
            "torch.distributed.rpc.rref_proxy",
427
            "torch.distributed.elastic.rendezvous.etcd_rendezvous",
428
            "torch.distributed.elastic.rendezvous.etcd_rendezvous_backend",
429
            "torch.distributed.elastic.rendezvous.etcd_store",
430
            "torch.distributed.rpc.server_process_global_profiler",
431
            "torch.distributed.run",
432
            "torch.distributed.tensor.parallel",
433
            "torch.distributed.utils",
434
            "torch.utils.tensorboard",
435
            "torch.utils.tensorboard.summary",
436
            "torch.utils.tensorboard.writer",
437
            "torch.ao.quantization.experimental.fake_quantize",
438
            "torch.ao.quantization.experimental.linear",
439
            "torch.ao.quantization.experimental.observer",
440
            "torch.ao.quantization.experimental.qconfig",
441
        }
442

443
        errors = []
444
        for mod, excep_type in failures:
445
            if mod in public_allowlist:
446
                # TODO: Ensure this is the right error type
447

448
                continue
449
            if mod in private_allowlist:
450
                continue
451
            errors.append(f"{mod} failed to import with error {excep_type}")
452
        self.assertEqual("", "\n".join(errors))
453

454
    # AttributeError: module 'torch.distributed' has no attribute '_shard'
455
    @unittest.skipIf(IS_WINDOWS or IS_JETSON or IS_MACOS, "Distributed Attribute Error")
456
    @skipIfTorchDynamo("Broken and not relevant for now")
457
    def test_correct_module_names(self):
458
        """
459
        An API is considered public, if  its  `__module__` starts with `torch.`
460
        and there is no name in `__module__` or the object itself that starts with "_".
461
        Each public package should either:
462
        - (preferred) Define `__all__` and all callables and classes in there must have their
463
         `__module__` start with the current submodule's path. Things not in `__all__` should
464
          NOT have their `__module__` start with the current submodule.
465
        - (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their
466
          `__module__` that start with the current submodule.
467
        """
468

469
        failure_list = []
470
        with open(
471
            get_file_path_2(os.path.dirname(__file__), "allowlist_for_publicAPI.json")
472
        ) as json_file:
473
            # no new entries should be added to this allow_dict.
474
            # New APIs must follow the public API guidelines.
475

476
            allow_dict = json.load(json_file)
477
            # Because we want minimal modifications to the `allowlist_for_publicAPI.json`,
478
            # we are adding the entries for the migrated modules here from the original
479
            # locations.
480

481
            for modname in allow_dict["being_migrated"]:
482
                if modname in allow_dict:
483
                    allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[
484
                        modname
485
                    ]
486

487
        def test_module(modname):
488
            try:
489
                if "__main__" in modname:
490
                    return
491
                mod = importlib.import_module(modname)
492
            except Exception:
493
                # It is ok to ignore here as we have a test above that ensures
494
                # this should never happen
495

496
                return
497
            if not self._is_mod_public(modname):
498
                return
499
            # verifies that each public API has the correct module name and naming semantics
500

501
            def check_one_element(elem, modname, mod, *, is_public, is_all):
502
                obj = getattr(mod, elem)
503

504
                # torch.dtype is not a class nor callable, so we need to check for it separately
505
                if not (
506
                    isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)
507
                ):
508
                    return
509
                elem_module = getattr(obj, "__module__", None)
510

511
                # Only used for nice error message below
512
                why_not_looks_public = ""
513
                if elem_module is None:
514
                    why_not_looks_public = (
515
                        "because it does not have a `__module__` attribute"
516
                    )
517

518
                # If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}),
519
                # the module's starting package would be referred to as the new location even
520
                # if there is a "from foo import a" inside the "bar.py".
521
                modname = allow_dict["being_migrated"].get(modname, modname)
522
                elem_modname_starts_with_mod = (
523
                    elem_module is not None
524
                    and elem_module.startswith(modname)
525
                    and "._" not in elem_module
526
                )
527
                if not why_not_looks_public and not elem_modname_starts_with_mod:
528
                    why_not_looks_public = (
529
                        f"because its `__module__` attribute (`{elem_module}`) is not within the "
530
                        f"torch library or does not start with the submodule where it is defined (`{modname}`)"
531
                    )
532

533
                # elem's name must NOT begin with an `_` and it's module name
534
                # SHOULD start with it's current module since it's a public API
535
                looks_public = not elem.startswith("_") and elem_modname_starts_with_mod
536
                if not why_not_looks_public and not looks_public:
537
                    why_not_looks_public = f"because it starts with `_` (`{elem}`)"
538
                if is_public != looks_public:
539
                    if modname in allow_dict and elem in allow_dict[modname]:
540
                        return
541
                    if is_public:
542
                        why_is_public = (
543
                            f"it is inside the module's (`{modname}`) `__all__`"
544
                            if is_all
545
                            else "it is an attribute that does not start with `_` on a module that "
546
                            "does not have `__all__` defined"
547
                        )
548
                        fix_is_public = (
549
                            f"remove it from the modules's (`{modname}`) `__all__`"
550
                            if is_all
551
                            else f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name"
552
                        )
553
                    else:
554
                        assert is_all
555
                        why_is_public = (
556
                            f"it is not inside the module's (`{modname}`) `__all__`"
557
                        )
558
                        fix_is_public = (
559
                            f"add it from the modules's (`{modname}`) `__all__`"
560
                        )
561
                    if looks_public:
562
                        why_looks_public = (
563
                            "it does look public because it follows the rules from the doc above "
564
                            "(does not start with `_` and has a proper `__module__`)."
565
                        )
566
                        fix_looks_public = "make its name start with `_`"
567
                    else:
568
                        why_looks_public = why_not_looks_public
569
                        if not elem_modname_starts_with_mod:
570
                            fix_looks_public = (
571
                                "make sure the `__module__` is properly set and points to a submodule "
572
                                f"of `{modname}`"
573
                            )
574
                        else:
575
                            fix_looks_public = (
576
                                "remove the `_` at the beginning of the name"
577
                            )
578
                    failure_list.append(f"# {modname}.{elem}:")
579
                    is_public_str = "" if is_public else " NOT"
580
                    failure_list.append(
581
                        f"  - Is{is_public_str} public: {why_is_public}"
582
                    )
583
                    looks_public_str = "" if looks_public else " NOT"
584
                    failure_list.append(
585
                        f"  - Does{looks_public_str} look public: {why_looks_public}"
586
                    )
587
                    # Swap the str below to avoid having to create the NOT again
588
                    failure_list.append(
589
                        "  - You can do either of these two things to fix this problem:"
590
                    )
591
                    failure_list.append(
592
                        f"    - To make it{looks_public_str} public: {fix_is_public}"
593
                    )
594
                    failure_list.append(
595
                        f"    - To make it{is_public_str} look public: {fix_looks_public}"
596
                    )
597

598
            if hasattr(mod, "__all__"):
599
                public_api = mod.__all__
600
                all_api = dir(mod)
601
                for elem in all_api:
602
                    check_one_element(
603
                        elem, modname, mod, is_public=elem in public_api, is_all=True
604
                    )
605
            else:
606
                all_api = dir(mod)
607
                for elem in all_api:
608
                    if not elem.startswith("_"):
609
                        check_one_element(
610
                            elem, modname, mod, is_public=True, is_all=False
611
                        )
612

613
        for mod in pkgutil.walk_packages(torch.__path__, "torch."):
614
            modname = mod.name
615
            test_module(modname)
616
        test_module("torch")
617

618
        msg = (
619
            "All the APIs below do not meet our guidelines for public API from "
620
            "https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.\n"
621
        )
622
        msg += (
623
            "Make sure that everything that is public is expected (in particular that the module "
624
            "has a properly populated `__all__` attribute) and that everything that is supposed to be public "
625
            "does look public (it does not start with `_` and has a `__module__` that is properly populated)."
626
        )
627

628
        msg += "\n\nFull list:\n"
629
        msg += "\n".join(map(str, failure_list))
630

631
        # empty lists are considered false in python
632
        self.assertTrue(not failure_list, msg)
633

634

635
if __name__ == "__main__":
636
    run_tests()
637

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.