llvm-project

Форк
0
/
dot-product.cpp 
243 строки · 9.7 Кб
1
//===-- runtime/dot-product.cpp -------------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8

9
#include "float.h"
10
#include "terminator.h"
11
#include "tools.h"
12
#include "flang/Common/float128.h"
13
#include "flang/Runtime/cpp-type.h"
14
#include "flang/Runtime/descriptor.h"
15
#include "flang/Runtime/reduction.h"
16
#include <cfloat>
17
#include <cinttypes>
18

19
namespace Fortran::runtime {
20

21
// Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22
// argument; MATMUL does not.
23

24
// Suppress the warnings about calling __host__-only std::complex operators,
25
// defined in C++ STD header files, from __device__ code.
26
RT_DIAG_PUSH
27
RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
28

29
// General accumulator for any type and stride; this is not used for
30
// contiguous numeric vectors.
31
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
32
class Accumulator {
33
public:
34
  using Result = AccumulationType<RCAT, RKIND>;
35
  RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
36
      : x_{x}, y_{y} {}
37
  RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
38
    if constexpr (RCAT == TypeCategory::Logical) {
39
      sum_ = sum_ ||
40
          (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
41
    } else {
42
      const XT &xElement{*x_.Element<XT>(&xAt)};
43
      const YT &yElement{*y_.Element<YT>(&yAt)};
44
      if constexpr (RCAT == TypeCategory::Complex) {
45
        sum_ += std::conj(static_cast<Result>(xElement)) *
46
            static_cast<Result>(yElement);
47
      } else {
48
        sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
49
      }
50
    }
51
  }
52
  RT_API_ATTRS Result GetResult() const { return sum_; }
53

54
private:
55
  const Descriptor &x_, &y_;
56
  Result sum_{};
57
};
58

59
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
60
static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct(
61
    const Descriptor &x, const Descriptor &y, Terminator &terminator) {
62
  using Result = CppTypeFor<RCAT, RKIND>;
63
  RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
64
  SubscriptValue n{x.GetDimension(0).Extent()};
65
  if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
66
    terminator.Crash(
67
        "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
68
        static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
69
  }
70
  if constexpr (RCAT != TypeCategory::Logical) {
71
    if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
72
        y.GetDimension(0).ByteStride() == sizeof(YT)) {
73
      // Contiguous numeric vectors
74
      if constexpr (std::is_same_v<XT, YT>) {
75
        // Contiguous homogeneous numeric vectors
76
        if constexpr (std::is_same_v<XT, float>) {
77
          // TODO: call BLAS-1 SDOT or SDSDOT
78
        } else if constexpr (std::is_same_v<XT, double>) {
79
          // TODO: call BLAS-1 DDOT
80
        } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
81
          // TODO: call BLAS-1 CDOTC
82
        } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
83
          // TODO: call BLAS-1 ZDOTC
84
        }
85
      }
86
      XT *xp{x.OffsetElement<XT>(0)};
87
      YT *yp{y.OffsetElement<YT>(0)};
88
      using AccumType = AccumulationType<RCAT, RKIND>;
89
      AccumType accum{};
90
      if constexpr (RCAT == TypeCategory::Complex) {
91
        for (SubscriptValue j{0}; j < n; ++j) {
92
          // std::conj() may instantiate its argument twice,
93
          // so xp has to be incremented separately.
94
          // This is a workaround for an alleged bug in clang,
95
          // that shows up as:
96
          //   warning: multiple unsequenced modifications to 'xp'
97
          accum += std::conj(static_cast<AccumType>(*xp)) *
98
              static_cast<AccumType>(*yp++);
99
          xp++;
100
        }
101
      } else {
102
        for (SubscriptValue j{0}; j < n; ++j) {
103
          accum +=
104
              static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
105
        }
106
      }
107
      return static_cast<Result>(accum);
108
    }
109
  }
110
  // Non-contiguous, heterogeneous, & LOGICAL cases
111
  SubscriptValue xAt{x.GetDimension(0).LowerBound()};
112
  SubscriptValue yAt{y.GetDimension(0).LowerBound()};
113
  Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
114
  for (SubscriptValue j{0}; j < n; ++j) {
115
    accumulator.AccumulateIndexed(xAt++, yAt++);
116
  }
117
  return static_cast<Result>(accumulator.GetResult());
118
}
119

120
RT_DIAG_POP
121

122
template <TypeCategory RCAT, int RKIND> struct DotProduct {
123
  using Result = CppTypeFor<RCAT, RKIND>;
124
  template <TypeCategory XCAT, int XKIND> struct DP1 {
125
    template <TypeCategory YCAT, int YKIND> struct DP2 {
126
      RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
127
          Terminator &terminator) const {
128
        if constexpr (constexpr auto resultType{
129
                          GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
130
          if constexpr (resultType->first == RCAT &&
131
              (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
132
            return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
133
                CppTypeFor<YCAT, YKIND>>(x, y, terminator);
134
          }
135
        }
136
        terminator.Crash(
137
            "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
138
            static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
139
            static_cast<int>(YCAT), YKIND);
140
      }
141
    };
142
    RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
143
        Terminator &terminator, TypeCategory yCat, int yKind) const {
144
      return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
145
    }
146
  };
147
  RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
148
      const char *source, int line) const {
149
    Terminator terminator{source, line};
150
    if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
151
      // No conversions needed, operands and result have same known type
152
      return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
153
          x, y, terminator);
154
    } else {
155
      auto xCatKind{x.type().GetCategoryAndKind()};
156
      auto yCatKind{y.type().GetCategoryAndKind()};
157
      RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
158
      return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
159
          terminator, x, y, terminator, yCatKind->first, yCatKind->second);
160
    }
161
  }
162
};
163

164
extern "C" {
165
RT_EXT_API_GROUP_BEGIN
166

167
CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)(
168
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
169
  return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
170
}
171
CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)(
172
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
173
  return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
174
}
175
CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)(
176
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
177
  return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
178
}
179
CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)(
180
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
181
  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
182
}
183
#ifdef __SIZEOF_INT128__
184
CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)(
185
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
186
  return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
187
}
188
#endif
189

190
// TODO: REAL/COMPLEX(2 & 3)
191
// Intermediate results and operations are at least 64 bits
192
CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)(
193
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
194
  return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
195
}
196
CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)(
197
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
198
  return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
199
}
200
#if LDBL_MANT_DIG == 64
201
CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)(
202
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
203
  return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
204
}
205
#endif
206
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
207
CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)(
208
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
209
  return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
210
}
211
#endif
212

213
void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
214
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
215
  result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
216
}
217
void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
218
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
219
  result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
220
}
221
#if LDBL_MANT_DIG == 64
222
void RTDEF(CppDotProductComplex10)(
223
    CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
224
    const Descriptor &y, const char *source, int line) {
225
  result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
226
}
227
#endif
228
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
229
void RTDEF(CppDotProductComplex16)(
230
    CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
231
    const Descriptor &y, const char *source, int line) {
232
  result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
233
}
234
#endif
235

236
bool RTDEF(DotProductLogical)(
237
    const Descriptor &x, const Descriptor &y, const char *source, int line) {
238
  return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
239
}
240

241
RT_EXT_API_GROUP_END
242
} // extern "C"
243
} // namespace Fortran::runtime
244

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.