DeepSpeed
Зеркало из https://github.com/microsoft/DeepSpeed
329 строк · 10.6 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6import functools
7import os
8import pkgutil
9import importlib
10import torch
11
12from .abstract_accelerator import DeepSpeedAccelerator
13
14
15class HPU_Accelerator(DeepSpeedAccelerator):
16
17def __init__(self):
18self._name = 'hpu'
19self._communication_backend_name = 'hccl'
20self._compile_backend = "hpu_backend"
21self.apply_hpu_workarounds()
22try:
23import habana_frameworks.torch.hpu as hpu
24hpu.setDeterministic(True)
25self.hpu = hpu
26except ImportError as e:
27raise ValueError(
28f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
29
30self.fp16_supported = None
31
32def apply_hpu_workarounds(self):
33
34def update_wa_env_var(key, value):
35if key not in os.environ.keys():
36os.environ[key] = value
37
38update_wa_env_var("PT_HPU_LAZY_ACC_PAR_MODE", "0")
39update_wa_env_var("PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES", "0")
40
41# Device APIs
42def is_synchronized_device(self):
43return False
44
45def use_host_timers(self):
46return False
47
48def resolves_data_dependency(self):
49return True
50
51def handles_memory_backpressure(self):
52return True
53
54def device_name(self, device_index=None):
55# ignoring device_index.
56return 'hpu'
57
58def device(self, device_index=None):
59return torch.device(self.device_name(device_index))
60
61def set_device(self, device_index):
62self.hpu.set_device(device_index)
63
64def current_device(self):
65return (self.hpu.current_device())
66
67def current_device_name(self):
68return 'hpu:{}'.format(self.current_device())
69
70def device_count(self):
71return self.hpu.device_count()
72
73def synchronize(self, device_index=None):
74return self.hpu.synchronize()
75
76# RNG APIs
77def random(self):
78return torch.random
79
80def set_rng_state(self, new_state, device_index=None):
81self.hpu.random.set_rng_state(new_state)
82
83def get_rng_state(self, device_index=None):
84return self.hpu.random.get_rng_state()
85
86def manual_seed(self, seed):
87return self.hpu.random.manual_seed(seed)
88
89def manual_seed_all(self, seed):
90self.hpu.random.manual_seed_all(seed)
91
92def initial_seed(self):
93return self.hpu.random.initial_seed()
94
95def default_generator(self, device_index):
96return self.hpu.random.default_generators[device_index]
97
98# Streams/Events
99@property
100def Stream(self):
101return self.hpu.Stream
102
103def stream(self, stream):
104return self.hpu.stream(stream)
105
106def current_stream(self, device_index=None):
107return self.hpu.current_stream()
108
109def default_stream(self, device_index=None):
110return self.hpu.default_stream()
111
112@property
113def Event(self):
114import habana_frameworks.torch.core as htcore
115return htcore.hpu.Event
116
117# Memory management
118def empty_cache(self):
119return
120
121def memory_allocated(self, device_index=None):
122return self.hpu.memory_allocated()
123
124def max_memory_allocated(self, device_index=None):
125return self.hpu.max_memory_allocated()
126
127def reset_max_memory_allocated(self, device_index=None):
128return self.hpu.reset_max_memory_allocated()
129
130def memory_cached(self, device_index=None):
131return self.hpu.memory_cached(device_index)
132
133def max_memory_cached(self, device_index=None):
134return self.hpu.max_memory_cached(device_index)
135
136def reset_max_memory_cached(self, device_index=None):
137return None
138
139def memory_stats(self, device_index=None):
140return self.hpu.memory_stats(device_index)
141
142def reset_peak_memory_stats(self, device_index=None):
143self.hpu.reset_peak_memory_stats(device_index)
144
145def memory_reserved(self, device_index=None):
146return self.hpu.memory_reserved(device_index)
147
148def max_memory_reserved(self, device_index=None):
149return self.hpu.max_memory_reserved(device_index)
150
151def total_memory(self, device_index=None):
152return self.memory_stats(device_index)['Limit']
153
154def available_memory(self, device_index=None):
155return self.total_memory(device_index) - self.memory_allocated(device_index)
156
157# Data types
158def is_bf16_supported(self):
159return True
160
161def is_fp16_supported(self):
162if self.fp16_supported is None:
163import habana_frameworks.torch.utils.experimental as htexp
164self.fp16_supported = htexp._is_fp16_supported()
165return self.fp16_supported
166
167def supported_dtypes(self):
168supported_dtypes = [torch.float, torch.bfloat16]
169if self.is_fp16_supported():
170supported_dtypes.append(torch.half)
171return supported_dtypes
172
173# Misc
174def amp(self):
175return None
176
177def is_available(self):
178return self.hpu.is_available()
179
180def range_push(self, msg):
181return
182
183def range_pop(self):
184return
185
186def lazy_call(self, callback):
187callback()
188
189def communication_backend_name(self):
190return self._communication_backend_name
191
192def is_triton_supported(self):
193return False
194
195# Graph operations
196def create_graph(self):
197return self.hpu.HPUGraph()
198
199def capture_to_graph(self, graph, pool=None, stream=None):
200return self.hpu.graph(graph, stream=stream)
201
202def replay_graph(self, graph):
203graph.replay()
204return
205
206# Tensor operations
207@property
208def BFloat16Tensor(self):
209return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')
210
211@property
212def ByteTensor(self):
213return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')
214
215@property
216def DoubleTensor(self):
217return functools.partial(torch.tensor, dtype=torch.double, device='hpu')
218
219@property
220def FloatTensor(self):
221return functools.partial(torch.tensor, dtype=torch.float, device='hpu')
222
223@property
224def HalfTensor(self):
225return functools.partial(torch.tensor, dtype=torch.half, device='hpu')
226
227@property
228def IntTensor(self):
229return functools.partial(torch.tensor, dtype=torch.int, device='hpu')
230
231@property
232def LongTensor(self):
233return functools.partial(torch.tensor, dtype=torch.long, device='hpu')
234
235def pin_memory(self, tensor, align_bytes=1):
236return tensor.pin_memory(self.device())
237
238def is_pinned(self, tensor):
239return tensor.is_pinned()
240
241def on_accelerator(self, tensor):
242device_str = str(tensor.device)
243if device_str.startswith('hpu:'):
244return True
245else:
246return False
247
248def op_builder_dir(self):
249try:
250# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
251# if successful this also means we're doing a local install and not JIT compile path
252from op_builder import __deepspeed__ # noqa: F401 # type: ignore
253return "op_builder.hpu"
254except ImportError:
255return "deepspeed.ops.op_builder.hpu"
256
257# dict that holds class name <--> class type mapping i.e.
258# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
259# this dict will be filled at init stage
260class_dict = None
261
262def _lazy_init_class_dict(self):
263if self.class_dict is not None:
264return
265else:
266self.class_dict = {}
267# begin initialize for create_op_builder()
268# put all valid class name <--> class type mapping into class_dict
269op_builder_dir = self.op_builder_dir()
270op_builder_module = importlib.import_module(op_builder_dir)
271op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
272for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
273# avoid self references,
274# skip sub_directories which contains ops for other backend(cpu, npu, etc.).
275if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
276os.path.join(op_builder_absolute_path, module_name)):
277module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
278for member_name in module.__dir__():
279if member_name.endswith(
280'Builder'
281) and member_name != "OpBuilder" and member_name != "CPUOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
282if not member_name in self.class_dict:
283self.class_dict[member_name] = getattr(module, member_name)
284# end initialize for create_op_builder()
285
286# create an instance of op builder and return, name specified by class_name
287def create_op_builder(self, class_name):
288self._lazy_init_class_dict()
289if class_name in self.class_dict:
290return self.class_dict[class_name]()
291else:
292return None
293
294# return an op builder class, name specified by class_name
295def get_op_builder(self, class_name):
296self._lazy_init_class_dict()
297if class_name in self.class_dict:
298return self.class_dict[class_name]
299else:
300return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
301
302def build_extension(self):
303from torch.utils.cpp_extension import BuildExtension
304return BuildExtension
305
306def export_envs(self):
307return []
308
309def visible_devices_envs(self):
310# Current way deepspeed set this env var is not applicable with all HPU instances
311# User has to follow instructions in:
312# https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html
313# keeping CUDA_VISIBLE_DEVICES
314return ['CUDA_VISIBLE_DEVICES'] #['HABANA_VISIBLE_MODULES']
315
316def set_visible_devices_envs(self, current_env, local_accelerator_ids):
317for env in self.visible_devices_envs():
318current_env[env] = ",".join(map(str, local_accelerator_ids))
319
320def get_compile_backend(self):
321return self._compile_backend
322
323def set_compile_backend(self, backend):
324supported_backends = torch._dynamo.list_backends(exclude_tags=())
325if backend in supported_backends:
326self._compile_backend = backend
327else:
328raise ValueError(
329f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
330