1
# Owner(s): ["module: autograd"]
9
from typing import Callable
12
from torch._utils_internal import get_file_path_2
13
from torch.testing._internal.common_utils import (
23
class TestPublicBindings(TestCase):
24
def test_no_new_reexport_callables(self):
26
This test aims to stop the introduction of new re-exported callables into
27
torch whose names do not start with _. Such callables are made available as
28
torch.XXX, which may not be desirable.
30
reexported_callables = sorted(
32
for k, v in vars(torch).items()
33
if callable(v) and not v.__module__.startswith("torch")
36
all(k.startswith("_") for k in reexported_callables), reexported_callables
39
def test_no_new_bindings(self):
41
This test aims to stop the introduction of new JIT bindings into torch._C
42
whose names do not start with _. Such bindings are made available as
43
torch.XXX, which may not be desirable.
45
If your change causes this test to fail, add your new binding to a relevant
46
submodule of torch._C, such as torch._C._jit (or other relevant submodule of
47
torch._C). If your binding really needs to be available as torch.XXX, add it
48
to torch._C and add it to the allowlist below.
50
If you have removed a binding, remove it from the allowlist as well.
53
# This allowlist contains every binding in torch._C that is copied into torch at
54
# the time of writing. It was generated with
56
# {elem for elem in dir(torch._C) if not elem.startswith("_")}
57
torch_C_allowlist_superset = {
64
"autocast_decrement_nesting",
65
"autocast_increment_nesting",
68
"BenchmarkExecutionStats",
76
"clear_autocast_cache",
79
"CompleteArgumentSpec",
82
"ConcreteModuleTypeBuilder",
84
"CudaBFloat16TensorBase",
88
"CudaComplexDoubleTensorBase",
89
"CudaComplexFloatTensorBase",
90
"CudaDoubleTensorBase",
91
"CudaFloatTensorBase",
95
"CudaShortTensorBase",
98
"DeserializationStorageContext",
102
"DisableTorchFunction",
103
"DisableTorchFunctionSubclass",
109
"ExcludeDispatchKeyGuard",
121
"get_autocast_cpu_dtype",
122
"get_autocast_dtype",
123
"get_autocast_ipu_dtype",
125
"get_num_interop_threads",
129
"GraphExecutorState",
139
"import_ir_module_from_buffer",
149
"is_anomaly_enabled",
150
"is_anomaly_check_nan_enabled",
151
"is_autocast_cache_enabled",
152
"is_autocast_cpu_enabled",
153
"is_autocast_ipu_enabled",
154
"is_autocast_enabled",
156
"is_inference_mode_enabled",
164
"merge_type_from_type_comment",
176
"parse_type_comment",
184
"ScriptClassFunction",
186
"ScriptDictIterator",
187
"ScriptDictKeyIterator",
189
"ScriptListIterator",
193
"ScriptModuleSerializer",
195
"ScriptObjectProperty",
196
"SerializationStorageContext",
197
"set_anomaly_enabled",
198
"set_autocast_cache_enabled",
199
"set_autocast_cpu_dtype",
200
"set_autocast_dtype",
201
"set_autocast_ipu_dtype",
202
"set_autocast_cpu_enabled",
203
"set_autocast_ipu_enabled",
204
"set_autocast_enabled",
205
"set_flush_denormal",
206
"set_num_interop_threads",
219
"ThroughputBenchmark",
227
"set_autocast_gpu_dtype",
228
"get_autocast_gpu_dtype",
232
"set_autocast_xla_enabled",
233
"set_autocast_xla_dtype",
234
"get_autocast_xla_dtype",
235
"is_autocast_xla_enabled",
238
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
240
# torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940)
241
explicitly_removed_torch_C_bindings = {"TensorBase"}
243
torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings
245
# Check that the torch._C bindings are all in the allowlist. Since
246
# bindings can change based on how PyTorch was compiled (e.g. with/without
247
# CUDA), the two may not be an exact match but the bindings should be
248
# a subset of the allowlist.
249
difference = torch_C_bindings.difference(torch_C_allowlist_superset)
250
msg = f"torch._C had bindings that are not present in the allowlist:\n{difference}"
251
self.assertTrue(torch_C_bindings.issubset(torch_C_allowlist_superset), msg)
254
def _is_mod_public(modname):
255
split_strs = modname.split(".")
256
for elem in split_strs:
257
if elem.startswith("_"):
262
IS_WINDOWS or IS_MACOS,
263
"Inductor/Distributed modules hard fail on windows and macos",
265
@skipIfTorchDynamo("Broken and not relevant for now")
266
def test_modules_can_be_imported(self):
269
def onerror(modname):
270
failures.append((modname, ImportError))
272
for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror):
275
# TODO: fix "torch/utils/model_dump/__main__.py"
276
# which calls sys.exit() when we try to import it
277
if "__main__" in modname:
279
importlib.import_module(modname)
280
except Exception as e:
281
# Some current failures are not ImportError
283
failures.append((modname, type(e)))
285
# It is ok to add new entries here but please be careful that these modules
286
# do not get imported by public code.
287
private_allowlist = {
288
"torch._inductor.codegen.cuda.cuda_kernel",
289
# TODO(#133647): Remove the onnx._internal entries after
290
# onnx and onnxscript are installed in CI.
291
"torch.onnx._internal.exporter",
292
"torch.onnx._internal.exporter._analysis",
293
"torch.onnx._internal.exporter._building",
294
"torch.onnx._internal.exporter._capture_strategies",
295
"torch.onnx._internal.exporter._compat",
296
"torch.onnx._internal.exporter._core",
297
"torch.onnx._internal.exporter._decomp",
298
"torch.onnx._internal.exporter._dispatching",
299
"torch.onnx._internal.exporter._fx_passes",
300
"torch.onnx._internal.exporter._ir_passes",
301
"torch.onnx._internal.exporter._isolated",
302
"torch.onnx._internal.exporter._onnx_program",
303
"torch.onnx._internal.exporter._registration",
304
"torch.onnx._internal.exporter._reporting",
305
"torch.onnx._internal.exporter._schemas",
306
"torch.onnx._internal.exporter._tensors",
307
"torch.onnx._internal.exporter._verification",
308
"torch.onnx._internal.fx._pass",
309
"torch.onnx._internal.fx.analysis",
310
"torch.onnx._internal.fx.analysis.unsupported_nodes",
311
"torch.onnx._internal.fx.decomposition_skip",
312
"torch.onnx._internal.fx.diagnostics",
313
"torch.onnx._internal.fx.fx_onnx_interpreter",
314
"torch.onnx._internal.fx.fx_symbolic_graph_extractor",
315
"torch.onnx._internal.fx.onnxfunction_dispatcher",
316
"torch.onnx._internal.fx.op_validation",
317
"torch.onnx._internal.fx.passes",
318
"torch.onnx._internal.fx.passes._utils",
319
"torch.onnx._internal.fx.passes.decomp",
320
"torch.onnx._internal.fx.passes.functionalization",
321
"torch.onnx._internal.fx.passes.modularization",
322
"torch.onnx._internal.fx.passes.readability",
323
"torch.onnx._internal.fx.passes.type_promotion",
324
"torch.onnx._internal.fx.passes.virtualization",
325
"torch.onnx._internal.fx.type_utils",
326
"torch.testing._internal.common_distributed",
327
"torch.testing._internal.common_fsdp",
328
"torch.testing._internal.dist_utils",
329
"torch.testing._internal.distributed.common_state_dict",
330
"torch.testing._internal.distributed._shard.sharded_tensor",
331
"torch.testing._internal.distributed._shard.test_common",
332
"torch.testing._internal.distributed._tensor.common_dtensor",
333
"torch.testing._internal.distributed.ddp_under_dist_autograd_test",
334
"torch.testing._internal.distributed.distributed_test",
335
"torch.testing._internal.distributed.distributed_utils",
336
"torch.testing._internal.distributed.fake_pg",
337
"torch.testing._internal.distributed.multi_threaded_pg",
338
"torch.testing._internal.distributed.nn.api.remote_module_test",
339
"torch.testing._internal.distributed.rpc.dist_autograd_test",
340
"torch.testing._internal.distributed.rpc.dist_optimizer_test",
341
"torch.testing._internal.distributed.rpc.examples.parameter_server_test",
342
"torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test",
343
"torch.testing._internal.distributed.rpc.faulty_agent_rpc_test",
344
"torch.testing._internal.distributed.rpc.faulty_rpc_agent_test_fixture",
345
"torch.testing._internal.distributed.rpc.jit.dist_autograd_test",
346
"torch.testing._internal.distributed.rpc.jit.rpc_test",
347
"torch.testing._internal.distributed.rpc.jit.rpc_test_faulty",
348
"torch.testing._internal.distributed.rpc.rpc_agent_test_fixture",
349
"torch.testing._internal.distributed.rpc.rpc_test",
350
"torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
351
"torch.testing._internal.distributed.rpc_utils",
352
"torch._inductor.codegen.cuda.cuda_template",
353
"torch._inductor.codegen.cuda.gemm_template",
354
"torch._inductor.codegen.cpp_template",
355
"torch._inductor.codegen.cpp_gemm_template",
356
"torch._inductor.codegen.cpp_micro_gemm",
357
"torch._inductor.codegen.cpp_template_kernel",
358
"torch._inductor.runtime.triton_helpers",
359
"torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity",
360
"torch.backends._coreml.preprocess",
361
"torch.contrib._tensorboard_vis",
362
"torch.distributed._composable",
363
"torch.distributed._functional_collectives",
364
"torch.distributed._functional_collectives_impl",
365
"torch.distributed._shard",
366
"torch.distributed._sharded_tensor",
367
"torch.distributed._sharding_spec",
368
"torch.distributed._spmd.api",
369
"torch.distributed._spmd.batch_dim_utils",
370
"torch.distributed._spmd.comm_tensor",
371
"torch.distributed._spmd.data_parallel",
372
"torch.distributed._spmd.distribute",
373
"torch.distributed._spmd.experimental_ops",
374
"torch.distributed._spmd.parallel_mode",
375
"torch.distributed._tensor",
376
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
377
"torch.distributed.algorithms._optimizer_overlap",
378
"torch.distributed.rpc._testing.faulty_agent_backend_registry",
379
"torch.distributed.rpc._utils",
380
"torch.ao.pruning._experimental.data_sparsifier.benchmarks.dlrm_utils",
381
"torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_disk_savings",
382
"torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_forward_time",
383
"torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_model_metrics",
384
"torch.ao.pruning._experimental.data_sparsifier.lightning.tests.test_callbacks",
385
"torch.csrc.jit.tensorexpr.scripts.bisect",
386
"torch.csrc.lazy.test_mnist",
387
"torch.distributed._shard.checkpoint._fsspec_filesystem",
388
"torch.distributed._tensor.examples.visualize_sharding_example",
389
"torch.distributed.checkpoint._fsspec_filesystem",
390
"torch.distributed.examples.memory_tracker_example",
391
"torch.testing._internal.distributed.rpc.fb.thrift_rpc_agent_test_fixture",
392
"torch.utils._cxx_pytree",
393
"torch.utils.tensorboard._convert_np",
394
"torch.utils.tensorboard._embedding",
395
"torch.utils.tensorboard._onnx_graph",
396
"torch.utils.tensorboard._proto_graph",
397
"torch.utils.tensorboard._pytorch_graph",
398
"torch.utils.tensorboard._utils",
401
# No new entries should be added to this list.
402
# All public modules should be importable on all platforms.
404
"torch.distributed.algorithms.ddp_comm_hooks",
405
"torch.distributed.algorithms.model_averaging.averagers",
406
"torch.distributed.algorithms.model_averaging.hierarchical_model_averager",
407
"torch.distributed.algorithms.model_averaging.utils",
408
"torch.distributed.checkpoint",
409
"torch.distributed.constants",
410
"torch.distributed.distributed_c10d",
411
"torch.distributed.elastic.agent.server",
412
"torch.distributed.elastic.rendezvous",
413
"torch.distributed.fsdp",
414
"torch.distributed.launch",
415
"torch.distributed.launcher",
416
"torch.distributed.nn",
417
"torch.distributed.nn.api.remote_module",
418
"torch.distributed.optim",
419
"torch.distributed.optim.optimizer",
420
"torch.distributed.rendezvous",
421
"torch.distributed.rpc.api",
422
"torch.distributed.rpc.backend_registry",
423
"torch.distributed.rpc.constants",
424
"torch.distributed.rpc.internal",
425
"torch.distributed.rpc.options",
426
"torch.distributed.rpc.rref_proxy",
427
"torch.distributed.elastic.rendezvous.etcd_rendezvous",
428
"torch.distributed.elastic.rendezvous.etcd_rendezvous_backend",
429
"torch.distributed.elastic.rendezvous.etcd_store",
430
"torch.distributed.rpc.server_process_global_profiler",
431
"torch.distributed.run",
432
"torch.distributed.tensor.parallel",
433
"torch.distributed.utils",
434
"torch.utils.tensorboard",
435
"torch.utils.tensorboard.summary",
436
"torch.utils.tensorboard.writer",
437
"torch.ao.quantization.experimental.fake_quantize",
438
"torch.ao.quantization.experimental.linear",
439
"torch.ao.quantization.experimental.observer",
440
"torch.ao.quantization.experimental.qconfig",
444
for mod, excep_type in failures:
445
if mod in public_allowlist:
446
# TODO: Ensure this is the right error type
449
if mod in private_allowlist:
451
errors.append(f"{mod} failed to import with error {excep_type}")
452
self.assertEqual("", "\n".join(errors))
454
# AttributeError: module 'torch.distributed' has no attribute '_shard'
455
@unittest.skipIf(IS_WINDOWS or IS_JETSON or IS_MACOS, "Distributed Attribute Error")
456
@skipIfTorchDynamo("Broken and not relevant for now")
457
def test_correct_module_names(self):
459
An API is considered public, if its `__module__` starts with `torch.`
460
and there is no name in `__module__` or the object itself that starts with "_".
461
Each public package should either:
462
- (preferred) Define `__all__` and all callables and classes in there must have their
463
`__module__` start with the current submodule's path. Things not in `__all__` should
464
NOT have their `__module__` start with the current submodule.
465
- (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their
466
`__module__` that start with the current submodule.
471
get_file_path_2(os.path.dirname(__file__), "allowlist_for_publicAPI.json")
473
# no new entries should be added to this allow_dict.
474
# New APIs must follow the public API guidelines.
476
allow_dict = json.load(json_file)
477
# Because we want minimal modifications to the `allowlist_for_publicAPI.json`,
478
# we are adding the entries for the migrated modules here from the original
481
for modname in allow_dict["being_migrated"]:
482
if modname in allow_dict:
483
allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[
487
def test_module(modname):
489
if "__main__" in modname:
491
mod = importlib.import_module(modname)
493
# It is ok to ignore here as we have a test above that ensures
494
# this should never happen
497
if not self._is_mod_public(modname):
499
# verifies that each public API has the correct module name and naming semantics
501
def check_one_element(elem, modname, mod, *, is_public, is_all):
502
obj = getattr(mod, elem)
504
# torch.dtype is not a class nor callable, so we need to check for it separately
506
isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)
509
elem_module = getattr(obj, "__module__", None)
511
# Only used for nice error message below
512
why_not_looks_public = ""
513
if elem_module is None:
514
why_not_looks_public = (
515
"because it does not have a `__module__` attribute"
518
# If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}),
519
# the module's starting package would be referred to as the new location even
520
# if there is a "from foo import a" inside the "bar.py".
521
modname = allow_dict["being_migrated"].get(modname, modname)
522
elem_modname_starts_with_mod = (
523
elem_module is not None
524
and elem_module.startswith(modname)
525
and "._" not in elem_module
527
if not why_not_looks_public and not elem_modname_starts_with_mod:
528
why_not_looks_public = (
529
f"because its `__module__` attribute (`{elem_module}`) is not within the "
530
f"torch library or does not start with the submodule where it is defined (`{modname}`)"
533
# elem's name must NOT begin with an `_` and it's module name
534
# SHOULD start with it's current module since it's a public API
535
looks_public = not elem.startswith("_") and elem_modname_starts_with_mod
536
if not why_not_looks_public and not looks_public:
537
why_not_looks_public = f"because it starts with `_` (`{elem}`)"
538
if is_public != looks_public:
539
if modname in allow_dict and elem in allow_dict[modname]:
543
f"it is inside the module's (`{modname}`) `__all__`"
545
else "it is an attribute that does not start with `_` on a module that "
546
"does not have `__all__` defined"
549
f"remove it from the modules's (`{modname}`) `__all__`"
551
else f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name"
556
f"it is not inside the module's (`{modname}`) `__all__`"
559
f"add it from the modules's (`{modname}`) `__all__`"
563
"it does look public because it follows the rules from the doc above "
564
"(does not start with `_` and has a proper `__module__`)."
566
fix_looks_public = "make its name start with `_`"
568
why_looks_public = why_not_looks_public
569
if not elem_modname_starts_with_mod:
571
"make sure the `__module__` is properly set and points to a submodule "
576
"remove the `_` at the beginning of the name"
578
failure_list.append(f"# {modname}.{elem}:")
579
is_public_str = "" if is_public else " NOT"
581
f" - Is{is_public_str} public: {why_is_public}"
583
looks_public_str = "" if looks_public else " NOT"
585
f" - Does{looks_public_str} look public: {why_looks_public}"
587
# Swap the str below to avoid having to create the NOT again
589
" - You can do either of these two things to fix this problem:"
592
f" - To make it{looks_public_str} public: {fix_is_public}"
595
f" - To make it{is_public_str} look public: {fix_looks_public}"
598
if hasattr(mod, "__all__"):
599
public_api = mod.__all__
603
elem, modname, mod, is_public=elem in public_api, is_all=True
608
if not elem.startswith("_"):
610
elem, modname, mod, is_public=True, is_all=False
613
for mod in pkgutil.walk_packages(torch.__path__, "torch."):
619
"All the APIs below do not meet our guidelines for public API from "
620
"https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.\n"
623
"Make sure that everything that is public is expected (in particular that the module "
624
"has a properly populated `__all__` attribute) and that everything that is supposed to be public "
625
"does look public (it does not start with `_` and has a `__module__` that is properly populated)."
628
msg += "\n\nFull list:\n"
629
msg += "\n".join(map(str, failure_list))
631
# empty lists are considered false in python
632
self.assertTrue(not failure_list, msg)
635
if __name__ == "__main__":