DeepSpeed

Зеркало из https://github.com/microsoft/DeepSpeed
Форк
0
/
hpu_accelerator.py 
329 строк · 10.6 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
import functools
7
import os
8
import pkgutil
9
import importlib
10
import torch
11

12
from .abstract_accelerator import DeepSpeedAccelerator
13

14

15
class HPU_Accelerator(DeepSpeedAccelerator):
16

17
    def __init__(self):
18
        self._name = 'hpu'
19
        self._communication_backend_name = 'hccl'
20
        self._compile_backend = "hpu_backend"
21
        self.apply_hpu_workarounds()
22
        try:
23
            import habana_frameworks.torch.hpu as hpu
24
            hpu.setDeterministic(True)
25
            self.hpu = hpu
26
        except ImportError as e:
27
            raise ValueError(
28
                f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
29

30
        self.fp16_supported = None
31

32
    def apply_hpu_workarounds(self):
33

34
        def update_wa_env_var(key, value):
35
            if key not in os.environ.keys():
36
                os.environ[key] = value
37

38
        update_wa_env_var("PT_HPU_LAZY_ACC_PAR_MODE", "0")
39
        update_wa_env_var("PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES", "0")
40

41
    # Device APIs
42
    def is_synchronized_device(self):
43
        return False
44

45
    def use_host_timers(self):
46
        return False
47

48
    def resolves_data_dependency(self):
49
        return True
50

51
    def handles_memory_backpressure(self):
52
        return True
53

54
    def device_name(self, device_index=None):
55
        # ignoring device_index.
56
        return 'hpu'
57

58
    def device(self, device_index=None):
59
        return torch.device(self.device_name(device_index))
60

61
    def set_device(self, device_index):
62
        self.hpu.set_device(device_index)
63

64
    def current_device(self):
65
        return (self.hpu.current_device())
66

67
    def current_device_name(self):
68
        return 'hpu:{}'.format(self.current_device())
69

70
    def device_count(self):
71
        return self.hpu.device_count()
72

73
    def synchronize(self, device_index=None):
74
        return self.hpu.synchronize()
75

76
    # RNG APIs
77
    def random(self):
78
        return torch.random
79

80
    def set_rng_state(self, new_state, device_index=None):
81
        self.hpu.random.set_rng_state(new_state)
82

83
    def get_rng_state(self, device_index=None):
84
        return self.hpu.random.get_rng_state()
85

86
    def manual_seed(self, seed):
87
        return self.hpu.random.manual_seed(seed)
88

89
    def manual_seed_all(self, seed):
90
        self.hpu.random.manual_seed_all(seed)
91

92
    def initial_seed(self):
93
        return self.hpu.random.initial_seed()
94

95
    def default_generator(self, device_index):
96
        return self.hpu.random.default_generators[device_index]
97

98
    # Streams/Events
99
    @property
100
    def Stream(self):
101
        return self.hpu.Stream
102

103
    def stream(self, stream):
104
        return self.hpu.stream(stream)
105

106
    def current_stream(self, device_index=None):
107
        return self.hpu.current_stream()
108

109
    def default_stream(self, device_index=None):
110
        return self.hpu.default_stream()
111

112
    @property
113
    def Event(self):
114
        import habana_frameworks.torch.core as htcore
115
        return htcore.hpu.Event
116

117
    # Memory management
118
    def empty_cache(self):
119
        return
120

121
    def memory_allocated(self, device_index=None):
122
        return self.hpu.memory_allocated()
123

124
    def max_memory_allocated(self, device_index=None):
125
        return self.hpu.max_memory_allocated()
126

127
    def reset_max_memory_allocated(self, device_index=None):
128
        return self.hpu.reset_max_memory_allocated()
129

130
    def memory_cached(self, device_index=None):
131
        return self.hpu.memory_cached(device_index)
132

133
    def max_memory_cached(self, device_index=None):
134
        return self.hpu.max_memory_cached(device_index)
135

136
    def reset_max_memory_cached(self, device_index=None):
137
        return None
138

139
    def memory_stats(self, device_index=None):
140
        return self.hpu.memory_stats(device_index)
141

142
    def reset_peak_memory_stats(self, device_index=None):
143
        self.hpu.reset_peak_memory_stats(device_index)
144

145
    def memory_reserved(self, device_index=None):
146
        return self.hpu.memory_reserved(device_index)
147

148
    def max_memory_reserved(self, device_index=None):
149
        return self.hpu.max_memory_reserved(device_index)
150

151
    def total_memory(self, device_index=None):
152
        return self.memory_stats(device_index)['Limit']
153

154
    def available_memory(self, device_index=None):
155
        return self.total_memory(device_index) - self.memory_allocated(device_index)
156

157
    # Data types
158
    def is_bf16_supported(self):
159
        return True
160

161
    def is_fp16_supported(self):
162
        if self.fp16_supported is None:
163
            import habana_frameworks.torch.utils.experimental as htexp
164
            self.fp16_supported = htexp._is_fp16_supported()
165
        return self.fp16_supported
166

167
    def supported_dtypes(self):
168
        supported_dtypes = [torch.float, torch.bfloat16]
169
        if self.is_fp16_supported():
170
            supported_dtypes.append(torch.half)
171
        return supported_dtypes
172

173
    # Misc
174
    def amp(self):
175
        return None
176

177
    def is_available(self):
178
        return self.hpu.is_available()
179

180
    def range_push(self, msg):
181
        return
182

183
    def range_pop(self):
184
        return
185

186
    def lazy_call(self, callback):
187
        callback()
188

189
    def communication_backend_name(self):
190
        return self._communication_backend_name
191

192
    def is_triton_supported(self):
193
        return False
194

195
    # Graph operations
196
    def create_graph(self):
197
        return self.hpu.HPUGraph()
198

199
    def capture_to_graph(self, graph, pool=None, stream=None):
200
        return self.hpu.graph(graph, stream=stream)
201

202
    def replay_graph(self, graph):
203
        graph.replay()
204
        return
205

206
    # Tensor operations
207
    @property
208
    def BFloat16Tensor(self):
209
        return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')
210

211
    @property
212
    def ByteTensor(self):
213
        return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')
214

215
    @property
216
    def DoubleTensor(self):
217
        return functools.partial(torch.tensor, dtype=torch.double, device='hpu')
218

219
    @property
220
    def FloatTensor(self):
221
        return functools.partial(torch.tensor, dtype=torch.float, device='hpu')
222

223
    @property
224
    def HalfTensor(self):
225
        return functools.partial(torch.tensor, dtype=torch.half, device='hpu')
226

227
    @property
228
    def IntTensor(self):
229
        return functools.partial(torch.tensor, dtype=torch.int, device='hpu')
230

231
    @property
232
    def LongTensor(self):
233
        return functools.partial(torch.tensor, dtype=torch.long, device='hpu')
234

235
    def pin_memory(self, tensor, align_bytes=1):
236
        return tensor.pin_memory(self.device())
237

238
    def is_pinned(self, tensor):
239
        return tensor.is_pinned()
240

241
    def on_accelerator(self, tensor):
242
        device_str = str(tensor.device)
243
        if device_str.startswith('hpu:'):
244
            return True
245
        else:
246
            return False
247

248
    def op_builder_dir(self):
249
        try:
250
            # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
251
            # if successful this also means we're doing a local install and not JIT compile path
252
            from op_builder import __deepspeed__  # noqa: F401 # type: ignore
253
            return "op_builder.hpu"
254
        except ImportError:
255
            return "deepspeed.ops.op_builder.hpu"
256

257
    # dict that holds class name <--> class type mapping i.e.
258
    # 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
259
    # this dict will be filled at init stage
260
    class_dict = None
261

262
    def _lazy_init_class_dict(self):
263
        if self.class_dict is not None:
264
            return
265
        else:
266
            self.class_dict = {}
267
            # begin initialize for create_op_builder()
268
            # put all valid class name <--> class type mapping into class_dict
269
            op_builder_dir = self.op_builder_dir()
270
            op_builder_module = importlib.import_module(op_builder_dir)
271
            op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
272
            for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
273
                # avoid self references,
274
                # skip sub_directories which contains ops for other backend(cpu, npu, etc.).
275
                if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
276
                        os.path.join(op_builder_absolute_path, module_name)):
277
                    module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
278
                    for member_name in module.__dir__():
279
                        if member_name.endswith(
280
                                'Builder'
281
                        ) and member_name != "OpBuilder" and member_name != "CPUOpBuilder" and member_name != "TorchCPUOpBuilder":  # avoid abstract classes
282
                            if not member_name in self.class_dict:
283
                                self.class_dict[member_name] = getattr(module, member_name)
284
            # end initialize for create_op_builder()
285

286
    # create an instance of op builder and return, name specified by class_name
287
    def create_op_builder(self, class_name):
288
        self._lazy_init_class_dict()
289
        if class_name in self.class_dict:
290
            return self.class_dict[class_name]()
291
        else:
292
            return None
293

294
    # return an op builder class, name specified by class_name
295
    def get_op_builder(self, class_name):
296
        self._lazy_init_class_dict()
297
        if class_name in self.class_dict:
298
            return self.class_dict[class_name]
299
        else:
300
            return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
301

302
    def build_extension(self):
303
        from torch.utils.cpp_extension import BuildExtension
304
        return BuildExtension
305

306
    def export_envs(self):
307
        return []
308

309
    def visible_devices_envs(self):
310
        # Current way deepspeed set this env var is not applicable with all HPU instances
311
        # User has to follow instructions in:
312
        # https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html
313
        # keeping CUDA_VISIBLE_DEVICES
314
        return ['CUDA_VISIBLE_DEVICES']  #['HABANA_VISIBLE_MODULES']
315

316
    def set_visible_devices_envs(self, current_env, local_accelerator_ids):
317
        for env in self.visible_devices_envs():
318
            current_env[env] = ",".join(map(str, local_accelerator_ids))
319

320
    def get_compile_backend(self):
321
        return self._compile_backend
322

323
    def set_compile_backend(self, backend):
324
        supported_backends = torch._dynamo.list_backends(exclude_tags=())
325
        if backend in supported_backends:
326
            self._compile_backend = backend
327
        else:
328
            raise ValueError(
329
                f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
330

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.