pytorch
553 строки · 24.1 Кб
1# Copyright (c) Meta Platforms, Inc. and affiliates
2import logging3import math4from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union5
6import torch7
8from torch.distributed import is_available9
10from ..utils._typing_utils import not_none11
12__all__ = ["init_device_mesh", "DeviceMesh"]13
14
15if not is_available():16import sys17
18# We need to create the stubs when distributed is not available.19# Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),20# since it would try to import ``torch.distributed.device_mesh`` or21# ``torch.distributed.init_device_mesh`` but cannot find them.22
23class _DeviceMeshStub:24pass25
26def _init_device_mesh_stub():27pass28
29sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]30sys.modules[31"torch.distributed.device_mesh"32].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]33
34
35else:36from torch.distributed.distributed_c10d import (37_find_pg_by_ranks_and_tag,38_get_default_group,39_get_group_tag,40get_rank,41get_world_size,42init_process_group,43is_initialized,44new_group,45ProcessGroup,46)47
48logger = logging.getLogger(__name__)49
50# only import numpy typing when type checking51if TYPE_CHECKING:52try:53from numpy.typing import ArrayLike54except ImportError:55logger.warning(56"DeviceMesh requires numpy >= 1.21 to be installed for type checking"57)58
59class _MeshEnv:60def __init__(self) -> None:61self.mesh_stack: List[DeviceMesh] = []62self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}63
64def get_current_mesh(self) -> "DeviceMesh":65if len(self.mesh_stack) == 0:66raise RuntimeError("No device mesh is currently active!")67return self.mesh_stack[-1]68
69def create_child_mesh(70self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str71) -> "DeviceMesh":72# swap the current dim to the last dim then reshape to flatten out other73# dims, so we can just extract the list of ranks which contains cur_rank.74cur_rank = device_mesh.get_rank()75pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(76-1, device_mesh.mesh.size(mesh_dim)77)78
79for mesh_1d in pg_ranks_by_dim:80sub_mesh = DeviceMesh(81device_mesh.device_type,82mesh_1d,83mesh_dim_names=(mesh_dim_name,),84)85if cur_rank in mesh_1d:86res_sub_mesh = sub_mesh87
88res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined]89# Assign the current DeviceMesh as the parent of the child DeviceMesh.90self.child_to_parent_mapping[res_sub_mesh] = device_mesh91return res_sub_mesh92
93def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:94return self.child_to_parent_mapping.get(device_mesh, None)95
96def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:97"""98Return the index of the mesh dim in the parent mesh.
99The device_mesh passed in needs to be sliced out from a parent mesh.
100"""
101parent_mesh = self.get_parent_mesh(device_mesh)102child_mesh_dim_names = device_mesh.mesh_dim_names103if parent_mesh and child_mesh_dim_names:104assert (105len(child_mesh_dim_names) == 1106), "The child mesh can only be a 1D mesh."107child_mesh_dim_name = child_mesh_dim_names[0]108return self.get_mesh_dim_by_name(parent_mesh, child_mesh_dim_name)109return None110
111@staticmethod112def num_devices_per_host(device_type: str) -> int:113return _get_device_handle(device_type).device_count()114
115@staticmethod116def num_hosts(device_type: str) -> int:117# ProcessGroup can't tell us this info so we have to infer it, assume118# homogeneous hardware for now119return get_world_size() // _MeshEnv.num_devices_per_host(device_type)120
121def get_mesh_dim_by_name(122self, device_mesh: "DeviceMesh", mesh_dim_name: str123) -> int:124if (125device_mesh.mesh_dim_names is None126or len(device_mesh.mesh_dim_names) == 0127):128raise KeyError(129"No `mesh_dim_names` found.",130)131if mesh_dim_name not in device_mesh.mesh_dim_names:132raise KeyError(133f"Mesh dimension '{mesh_dim_name}' does not exist.",134f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}",135)136return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))137
138_mesh_resources: _MeshEnv = _MeshEnv()139
140def _get_device_handle(device_type: str = "cuda"):141"""142Get the module corresponding to the device_type which is cuda or cuda-like device.
143For example, when the device_type is cuda, the module `torch.cuda` is returned.
144Return None when there is no corresponding module for device_type, otherwise
145return the corresponding module.
146"""
147return getattr(torch, device_type, None)148
149class DeviceMesh:150"""151DeviceMesh represents a mesh of devices, where layout of devices could be
152represented as a n-d dimension array, and each value of the n-d dimensional
153array is the global id of the default process group ranks.
154
155DeviceMesh could be used to describe the layout of devices across the cluster,
156and serves as a proxy for communication among the device lists within the cluster.
157
158DeviceMesh can be used as a context manager.
159
160.. note::
161DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
162is running on all processes/ranks in the cluster. Therefore, users need to make sure the
163`mesh` array (which describes the layout of devices) should be identical across all ranks.
164Inconsistent `mesh` will lead to silent hang.
165
166Args:
167device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
168mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
169of devices, where the IDs are global IDs of the default process group.
170
171Returns:
172DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
173
174The following program runs on each process/rank in an SPMD manner. In this example, we have 2
175hosts with 4 GPUs each.
176A reduction over the first dimension of mesh will reduce across
177columns (0, 4), .. and (3, 7), a reduction over the second dimension
178of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
179
180Example::
181>>> # xdoctest: +SKIP("no rank")
182>>> from torch.distributed.device_mesh import DeviceMesh
183>>>
184>>> # Initialize device mesh as (2, 4) to represent the topology
185>>> # of cross-host(dim 0), and within-host (dim 1).
186>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
187"""
188
189device_type: str190mesh: torch.Tensor191mesh_dim_names: Optional[Tuple[str, ...]]192
193def __init__(194self,195device_type: str,196mesh: Union[torch.Tensor, "ArrayLike"],197*,198mesh_dim_names: Optional[Tuple[str, ...]] = None,199) -> None:200self.device_type = device_type201if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":202raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")203self.mesh = (204mesh.detach().cpu()205if isinstance(mesh, torch.Tensor)206else torch.tensor(mesh, dtype=torch.int)207)208self.mesh_dim_names = mesh_dim_names209
210# private field to pre-generate DeviceMesh's hash211self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())212self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self)))213
214# Skip process group initialization if xla device.215# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.216if device_type != "xla":217# always try to create default (world) pg, even if it is not initialized218# already. The world pg is used for device mesh identity (rank) on each219# process (we need to know if the current global rank is in the mesh or not).220self._get_or_create_default_group()221self._init_process_groups()222
223def _get_or_create_default_group(self):224default_initialized = is_initialized()225if not default_initialized:226init_process_group()227
228world_size = get_world_size()229if self.mesh.numel() > world_size:230raise RuntimeError(231f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"232)233
234device_handle = _get_device_handle(self.device_type)235# TODO: if user want to pass pg_options, offer a way to do it236if not default_initialized and device_handle:237# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host238# NOTE: This device selection would only work for homogeneous hardware.239num_devices_per_host = device_handle.device_count()240if (241world_size > num_devices_per_host242and world_size % num_devices_per_host != 0243):244raise RuntimeError(245f"DeviceMesh only support homogeneous hardware, but found "246f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"247)248device_handle.set_device(get_rank() % num_devices_per_host)249
250# calculate the coordinates of the current global rank on the mesh251rank_coords = (self.mesh == get_rank()).nonzero()252assert rank_coords.size(0) in (0, 1)253self._coordinate_on_dim: Optional[List[int]] = (254rank_coords[0].tolist() if rank_coords.size(0) > 0 else None255)256return _get_default_group()257
258def _init_process_groups(self):259# tag/ranks/group_name associated with each mesh dimension, each260# mesh dimension should have one sub-group per rank261#262# TODO(yifu): remove tag and ranks once we fully migrate to native263# functional collectives. See details in:264# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208265dim_group_infos: List[Tuple[str, List[int], str]] = []266
267if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():268# if the mesh is the same as world_pg, we just append the default269# pg to the first dim groups, as new_group cannot have the exact270# same ranks as world271dim_group_infos.append(272(273_get_group_tag(_get_default_group()),274list(range(get_world_size())),275_get_default_group().group_name,276)277)278else:279# create sub pgs base on the mesh argument specified280for dim in range(self.mesh.ndim):281# swap the current dim to the last dim282# then reshape to flatten out other dims283pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(284-1, self.mesh.size(dim)285)286# multi-dim mesh, create subgroups by looping over the pg_ranks287# for each dim and append the groups288for dim_mesh in pg_ranks_by_dim:289subgroup_ranks = dim_mesh.tolist()290
291# We temporarily revert the re-use subgroup, since it breaks two internal tests.292# Temporarily reverting to resolve test timeout while root-causing.293# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.294dim_group = new_group(ranks=subgroup_ranks)295
296# only add to dim_groups if the current rank in the subgroup297if self.get_rank() in subgroup_ranks:298if len(dim_group_infos) > dim:299raise RuntimeError(300f"Each device mesh dimension should get only one process group, but got {self.get_rank} "301f"in {subgroup_ranks}!"302)303dim_group_infos.append(304(305_get_group_tag(not_none(dim_group)),306subgroup_ranks,307dim_group.group_name,308)309)310self._dim_group_infos = dim_group_infos311
312def __enter__(self) -> "DeviceMesh":313# set this mesh as the current mesh in mesh env314_mesh_resources.mesh_stack.append(self)315return self316
317# pyre-fixme[2]: Parameter must be annotated.318def __exit__(self, exc_type, exc_value, exc_traceback) -> None:319# pop this mesh from mesh env320_mesh_resources.mesh_stack.pop()321
322def __repr__(self) -> str:323device_mesh_repr = (324f"DeviceMesh({self.mesh.tolist()})"325if not self.mesh_dim_names326else f"DeviceMesh({self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})"327)328return device_mesh_repr329
330def __hash__(self):331return self._hash332
333def __eq__(self, other: object) -> bool:334if not isinstance(other, DeviceMesh):335return False336if id(self.mesh) == id(other.mesh):337return True338return (339self.mesh.shape == other.mesh.shape340and self._flatten_mesh_list == other._flatten_mesh_list341)342
343def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":344"""345Slice the current DeviceMesh based on the mesh_dim_name given to create a child
346DeviceMesh.
347
348Args:
349mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
350to create a child DeviceMesh for.
351Returns:
352A :class:`DeviceMesh` object
353
354The following program runs on each process/rank in an SPMD manner. In this example, we have 2
355hosts with 4 GPUs each.
356Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
357Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
358Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
359Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
360Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
361Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
362
363Example::
364>>> # xdoctest: +SKIP("no rank")
365>>> from torch.distributed.device_mesh import DeviceMesh
366>>>
367>>> # Initialize device mesh as (2, 4) to represent the topology
368>>> # of cross-host(dim 0), and within-host (dim 1).
369>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
370"""
371if self.mesh.ndim == 1:372if self.mesh_dim_names and mesh_dim_name == self.mesh_dim_names[0]:373return self374else:375raise RuntimeError(376f"Invalid mesh_dim_name {mesh_dim_name} specified."377)378
379mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name)380submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)381
382return submesh383
384def get_group(385self, mesh_dim: Optional[Union[int, str]] = None386) -> Union[ProcessGroup, List[ProcessGroup]]:387"""388Returns a list of ProcessGroups corresponding to the mesh dimensions, or
389returns a single ProcessGroup if mesh_dim is specified or the given mesh has
390only one mesh dimension.
391
392Args:
393mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
394of the mesh dimension. Default is None.
395
396Returns:
397A list of :class:`ProcessGroup` object when `mesh_dim` is not specified for
398a DeviceMesh with more than 1 dimension; otherwise, returns a single
399:class:`ProcessGroup` object.
400"""
401if not hasattr(self, "_dim_group_infos"):402raise RuntimeError("DeviceMesh process groups not initialized!")403
404if self.mesh.ndim == 1:405return not_none(406_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])407)408
409if mesh_dim is not None:410if isinstance(mesh_dim, str):411mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim)412return not_none(413_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2])414)415else:416dim_groups = []417for ith_dim in range(self.mesh.ndim):418dim_groups.append(419not_none(420_find_pg_by_ranks_and_tag(421*self._dim_group_infos[ith_dim][:2]422)423)424)425return dim_groups426
427def size(self, mesh_dim: Optional[int] = None) -> int:428return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)429
430@property431def ndim(self) -> int:432return self.mesh.ndim433
434@property435def shape(self) -> Tuple[int, ...]:436return tuple(self.mesh.shape)437
438def get_rank(self) -> int:439"""440Returns the current global rank.
441"""
442return get_rank()443
444def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:445"""446Returns the local rank of the given mesh_dim of the DeviceMesh.
447
448Args:
449mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
450of the mesh dimension. Default is None.
451
452Returns:
453An integer denotes the local rank.
454
455The following program runs on each process/rank in an SPMD manner. In this example, we have 2
456hosts with 4 GPUs each.
457Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
458Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
459Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
460Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
461Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
462Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
463
464Example::
465>>> # xdoctest: +SKIP("no rank")
466>>> from torch.distributed.device_mesh import DeviceMesh
467>>>
468>>> # Initialize device mesh as (2, 4) to represent the topology
469>>> # of cross-host(dim 0), and within-host (dim 1).
470>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
471"""
472if self.ndim > 1 and mesh_dim is None:473raise RuntimeError(474f"Found the DeviceMesh have {self.mesh.ndim} dimensions",475"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",476)477elif mesh_dim is None:478mesh_dim = 0479
480mesh_dim_group = not_none(self.get_group(mesh_dim))481assert isinstance(482mesh_dim_group, ProcessGroup483), "We expect ProcessGroup before calling `get_rank`!"484return not_none(get_rank(mesh_dim_group))485
486def get_coordinate(self) -> Optional[List[int]]:487"""488Return the relative indices of this rank relative to all
489dimensions of the mesh. If this rank is not part of the mesh, return None.
490"""
491return self._coordinate_on_dim if self._coordinate_on_dim else None492
493def init_device_mesh(494device_type: str,495mesh_shape: Tuple[int, ...],496*,497mesh_dim_names: Optional[Tuple[str, ...]] = None,498) -> DeviceMesh:499"""500Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
501
502This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
503If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
504
505.. note::
506`init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
507runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
508describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
509
510.. note::
511If no process group is found, init_device_mesh will initialize distributed process group/groups
512required for distributed communications behind the scene.
513
514Args:
515device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
516mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
517describing the layout of devices.
518mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
519of the multi-dimensional array describing the layout of devices. Its length must match the length
520of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
521
522Returns:
523DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
524
525Example::
526>>> # xdoctest: +SKIP("no rank")
527>>> from torch.distributed.device_mesh import init_device_mesh
528>>>
529>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
530>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
531
532"""
533if mesh_dim_names is not None:534if len(set(mesh_dim_names)) != len(mesh_dim_names):535raise RuntimeError(536"Each mesh_dim_name must be unique.",537f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",538)539
540if len(mesh_shape) != len(mesh_dim_names):541raise RuntimeError(542"mesh_shape and mesh_dim_names should have same length!",543f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",544)545
546mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)547device_mesh = DeviceMesh(548device_type=device_type,549mesh=mesh,550mesh_dim_names=mesh_dim_names,551)552
553return device_mesh554