TransformerEngine
528 строк · 18.3 Кб
1# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# See LICENSE for license information.
4"""
5Helper module for fp8 meta management
6"""
7from contextlib import contextmanager8from enum import Enum9from typing import Dict, Optional, Tuple, Union10
11import jax12import jax.numpy as jnp13from flax.core.frozen_dict import FrozenDict14from flax.linen import fp8_ops15
16from transformer_engine_jax import DType17from transformer_engine_jax import get_cublasLt_version18from transformer_engine_jax import get_cuda_version, get_device_compute_capability19from transformer_engine.common.recipe import DelayedScaling, Format20from transformer_engine.jax.sharding import global_shard_guard21from transformer_engine.jax.sharding import MeshResource22
23_is_fp8_available = None24_reason_for_no_fp8 = ""25Collection = Union[Dict, FrozenDict]26
27
28def _check_fp8_support(gpu_id) -> Tuple[bool, str]:29"""Return if fp8 support is available"""30gpu_arch = get_device_compute_capability(gpu_id)31if gpu_arch >= 90: # hopper and above32return True, ""33if gpu_arch < 89: # pre-ada34return False, "Device compute capability 8.9 or higher required for FP8 execution."35if get_cublasLt_version() < 120103:36return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."37if get_cuda_version() < 12010:38return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."39return True, ""40
41
42def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:43"""Return if fp8 support is available"""44if gpu_id is not None:45return _check_fp8_support(gpu_id)46
47global _is_fp8_available, _reason_for_no_fp848if _is_fp8_available is None:49_is_fp8_available = True50# JAX doesn't provide the local GPU id.51for local_gpu_id in range(len(jax.local_devices())):52ret, msg = _check_fp8_support(local_gpu_id)53if ret is False:54_is_fp8_available = ret55_reason_for_no_fp8 = msg56break57
58return _is_fp8_available, _reason_for_no_fp859
60
61def _format2dtypes(format_: Format):62if format_ == Format.E4M3:63return jnp.float8_e4m3fn, jnp.float8_e4m3fn64if format_ == Format.E5M2:65return jnp.float8_e5m2, jnp.float8_e5m266if format_ == Format.HYBRID:67return jnp.float8_e4m3fn, jnp.float8_e5m268return jnp.bfloat16, jnp.bfloat1669
70
71# fm32 is a custom dtype to specify the "add" rules as max operation.
72# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
73# which is implemented via nn.scan. Without this custom dtype, nn.scan
74# would sum gradients from all micro-batches, and this is not the expected
75# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
76# be "MAX".
77FlaxFloatMeta32 = fp8_ops.fm3278
79
80class FP8MetaPackage:81"""82A container that contains all required meta data for FP8
83"""
84
85def __init__(86self,87num_of_gemm: int,88fp8_max: jnp.ndarray,89amax: jnp.ndarray,90scale: jnp.ndarray,91scale_inv: jnp.ndarray,92) -> None:93total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM94self._num_of_gemm = num_of_gemm95assert fp8_max.shape[0] == total_num_of_meta96self._fp8_max = fp8_max97assert amax.shape[0] == total_num_of_meta98self._amax = amax99assert scale.shape[0] == total_num_of_meta100self._scale = scale101assert scale_inv.shape[0] == total_num_of_meta102self._scale_inv = scale_inv103
104@property105def num_of_gemm(self) -> int:106"""107num_of_gemm of this package
108"""
109return self._num_of_gemm110
111@property112def fp8_max(self) -> jnp.ndarray:113"""114fp8_max of this package
115"""
116return self._fp8_max117
118@property119def amax(self) -> jnp.ndarray:120"""121amax of this package
122"""
123return self._amax124
125@property126def scale(self) -> jnp.ndarray:127"""128scale of this package
129"""
130return self._scale131
132@property133def scale_inv(self) -> jnp.ndarray:134"""135scale_inv of this package
136"""
137return self._scale_inv138
139def get_package_by_gemm_idx(self, gemm_idx):140"""141Get a sub package by gemm_idx
142"""
143assert self.num_of_gemm > gemm_idx144
145meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM146meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM147return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],148self.amax[meta_start_idx:meta_end_idx],149self.scale[meta_start_idx:meta_end_idx],150self.scale_inv[meta_start_idx:meta_end_idx])151
152
153class AmaxComputeAlgo(Enum):154"""AmaxComputeAlgo."""155MAX = "max"156MOST_RECENT = "most_recent"157
158
159NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"160
161
162class FP8Helper:163"""164FP8 helper to manage the FP8 meta
165"""
166INITIALIZED = False167MARGIN: float = 0.0168FP8_FORMAT: Format = Format.HYBRID169FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]170BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]171UPDATE_FP8META_INTERVAL: int = 1172AMAX_HISTORY_LEN: int = 1024173AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX174NUM_META_PER_GEMM: int = 3175INPUT_META_IDX_PER_GEMM: int = 0176KERNEL_META_IDX_PER_GEMM: int = 1177GRAD_META_IDX_PER_GEMM: int = 2178FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME179FP8_AMAX_NAME: str = "fp8_meta_amax"180FP8_SCALE_NAME: str = "fp8_meta_scale"181FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"182FP8_MAX_NAME: str = "fp8_max"183FP8_2X_ACC_FPROP: bool = False184FP8_2X_ACC_DGRAD: bool = True185FP8_2X_ACC_WGRAD: bool = True186
187@staticmethod188def is_fp8_enabled():189"""190Indicate if fp8 training is enable or not.
191"""
192return FP8Helper.INITIALIZED193
194@staticmethod195def initialize(margin: float = 0.0,196fp8_format: Format = Format.HYBRID,197update_fp8meta_interval: int = 1,198amax_history_len: int = 1,199amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:200"""201Initialize the FP8 meta
202"""
203FP8Helper.INITIALIZED = True204FP8Helper.MARGIN = margin205FP8Helper.FP8_FORMAT = fp8_format206FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \207_format2dtypes(FP8Helper.FP8_FORMAT)208FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval209FP8Helper.AMAX_HISTORY_LEN = amax_history_len210FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo211FP8Helper.FP8_2X_ACC_FPROP = False212FP8Helper.FP8_2X_ACC_DGRAD = True213FP8Helper.FP8_2X_ACC_WGRAD = True214
215@staticmethod216def finalize() -> None:217"""218FP8 helper finalize
219"""
220FP8Helper.INITIALIZED = False221FP8Helper.MARGIN = 0.0222FP8Helper.FP8_FORMAT = Format.HYBRID223FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \224_format2dtypes(FP8Helper.FP8_FORMAT)225FP8Helper.UPDATE_FP8META_INTERVAL = 1226FP8Helper.AMAX_HISTORY_LEN = 1024227FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX228
229@staticmethod230def update_collections(new: Collection, original: Collection) -> Collection:231"""232Update the collections
233"""
234assert isinstance(original, (dict, FrozenDict))235assert isinstance(new, (dict, FrozenDict))236frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original237for key in new:238if key in frozen_original:239frozen_original, _ = frozen_original.pop(key)240new_coll = FrozenDict({**new, **frozen_original})241if not isinstance(original, FrozenDict):242new_coll = new_coll.unfreeze()243return new_coll244
245@staticmethod246def update_fp8_metas(state: Collection) -> Collection:247"""248Update the FP8 metas
249"""
250assert isinstance(state, (dict, FrozenDict))251if FP8Helper.FP8_COLLECTION_NAME in state:252frozen_state = FrozenDict(state) if not isinstance(state, FrozenDict) else state253others, fp8_metas = frozen_state.pop(FP8Helper.FP8_COLLECTION_NAME)254fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)255new_state = FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})256
257if not isinstance(state, FrozenDict):258new_state = new_state.unfreeze()259return new_state260return state261
262@staticmethod263def generate_fp8_max_array(num_of_meta):264"""265Generate the FP8 max array
266"""
267num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM268fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max269fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max270fp8_max_per_gemm = []271for i in range(FP8Helper.NUM_META_PER_GEMM):272val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \273else fp8_max_fwd274fp8_max_per_gemm.append([val])275fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)276return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)277
278@staticmethod279def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int, int, int]:280"""281Obtain the index about FP8 metas by the given GEMM index.
282"""
283input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM284kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM285grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM286return input_idx, kernel_idx, grad_idx287
288@staticmethod289@jax.jit290def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:291fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)292num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1293num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max294for i in range(num_of_gemm):295# flattern array is ordered in alphabetical order of collection names296fp8_max_idx = i * num_of_meta_with_max297fp8_amax_idx = fp8_max_idx + 1298fp8_scale_idx = fp8_amax_idx + 1299fp8_scale_inv_idx = fp8_scale_idx + 1300
301fp8_max = fp8_meta_arrays[fp8_max_idx]302if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:303amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=-1, keepdims=True)304else:305amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]306scale = fp8_meta_arrays[fp8_scale_idx]307
308sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)309sf = jnp.where(amax > 0.0, sf, scale)310sf = jnp.where(jnp.isfinite(amax), sf, scale)311fp8_meta_arrays[fp8_scale_idx] = sf312fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf313
314return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)315
316@staticmethod317def generate_fp8_meta_dtype_converter_pair(*args):318"""319Generate a pair of conversion fun in-between fm32 and fp32.
320"""
321
322def identical_fun(*metas):323return metas324
325def fm32_to_fp32_fun(*metas):326for meta in metas:327assert meta.dtype == FlaxFloatMeta32328return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas]329
330def fp32_to_fm32_fun(*metas):331for meta in metas:332assert meta.dtype == jnp.float32333return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas]334
335# Make functions to be a vaild JAX type336partial_identical_fun = jax.tree_util.Partial(identical_fun)337partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun)338partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun)339
340if len(args) < 1:341return partial_identical_fun, partial_identical_fun342
343original_dtype = args[0].dtype344for arg in args:345assert arg.dtype == original_dtype346
347if original_dtype == FlaxFloatMeta32:348return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun349
350return partial_identical_fun, partial_identical_fun351
352@staticmethod353def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:354"""355Update the amax history
356"""
357updated_amax = jnp.roll(amax, -1, -1)358updated_amax = updated_amax.at[..., 0].set(0)359return updated_amax360
361@staticmethod362@jax.jit363def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,364scale: jnp.ndarray) -> jnp.ndarray:365"""366Calculate fp8 scale and scale_inv based on given amax.
367"""
368if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:369amax = jnp.max(amax, axis=-1, keepdims=True)370else:371amax = amax[..., 0:1]372
373sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)374sf = jnp.where(amax > 0.0, sf, scale)375sf = jnp.where(jnp.isfinite(amax), sf, scale)376scale = sf377scale_inv = 1 / sf378
379return scale, scale_inv380
381
382@contextmanager
383def fp8_autocast(enabled: bool = False,384fp8_recipe: Optional[DelayedScaling] = None,385mesh_resource: Optional[MeshResource] = None) -> None:386r"""387Context manager for FP8 usage.
388
389.. code-block:: python
390
391mesh_shape = (4, 2)
392dp_mesh_axis_name = 'data_parallel'
393tp_mesh_axis_name = 'tensor_parallel'
394devices = np.asarray(jax.devices()).reshape(*mesh_shape)
395
396with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
397mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
398
399with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
400rules = extend_logical_axis_rules(tuple())
401transformer = TransformerLayer()
402
403with partitioning.axis_rules(rules):
404pjit(transformer.init, ...)(...)
405
406.. note::
407We only support :attr:`margin`, :attr:`fp8_format`,
408:attr:`interval`, :attr:`amax_history_len` and
409:attr:`amax_compute_algo`(with value 'max' and 'most_recent')
410in recipe.DelayedScaling currently. Other parameters in
411recipe.DelayedScaling will trigger an assertion.
412
413Parameters
414----------
415enabled: bool, default = False
416Whether or not to enable fp8
417fp8_recipe: recipe.DelayedScaling, default = None
418Recipe used for FP8 training.
419mesh_resource: MeshResource, default = None
420Specify the mesh axes for data and tensor parallelism to shard along.
421If set to None, then no data or tensor parallelism will be used.
422
423"""
424if fp8_recipe is None:425fp8_recipe = DelayedScaling()426
427assert fp8_recipe.amax_compute_algo in [428"max", "most_recent"429], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")430assert fp8_recipe.scaling_factor_compute_algo is None, (431"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")432assert fp8_recipe.override_linear_precision == (False, False, False), (433"DelayedScaling override_linear_precision isn't supported by TE/JAX.")434assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")435
436if mesh_resource is None:437mesh_resource = MeshResource()438
439try:440with global_shard_guard(mesh_resource):441if enabled:442fp8_available, reason_for_no_fp8 = is_fp8_available()443assert fp8_available, reason_for_no_fp8444
445amax_compute_algo = AmaxComputeAlgo.MOST_RECENT446if fp8_recipe.amax_compute_algo == 'max':447amax_compute_algo = AmaxComputeAlgo.MAX448
449FP8Helper.initialize(margin=fp8_recipe.margin,450fp8_format=fp8_recipe.fp8_format,451update_fp8meta_interval=fp8_recipe.interval,452amax_history_len=fp8_recipe.amax_history_len,453amax_compute_algo=amax_compute_algo)454yield455finally:456FP8Helper.finalize()457
458
459# Function Wrappers
460def update_collections(new: Collection, original: Collection) -> FrozenDict:461r"""462A helper to update Flax's Collection.
463
464Collection = [dict, flax.core.frozen_dict.FrozenDict]
465
466Parameters
467----------
468new: Collection
469A collection that includes new data.
470original: Collection
471The base collection.
472
473Returns
474-------
475outputs : Collection
476The updated collection.
477"""
478return FP8Helper.update_collections(new, original)479
480
481def update_fp8_metas(state: Collection) -> Collection:482r"""483Calculate new fp8 scales and its inverse via the followed formula
484
485.. code-block:: python
486
487sf = (fp8_max / amax) / (2 ^ margin)
488sf = sf if amax > 0.0, else original_scale
489updated_scale = sf if isfinite(amax), else original_scale)
490updated_scale_inv = 1/updated_scale
491
492Collection = [dict, flax.core.frozen_dict.FrozenDict]
493
494Parameters
495----------
496state: Collection
497A collection that includes FP8 metas.
498
499Returns
500-------
501outputs : Collection
502The collection with updated FP8 metas.
503"""
504return FP8Helper.update_fp8_metas(state)505
506
507def get_delayed_scaling():508r"""509Obtain an instance of DelayedScaling which is set via fp8_autocast.
510
511.. note::
512We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
513:attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
514Other parameters in recipe.DelayedScaling would be returned as the default
515values.
516
517Returns
518-------
519delay_scaling : DelayedScaling
520an instance of DelayedScaling which is set via fp8_autocast.
521"""
522amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \523else "most_recent"524return DelayedScaling(margin=int(FP8Helper.MARGIN),525interval=FP8Helper.UPDATE_FP8META_INTERVAL,526fp8_format=FP8Helper.FP8_FORMAT,527amax_history_len=FP8Helper.AMAX_HISTORY_LEN,528amax_compute_algo=amax_compute_algo)529