llvm-project

Форк
0
197 строк · 7.5 Кб
1
//===-- runtime/sum.cpp ---------------------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8

9
// Implements SUM for all required operand types and shapes.
10
//
11
// Real and complex SUM reductions attempt to reduce floating-point
12
// cancellation on intermediate results by using "Kahan summation"
13
// (basically the same as manual "double-double").
14

15
#include "reduction-templates.h"
16
#include "flang/Common/float128.h"
17
#include "flang/Runtime/reduction.h"
18
#include <cfloat>
19
#include <cinttypes>
20
#include <complex>
21

22
namespace Fortran::runtime {
23

24
template <typename INTERMEDIATE> class IntegerSumAccumulator {
25
public:
26
  explicit RT_API_ATTRS IntegerSumAccumulator(const Descriptor &array)
27
      : array_{array} {}
28
  void RT_API_ATTRS Reinitialize() { sum_ = 0; }
29
  template <typename A>
30
  RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
31
    *p = static_cast<A>(sum_);
32
  }
33
  template <typename A>
34
  RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
35
    sum_ += *array_.Element<A>(at);
36
    return true;
37
  }
38

39
private:
40
  const Descriptor &array_;
41
  INTERMEDIATE sum_{0};
42
};
43

44
template <typename INTERMEDIATE> class RealSumAccumulator {
45
public:
46
  explicit RT_API_ATTRS RealSumAccumulator(const Descriptor &array)
47
      : array_{array} {}
48
  void RT_API_ATTRS Reinitialize() { sum_ = correction_ = 0; }
49
  template <typename A> RT_API_ATTRS A Result() const { return sum_; }
50
  template <typename A>
51
  RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
52
    *p = Result<A>();
53
  }
54
  template <typename A> RT_API_ATTRS bool Accumulate(A x) {
55
    // Kahan summation
56
    auto next{x + correction_};
57
    auto oldSum{sum_};
58
    sum_ += next;
59
    correction_ = (sum_ - oldSum) - next; // algebraically zero
60
    return true;
61
  }
62
  template <typename A>
63
  RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
64
    return Accumulate(*array_.Element<A>(at));
65
  }
66

67
private:
68
  const Descriptor &array_;
69
  INTERMEDIATE sum_{0.0}, correction_{0.0};
70
};
71

72
template <typename PART> class ComplexSumAccumulator {
73
public:
74
  explicit RT_API_ATTRS ComplexSumAccumulator(const Descriptor &array)
75
      : array_{array} {}
76
  void RT_API_ATTRS Reinitialize() {
77
    reals_.Reinitialize();
78
    imaginaries_.Reinitialize();
79
  }
80
  template <typename A>
81
  RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
82
    using ResultPart = typename A::value_type;
83
    *p = {reals_.template Result<ResultPart>(),
84
        imaginaries_.template Result<ResultPart>()};
85
  }
86
  template <typename A> RT_API_ATTRS bool Accumulate(const A &z) {
87
    reals_.Accumulate(z.real());
88
    imaginaries_.Accumulate(z.imag());
89
    return true;
90
  }
91
  template <typename A>
92
  RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
93
    return Accumulate(*array_.Element<A>(at));
94
  }
95

96
private:
97
  const Descriptor &array_;
98
  RealSumAccumulator<PART> reals_{array_}, imaginaries_{array_};
99
};
100

101
extern "C" {
102
RT_EXT_API_GROUP_BEGIN
103

104
CppTypeFor<TypeCategory::Integer, 1> RTDEF(SumInteger1)(const Descriptor &x,
105
    const char *source, int line, int dim, const Descriptor *mask) {
106
  return GetTotalReduction<TypeCategory::Integer, 1>(x, source, line, dim, mask,
107
      IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
108
}
109
CppTypeFor<TypeCategory::Integer, 2> RTDEF(SumInteger2)(const Descriptor &x,
110
    const char *source, int line, int dim, const Descriptor *mask) {
111
  return GetTotalReduction<TypeCategory::Integer, 2>(x, source, line, dim, mask,
112
      IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
113
}
114
CppTypeFor<TypeCategory::Integer, 4> RTDEF(SumInteger4)(const Descriptor &x,
115
    const char *source, int line, int dim, const Descriptor *mask) {
116
  return GetTotalReduction<TypeCategory::Integer, 4>(x, source, line, dim, mask,
117
      IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
118
}
119
CppTypeFor<TypeCategory::Integer, 8> RTDEF(SumInteger8)(const Descriptor &x,
120
    const char *source, int line, int dim, const Descriptor *mask) {
121
  return GetTotalReduction<TypeCategory::Integer, 8>(x, source, line, dim, mask,
122
      IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 8>>{x}, "SUM");
123
}
124
#ifdef __SIZEOF_INT128__
125
CppTypeFor<TypeCategory::Integer, 16> RTDEF(SumInteger16)(const Descriptor &x,
126
    const char *source, int line, int dim, const Descriptor *mask) {
127
  return GetTotalReduction<TypeCategory::Integer, 16>(x, source, line, dim,
128
      mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 16>>{x},
129
      "SUM");
130
}
131
#endif
132

133
// TODO: real/complex(2 & 3)
134
CppTypeFor<TypeCategory::Real, 4> RTDEF(SumReal4)(const Descriptor &x,
135
    const char *source, int line, int dim, const Descriptor *mask) {
136
  return GetTotalReduction<TypeCategory::Real, 4>(
137
      x, source, line, dim, mask, RealSumAccumulator<float>{x}, "SUM");
138
}
139
CppTypeFor<TypeCategory::Real, 8> RTDEF(SumReal8)(const Descriptor &x,
140
    const char *source, int line, int dim, const Descriptor *mask) {
141
  return GetTotalReduction<TypeCategory::Real, 8>(
142
      x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
143
}
144
#if LDBL_MANT_DIG == 64
145
CppTypeFor<TypeCategory::Real, 10> RTDEF(SumReal10)(const Descriptor &x,
146
    const char *source, int line, int dim, const Descriptor *mask) {
147
  return GetTotalReduction<TypeCategory::Real, 10>(
148
      x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
149
}
150
#endif
151
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
152
CppTypeFor<TypeCategory::Real, 16> RTDEF(SumReal16)(const Descriptor &x,
153
    const char *source, int line, int dim, const Descriptor *mask) {
154
  return GetTotalReduction<TypeCategory::Real, 16>(
155
      x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
156
}
157
#endif
158

159
void RTDEF(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
160
    const Descriptor &x, const char *source, int line, int dim,
161
    const Descriptor *mask) {
162
  result = GetTotalReduction<TypeCategory::Complex, 4>(
163
      x, source, line, dim, mask, ComplexSumAccumulator<float>{x}, "SUM");
164
}
165
void RTDEF(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
166
    const Descriptor &x, const char *source, int line, int dim,
167
    const Descriptor *mask) {
168
  result = GetTotalReduction<TypeCategory::Complex, 8>(
169
      x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
170
}
171
#if LDBL_MANT_DIG == 64
172
void RTDEF(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
173
    const Descriptor &x, const char *source, int line, int dim,
174
    const Descriptor *mask) {
175
  result = GetTotalReduction<TypeCategory::Complex, 10>(
176
      x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
177
}
178
#endif
179
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
180
void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
181
    const Descriptor &x, const char *source, int line, int dim,
182
    const Descriptor *mask) {
183
  result = GetTotalReduction<TypeCategory::Complex, 16>(
184
      x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
185
}
186
#endif
187

188
void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
189
    const char *source, int line, const Descriptor *mask) {
190
  TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
191
      ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
192
      result, x, dim, source, line, mask, "SUM");
193
}
194

195
RT_EXT_API_GROUP_END
196
} // extern "C"
197
} // namespace Fortran::runtime
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.