DeepSpeed
Зеркало из https://github.com/microsoft/DeepSpeed
52 строки · 1.5 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6import torch
7from deepspeed.pt.deepspeed_linear import LinearModuleForZeroStage3
8from deepspeed.pt.log_utils import logger
9from deepspeed.accelerator import get_accelerator
10
11
12def see_memory_usage(message):
13
14# Print message except when distributed but not rank 0
15logger.info(message)
16logger.info(
17"Memory Allocated %s GigaBytes ",
18get_accelerator().memory_allocated() / (1024 * 1024 * 1024),
19)
20logger.info(
21"Max Memory Allocated %s GigaBytes",
22get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),
23)
24logger.info(
25"Cache Allocated %s GigaBytes",
26get_accelerator().memory_cached() / (1024 * 1024 * 1024),
27)
28logger.info(
29"Max cache Allocated %s GigaBytes",
30get_accelerator().max_memory_cached() / (1024 * 1024 * 1024),
31)
32
33
34tens = torch.rand(1024, 16384, dtype=torch.half, device=torch.device(get_accelerator().device_name()))
35tens_back = tens.detach().clone()
36
37#linear_bk = torch.nn.functional.linear
38#torch.nn.functional.linear = deepspeed.pt.deepspeed_linear.LinearFunctionForZeroStage3.apply
39model = LinearModuleForZeroStage3(16384, 16384)
40
41model.to(get_accelerator().device_name()).half()
42
43see_memory_usage("Before forward")
44y = model(tens)
45
46see_memory_usage("After forward")
47
48model.weight.data = torch.zeros(1, dtype=torch.half, device=torch.device(get_accelerator().device_name()))
49
50see_memory_usage("After weight zero")
51
52y.backward(tens_back)
53