deepspeed

Форк
0
/
multi_tensor_apply.cuh 
132 строки · 5.5 Кб
1
// Copyright (c) Microsoft Corporation.
2
// SPDX-License-Identifier: Apache-2.0
3

4
// DeepSpeed Team
5

6
/*
7
Copyright NVIDIA/apex
8
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
9
*/
10

11
#include <ATen/ATen.h>
12
#include <ATen/AccumulateType.h>
13
#include <ATen/cuda/CUDAContext.h>
14
#include <ATen/cuda/Exceptions.h>
15
#include <c10/cuda/CUDAGuard.h>
16
#include "compat.h"
17

18
#include <assert.h>
19

20
// #include <iostream>
21

22
// This header is the one-stop shop for all your multi-tensor apply needs.
23

24
// TODO:  Kernel arg size limit may be <4KB for some other cards (ie Jetson)
25
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
26
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
27

28
template <int n>
29
struct TensorListMetadata {
30
    void* addresses[n][depth_to_max_tensors[n - 1]];
31
    int sizes[depth_to_max_tensors[n - 1]];
32
    unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
33
    int block_to_chunk[depth_to_max_blocks[n - 1]];  // I fear this needs to be a full int.
34
    int start_tensor_this_launch;
35
};
36

37
template <typename T, typename U, typename... ArgTypes>
38
__global__ void multi_tensor_apply_kernel(int chunk_size,
39
                                          volatile int* noop_flag,
40
                                          T tl,
41
                                          U callable,
42
                                          ArgTypes... args)
43
{
44
    // Hand the chunk information to the user-supplied functor to process however it likes.
45
    callable(chunk_size, noop_flag, tl, args...);
46
}
47

48
template <int depth, typename T, typename... ArgTypes>
49
void multi_tensor_apply(int block_size,
50
                        int chunk_size,
51
                        const at::Tensor& noop_flag,
52
                        const std::vector<std::vector<at::Tensor>>& tensor_lists,
53
                        T callable,
54
                        ArgTypes... args)
55
{
56
    TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
57
    int len0 = tensor_lists[0].size();
58
    TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
59
    auto ref_device = tensor_lists[0][0].device();
60
    TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
61
    for (int l = 0; l < tensor_lists.size(); l++)  // No range-based for because I need indices
62
    {
63
        TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
64
        for (int t = 0; t < tensor_lists[l].size(); t++) {
65
            // TODO:  Print which tensor fails.
66
            bool contiguous_memory = tensor_lists[l][t].is_contiguous();
67
#ifdef VERSION_GE_1_5
68
            contiguous_memory = (contiguous_memory ||
69
                                 tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
70
#endif
71
            TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
72
            TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
73
                        "A tensor was not on the same device as the first tensor");
74
            TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
75
        }
76
    }
77

78
    int ntensors = tensor_lists[0].size();
79

80
    TensorListMetadata<depth> tl;
81

82
    const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
83
    auto stream = at::cuda::getCurrentCUDAStream();
84

85
    tl.start_tensor_this_launch = 0;
86
    int loc_block_info = 0;
87
    int loc_tensor_info = 0;
88
    for (int t = 0; t < ntensors; t++) {
89
        tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
90
        for (int d = 0; d < depth; d++)
91
            tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
92
        loc_tensor_info++;
93

94
        int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
95

96
        for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
97
            // std::cout << chunks_this_tensor << std::endl;
98
            tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
99
            tl.block_to_chunk[loc_block_info] = chunk;
100
            loc_block_info++;
101

102
            bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
103
                                 chunk == chunks_this_tensor - 1);
104
            bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
105
            bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
106
            if (tensors_full || blocks_full || last_chunk) {
107
                // using accscalar_t = acc_type<scalar_t, true>;
108
                multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
109
                    chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
110

111
                AT_CUDA_CHECK(cudaGetLastError());
112

113
                // Reset.  The control flow possibilities here make my brain hurt.
114
                loc_block_info = 0;
115
                if (chunk == chunks_this_tensor - 1) {
116
                    // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
117
                    // std::endl;
118
                    loc_tensor_info = 0;
119
                    tl.start_tensor_this_launch = t + 1;
120
                } else {
121
                    // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
122
                    // std::endl;
123
                    tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
124
                    for (int d = 0; d < depth; d++)
125
                        tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
126
                    loc_tensor_info = 1;
127
                    tl.start_tensor_this_launch = t;
128
                }
129
            }
130
        }
131
    }
132
}
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.