gpt-neox
45 строк · 1.5 Кб
1# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16import pathlib
17import subprocess
18
19from pathlib import Path
20
21srcpath = Path(__file__).parent.absolute()
22
23# Setting this param to a list has a problem of generating different
24# compilation commands (with different order of architectures) and
25# leading to recompilation of fused kernels. Set it to empty string
26# to avoid recompilation and assign arch flags explicitly in
27# extra_cuda_cflags below
28os.environ["TORCH_CUDA_ARCH_LIST"] = ""
29
30
31def load_fused_kernels():
32try:
33import scaled_upper_triang_masked_softmax_cuda
34import scaled_masked_softmax_cuda
35import fused_rotary_positional_embedding
36except (ImportError, ModuleNotFoundError) as e:
37print("\n")
38print(e)
39print("=" * 100)
40print(
41f"ERROR: Fused kernels configured but not properly installed. Please run `pip install {str(srcpath)}` to install them"
42)
43print("=" * 100)
44exit()
45return
46