llvm-project
197 строк · 7.5 Кб
1//===-- runtime/sum.cpp ---------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// Implements SUM for all required operand types and shapes.
10//
11// Real and complex SUM reductions attempt to reduce floating-point
12// cancellation on intermediate results by using "Kahan summation"
13// (basically the same as manual "double-double").
14
15#include "reduction-templates.h"
16#include "flang/Common/float128.h"
17#include "flang/Runtime/reduction.h"
18#include <cfloat>
19#include <cinttypes>
20#include <complex>
21
22namespace Fortran::runtime {
23
24template <typename INTERMEDIATE> class IntegerSumAccumulator {
25public:
26explicit RT_API_ATTRS IntegerSumAccumulator(const Descriptor &array)
27: array_{array} {}
28void RT_API_ATTRS Reinitialize() { sum_ = 0; }
29template <typename A>
30RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
31*p = static_cast<A>(sum_);
32}
33template <typename A>
34RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
35sum_ += *array_.Element<A>(at);
36return true;
37}
38
39private:
40const Descriptor &array_;
41INTERMEDIATE sum_{0};
42};
43
44template <typename INTERMEDIATE> class RealSumAccumulator {
45public:
46explicit RT_API_ATTRS RealSumAccumulator(const Descriptor &array)
47: array_{array} {}
48void RT_API_ATTRS Reinitialize() { sum_ = correction_ = 0; }
49template <typename A> RT_API_ATTRS A Result() const { return sum_; }
50template <typename A>
51RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
52*p = Result<A>();
53}
54template <typename A> RT_API_ATTRS bool Accumulate(A x) {
55// Kahan summation
56auto next{x + correction_};
57auto oldSum{sum_};
58sum_ += next;
59correction_ = (sum_ - oldSum) - next; // algebraically zero
60return true;
61}
62template <typename A>
63RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
64return Accumulate(*array_.Element<A>(at));
65}
66
67private:
68const Descriptor &array_;
69INTERMEDIATE sum_{0.0}, correction_{0.0};
70};
71
72template <typename PART> class ComplexSumAccumulator {
73public:
74explicit RT_API_ATTRS ComplexSumAccumulator(const Descriptor &array)
75: array_{array} {}
76void RT_API_ATTRS Reinitialize() {
77reals_.Reinitialize();
78imaginaries_.Reinitialize();
79}
80template <typename A>
81RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
82using ResultPart = typename A::value_type;
83*p = {reals_.template Result<ResultPart>(),
84imaginaries_.template Result<ResultPart>()};
85}
86template <typename A> RT_API_ATTRS bool Accumulate(const A &z) {
87reals_.Accumulate(z.real());
88imaginaries_.Accumulate(z.imag());
89return true;
90}
91template <typename A>
92RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
93return Accumulate(*array_.Element<A>(at));
94}
95
96private:
97const Descriptor &array_;
98RealSumAccumulator<PART> reals_{array_}, imaginaries_{array_};
99};
100
101extern "C" {
102RT_EXT_API_GROUP_BEGIN
103
104CppTypeFor<TypeCategory::Integer, 1> RTDEF(SumInteger1)(const Descriptor &x,
105const char *source, int line, int dim, const Descriptor *mask) {
106return GetTotalReduction<TypeCategory::Integer, 1>(x, source, line, dim, mask,
107IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
108}
109CppTypeFor<TypeCategory::Integer, 2> RTDEF(SumInteger2)(const Descriptor &x,
110const char *source, int line, int dim, const Descriptor *mask) {
111return GetTotalReduction<TypeCategory::Integer, 2>(x, source, line, dim, mask,
112IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
113}
114CppTypeFor<TypeCategory::Integer, 4> RTDEF(SumInteger4)(const Descriptor &x,
115const char *source, int line, int dim, const Descriptor *mask) {
116return GetTotalReduction<TypeCategory::Integer, 4>(x, source, line, dim, mask,
117IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, "SUM");
118}
119CppTypeFor<TypeCategory::Integer, 8> RTDEF(SumInteger8)(const Descriptor &x,
120const char *source, int line, int dim, const Descriptor *mask) {
121return GetTotalReduction<TypeCategory::Integer, 8>(x, source, line, dim, mask,
122IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 8>>{x}, "SUM");
123}
124#ifdef __SIZEOF_INT128__
125CppTypeFor<TypeCategory::Integer, 16> RTDEF(SumInteger16)(const Descriptor &x,
126const char *source, int line, int dim, const Descriptor *mask) {
127return GetTotalReduction<TypeCategory::Integer, 16>(x, source, line, dim,
128mask, IntegerSumAccumulator<CppTypeFor<TypeCategory::Integer, 16>>{x},
129"SUM");
130}
131#endif
132
133// TODO: real/complex(2 & 3)
134CppTypeFor<TypeCategory::Real, 4> RTDEF(SumReal4)(const Descriptor &x,
135const char *source, int line, int dim, const Descriptor *mask) {
136return GetTotalReduction<TypeCategory::Real, 4>(
137x, source, line, dim, mask, RealSumAccumulator<float>{x}, "SUM");
138}
139CppTypeFor<TypeCategory::Real, 8> RTDEF(SumReal8)(const Descriptor &x,
140const char *source, int line, int dim, const Descriptor *mask) {
141return GetTotalReduction<TypeCategory::Real, 8>(
142x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
143}
144#if LDBL_MANT_DIG == 64
145CppTypeFor<TypeCategory::Real, 10> RTDEF(SumReal10)(const Descriptor &x,
146const char *source, int line, int dim, const Descriptor *mask) {
147return GetTotalReduction<TypeCategory::Real, 10>(
148x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
149}
150#endif
151#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
152CppTypeFor<TypeCategory::Real, 16> RTDEF(SumReal16)(const Descriptor &x,
153const char *source, int line, int dim, const Descriptor *mask) {
154return GetTotalReduction<TypeCategory::Real, 16>(
155x, source, line, dim, mask, RealSumAccumulator<long double>{x}, "SUM");
156}
157#endif
158
159void RTDEF(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
160const Descriptor &x, const char *source, int line, int dim,
161const Descriptor *mask) {
162result = GetTotalReduction<TypeCategory::Complex, 4>(
163x, source, line, dim, mask, ComplexSumAccumulator<float>{x}, "SUM");
164}
165void RTDEF(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
166const Descriptor &x, const char *source, int line, int dim,
167const Descriptor *mask) {
168result = GetTotalReduction<TypeCategory::Complex, 8>(
169x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
170}
171#if LDBL_MANT_DIG == 64
172void RTDEF(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
173const Descriptor &x, const char *source, int line, int dim,
174const Descriptor *mask) {
175result = GetTotalReduction<TypeCategory::Complex, 10>(
176x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
177}
178#endif
179#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
180void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
181const Descriptor &x, const char *source, int line, int dim,
182const Descriptor *mask) {
183result = GetTotalReduction<TypeCategory::Complex, 16>(
184x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
185}
186#endif
187
188void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
189const char *source, int line, const Descriptor *mask) {
190TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
191ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
192result, x, dim, source, line, mask, "SUM");
193}
194
195RT_EXT_API_GROUP_END
196} // extern "C"
197} // namespace Fortran::runtime
198