TransformerEngine

Форк
0
528 строк · 18.3 Кб
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# See LICENSE for license information.
4
"""
5
Helper module for fp8 meta management
6
"""
7
from contextlib import contextmanager
8
from enum import Enum
9
from typing import Dict, Optional, Tuple, Union
10

11
import jax
12
import jax.numpy as jnp
13
from flax.core.frozen_dict import FrozenDict
14
from flax.linen import fp8_ops
15

16
from transformer_engine_jax import DType
17
from transformer_engine_jax import get_cublasLt_version
18
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
19
from transformer_engine.common.recipe import DelayedScaling, Format
20
from transformer_engine.jax.sharding import global_shard_guard
21
from transformer_engine.jax.sharding import MeshResource
22

23
_is_fp8_available = None
24
_reason_for_no_fp8 = ""
25
Collection = Union[Dict, FrozenDict]
26

27

28
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
29
    """Return if fp8 support is available"""
30
    gpu_arch = get_device_compute_capability(gpu_id)
31
    if gpu_arch >= 90:    # hopper and above
32
        return True, ""
33
    if gpu_arch < 89:    # pre-ada
34
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
35
    if get_cublasLt_version() < 120103:
36
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
37
    if get_cuda_version() < 12010:
38
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
39
    return True, ""
40

41

42
def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
43
    """Return if fp8 support is available"""
44
    if gpu_id is not None:
45
        return _check_fp8_support(gpu_id)
46

47
    global _is_fp8_available, _reason_for_no_fp8
48
    if _is_fp8_available is None:
49
        _is_fp8_available = True
50
        # JAX doesn't provide the local GPU id.
51
        for local_gpu_id in range(len(jax.local_devices())):
52
            ret, msg = _check_fp8_support(local_gpu_id)
53
            if ret is False:
54
                _is_fp8_available = ret
55
                _reason_for_no_fp8 = msg
56
            break
57

58
    return _is_fp8_available, _reason_for_no_fp8
59

60

61
def _format2dtypes(format_: Format):
62
    if format_ == Format.E4M3:
63
        return jnp.float8_e4m3fn, jnp.float8_e4m3fn
64
    if format_ == Format.E5M2:
65
        return jnp.float8_e5m2, jnp.float8_e5m2
66
    if format_ == Format.HYBRID:
67
        return jnp.float8_e4m3fn, jnp.float8_e5m2
68
    return jnp.bfloat16, jnp.bfloat16
69

70

71
# fm32 is a custom dtype to specify the "add" rules as max operation.
72
# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
73
# which is implemented via nn.scan. Without this custom dtype, nn.scan
74
# would sum gradients from all micro-batches, and this is not the expected
75
# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
76
# be "MAX".
77
FlaxFloatMeta32 = fp8_ops.fm32
78

79

80
class FP8MetaPackage:
81
    """
82
    A container that contains all required meta data for FP8
83
    """
84

85
    def __init__(
86
        self,
87
        num_of_gemm: int,
88
        fp8_max: jnp.ndarray,
89
        amax: jnp.ndarray,
90
        scale: jnp.ndarray,
91
        scale_inv: jnp.ndarray,
92
    ) -> None:
93
        total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
94
        self._num_of_gemm = num_of_gemm
95
        assert fp8_max.shape[0] == total_num_of_meta
96
        self._fp8_max = fp8_max
97
        assert amax.shape[0] == total_num_of_meta
98
        self._amax = amax
99
        assert scale.shape[0] == total_num_of_meta
100
        self._scale = scale
101
        assert scale_inv.shape[0] == total_num_of_meta
102
        self._scale_inv = scale_inv
103

104
    @property
105
    def num_of_gemm(self) -> int:
106
        """
107
        num_of_gemm of this package
108
        """
109
        return self._num_of_gemm
110

111
    @property
112
    def fp8_max(self) -> jnp.ndarray:
113
        """
114
        fp8_max of this package
115
        """
116
        return self._fp8_max
117

118
    @property
119
    def amax(self) -> jnp.ndarray:
120
        """
121
        amax of this package
122
        """
123
        return self._amax
124

125
    @property
126
    def scale(self) -> jnp.ndarray:
127
        """
128
        scale of this package
129
        """
130
        return self._scale
131

132
    @property
133
    def scale_inv(self) -> jnp.ndarray:
134
        """
135
        scale_inv of this package
136
        """
137
        return self._scale_inv
138

139
    def get_package_by_gemm_idx(self, gemm_idx):
140
        """
141
        Get a sub package by gemm_idx
142
        """
143
        assert self.num_of_gemm > gemm_idx
144

145
        meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
146
        meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
147
        return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
148
                              self.amax[meta_start_idx:meta_end_idx],
149
                              self.scale[meta_start_idx:meta_end_idx],
150
                              self.scale_inv[meta_start_idx:meta_end_idx])
151

152

153
class AmaxComputeAlgo(Enum):
154
    """AmaxComputeAlgo."""
155
    MAX = "max"
156
    MOST_RECENT = "most_recent"
157

158

159
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"
160

161

162
class FP8Helper:
163
    """
164
    FP8 helper to manage the FP8 meta
165
    """
166
    INITIALIZED = False
167
    MARGIN: float = 0.0
168
    FP8_FORMAT: Format = Format.HYBRID
169
    FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
170
    BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
171
    UPDATE_FP8META_INTERVAL: int = 1
172
    AMAX_HISTORY_LEN: int = 1024
173
    AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
174
    NUM_META_PER_GEMM: int = 3
175
    INPUT_META_IDX_PER_GEMM: int = 0
176
    KERNEL_META_IDX_PER_GEMM: int = 1
177
    GRAD_META_IDX_PER_GEMM: int = 2
178
    FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
179
    FP8_AMAX_NAME: str = "fp8_meta_amax"
180
    FP8_SCALE_NAME: str = "fp8_meta_scale"
181
    FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
182
    FP8_MAX_NAME: str = "fp8_max"
183
    FP8_2X_ACC_FPROP: bool = False
184
    FP8_2X_ACC_DGRAD: bool = True
185
    FP8_2X_ACC_WGRAD: bool = True
186

187
    @staticmethod
188
    def is_fp8_enabled():
189
        """
190
        Indicate if fp8 training is enable or not.
191
        """
192
        return FP8Helper.INITIALIZED
193

194
    @staticmethod
195
    def initialize(margin: float = 0.0,
196
                   fp8_format: Format = Format.HYBRID,
197
                   update_fp8meta_interval: int = 1,
198
                   amax_history_len: int = 1,
199
                   amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:
200
        """
201
        Initialize the FP8 meta
202
        """
203
        FP8Helper.INITIALIZED = True
204
        FP8Helper.MARGIN = margin
205
        FP8Helper.FP8_FORMAT = fp8_format
206
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
207
            _format2dtypes(FP8Helper.FP8_FORMAT)
208
        FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
209
        FP8Helper.AMAX_HISTORY_LEN = amax_history_len
210
        FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
211
        FP8Helper.FP8_2X_ACC_FPROP = False
212
        FP8Helper.FP8_2X_ACC_DGRAD = True
213
        FP8Helper.FP8_2X_ACC_WGRAD = True
214

215
    @staticmethod
216
    def finalize() -> None:
217
        """
218
        FP8 helper finalize
219
        """
220
        FP8Helper.INITIALIZED = False
221
        FP8Helper.MARGIN = 0.0
222
        FP8Helper.FP8_FORMAT = Format.HYBRID
223
        FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
224
            _format2dtypes(FP8Helper.FP8_FORMAT)
225
        FP8Helper.UPDATE_FP8META_INTERVAL = 1
226
        FP8Helper.AMAX_HISTORY_LEN = 1024
227
        FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
228

229
    @staticmethod
230
    def update_collections(new: Collection, original: Collection) -> Collection:
231
        """
232
        Update the collections
233
        """
234
        assert isinstance(original, (dict, FrozenDict))
235
        assert isinstance(new, (dict, FrozenDict))
236
        frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
237
        for key in new:
238
            if key in frozen_original:
239
                frozen_original, _ = frozen_original.pop(key)
240
        new_coll = FrozenDict({**new, **frozen_original})
241
        if not isinstance(original, FrozenDict):
242
            new_coll = new_coll.unfreeze()
243
        return new_coll
244

245
    @staticmethod
246
    def update_fp8_metas(state: Collection) -> Collection:
247
        """
248
        Update the FP8 metas
249
        """
250
        assert isinstance(state, (dict, FrozenDict))
251
        if FP8Helper.FP8_COLLECTION_NAME in state:
252
            frozen_state = FrozenDict(state) if not isinstance(state, FrozenDict) else state
253
            others, fp8_metas = frozen_state.pop(FP8Helper.FP8_COLLECTION_NAME)
254
            fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
255
            new_state = FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})
256

257
            if not isinstance(state, FrozenDict):
258
                new_state = new_state.unfreeze()
259
            return new_state
260
        return state
261

262
    @staticmethod
263
    def generate_fp8_max_array(num_of_meta):
264
        """
265
        Generate the FP8 max array
266
        """
267
        num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
268
        fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
269
        fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
270
        fp8_max_per_gemm = []
271
        for i in range(FP8Helper.NUM_META_PER_GEMM):
272
            val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
273
                else fp8_max_fwd
274
            fp8_max_per_gemm.append([val])
275
        fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)
276
        return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)
277

278
    @staticmethod
279
    def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int, int, int]:
280
        """
281
        Obtain the index about FP8 metas by the given GEMM index.
282
        """
283
        input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM
284
        kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM
285
        grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM
286
        return input_idx, kernel_idx, grad_idx
287

288
    @staticmethod
289
    @jax.jit
290
    def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
291
        fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)
292
        num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1
293
        num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max
294
        for i in range(num_of_gemm):
295
            # flattern array is ordered in alphabetical order of collection names
296
            fp8_max_idx = i * num_of_meta_with_max
297
            fp8_amax_idx = fp8_max_idx + 1
298
            fp8_scale_idx = fp8_amax_idx + 1
299
            fp8_scale_inv_idx = fp8_scale_idx + 1
300

301
            fp8_max = fp8_meta_arrays[fp8_max_idx]
302
            if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
303
                amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=-1, keepdims=True)
304
            else:
305
                amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
306
            scale = fp8_meta_arrays[fp8_scale_idx]
307

308
            sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
309
            sf = jnp.where(amax > 0.0, sf, scale)
310
            sf = jnp.where(jnp.isfinite(amax), sf, scale)
311
            fp8_meta_arrays[fp8_scale_idx] = sf
312
            fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf
313

314
        return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
315

316
    @staticmethod
317
    def generate_fp8_meta_dtype_converter_pair(*args):
318
        """
319
        Generate a pair of conversion fun in-between fm32 and fp32.
320
        """
321

322
        def identical_fun(*metas):
323
            return metas
324

325
        def fm32_to_fp32_fun(*metas):
326
            for meta in metas:
327
                assert meta.dtype == FlaxFloatMeta32
328
            return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas]
329

330
        def fp32_to_fm32_fun(*metas):
331
            for meta in metas:
332
                assert meta.dtype == jnp.float32
333
            return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas]
334

335
        # Make functions to be a vaild JAX type
336
        partial_identical_fun = jax.tree_util.Partial(identical_fun)
337
        partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun)
338
        partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun)
339

340
        if len(args) < 1:
341
            return partial_identical_fun, partial_identical_fun
342

343
        original_dtype = args[0].dtype
344
        for arg in args:
345
            assert arg.dtype == original_dtype
346

347
        if original_dtype == FlaxFloatMeta32:
348
            return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun
349

350
        return partial_identical_fun, partial_identical_fun
351

352
    @staticmethod
353
    def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
354
        """
355
        Update the amax history
356
        """
357
        updated_amax = jnp.roll(amax, -1, -1)
358
        updated_amax = updated_amax.at[..., 0].set(0)
359
        return updated_amax
360

361
    @staticmethod
362
    @jax.jit
363
    def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
364
                         scale: jnp.ndarray) -> jnp.ndarray:
365
        """
366
        Calculate fp8 scale and scale_inv based on given amax.
367
        """
368
        if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
369
            amax = jnp.max(amax, axis=-1, keepdims=True)
370
        else:
371
            amax = amax[..., 0:1]
372

373
        sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
374
        sf = jnp.where(amax > 0.0, sf, scale)
375
        sf = jnp.where(jnp.isfinite(amax), sf, scale)
376
        scale = sf
377
        scale_inv = 1 / sf
378

379
        return scale, scale_inv
380

381

382
@contextmanager
383
def fp8_autocast(enabled: bool = False,
384
                 fp8_recipe: Optional[DelayedScaling] = None,
385
                 mesh_resource: Optional[MeshResource] = None) -> None:
386
    r"""
387
    Context manager for FP8 usage.
388

389
    .. code-block:: python
390

391
        mesh_shape = (4, 2)
392
        dp_mesh_axis_name = 'data_parallel'
393
        tp_mesh_axis_name = 'tensor_parallel'
394
        devices = np.asarray(jax.devices()).reshape(*mesh_shape)
395

396
        with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
397
            mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
398

399
            with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
400
                rules = extend_logical_axis_rules(tuple())
401
                transformer = TransformerLayer()
402

403
                with partitioning.axis_rules(rules):
404
                    pjit(transformer.init, ...)(...)
405

406
    .. note::
407
        We only support :attr:`margin`, :attr:`fp8_format`,
408
        :attr:`interval`, :attr:`amax_history_len` and
409
        :attr:`amax_compute_algo`(with value 'max' and 'most_recent')
410
        in recipe.DelayedScaling currently. Other parameters in
411
        recipe.DelayedScaling will trigger an assertion.
412

413
    Parameters
414
    ----------
415
    enabled: bool, default = False
416
        Whether or not to enable fp8
417
    fp8_recipe: recipe.DelayedScaling, default = None
418
        Recipe used for FP8 training.
419
    mesh_resource: MeshResource, default = None
420
        Specify the mesh axes for data and tensor parallelism to shard along.
421
        If set to None, then no data or tensor parallelism will be used.
422

423
    """
424
    if fp8_recipe is None:
425
        fp8_recipe = DelayedScaling()
426

427
    assert fp8_recipe.amax_compute_algo in [
428
        "max", "most_recent"
429
    ], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
430
    assert fp8_recipe.scaling_factor_compute_algo is None, (
431
        "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")
432
    assert fp8_recipe.override_linear_precision == (False, False, False), (
433
        "DelayedScaling override_linear_precision isn't supported by TE/JAX.")
434
    assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
435

436
    if mesh_resource is None:
437
        mesh_resource = MeshResource()
438

439
    try:
440
        with global_shard_guard(mesh_resource):
441
            if enabled:
442
                fp8_available, reason_for_no_fp8 = is_fp8_available()
443
                assert fp8_available, reason_for_no_fp8
444

445
                amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
446
                if fp8_recipe.amax_compute_algo == 'max':
447
                    amax_compute_algo = AmaxComputeAlgo.MAX
448

449
                FP8Helper.initialize(margin=fp8_recipe.margin,
450
                                     fp8_format=fp8_recipe.fp8_format,
451
                                     update_fp8meta_interval=fp8_recipe.interval,
452
                                     amax_history_len=fp8_recipe.amax_history_len,
453
                                     amax_compute_algo=amax_compute_algo)
454
            yield
455
    finally:
456
        FP8Helper.finalize()
457

458

459
# Function Wrappers
460
def update_collections(new: Collection, original: Collection) -> FrozenDict:
461
    r"""
462
    A helper to update Flax's Collection.
463

464
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
465

466
    Parameters
467
    ----------
468
    new: Collection
469
        A collection that includes new data.
470
    original: Collection
471
        The base collection.
472

473
    Returns
474
    -------
475
    outputs : Collection
476
        The updated collection.
477
    """
478
    return FP8Helper.update_collections(new, original)
479

480

481
def update_fp8_metas(state: Collection) -> Collection:
482
    r"""
483
    Calculate new fp8 scales and its inverse via the followed formula
484

485
    .. code-block:: python
486

487
        sf = (fp8_max / amax) / (2 ^ margin)
488
        sf = sf if amax > 0.0, else original_scale
489
        updated_scale = sf if isfinite(amax), else original_scale)
490
        updated_scale_inv = 1/updated_scale
491

492
    Collection = [dict, flax.core.frozen_dict.FrozenDict]
493

494
    Parameters
495
    ----------
496
    state: Collection
497
        A collection that includes FP8 metas.
498

499
    Returns
500
    -------
501
    outputs : Collection
502
        The collection with updated FP8 metas.
503
    """
504
    return FP8Helper.update_fp8_metas(state)
505

506

507
def get_delayed_scaling():
508
    r"""
509
    Obtain an instance of  DelayedScaling which is set via fp8_autocast.
510

511
    .. note::
512
        We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
513
        :attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
514
        Other parameters in recipe.DelayedScaling would be returned as the default
515
        values.
516

517
    Returns
518
    -------
519
    delay_scaling : DelayedScaling
520
        an instance of  DelayedScaling which is set via fp8_autocast.
521
    """
522
    amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
523
                        else "most_recent"
524
    return DelayedScaling(margin=int(FP8Helper.MARGIN),
525
                          interval=FP8Helper.UPDATE_FP8META_INTERVAL,
526
                          fp8_format=FP8Helper.FP8_FORMAT,
527
                          amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
528
                          amax_compute_algo=amax_compute_algo)
529

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.