pytorch

Форк
0
/
complex_math.cpp 
93 строки · 3.1 Кб
1
#include <c10/util/complex.h>
2

3
#include <cmath>
4

5
// Note [ Complex Square root in libc++]
6
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7
// In libc++ complex square root is computed using polar form
8
// This is a reasonably fast algorithm, but can result in significant
9
// numerical errors when arg is close to 0, pi/2, pi, or 3pi/4
10
// In that case provide a more conservative implementation which is
11
// slower but less prone to those kinds of errors
12
// In libstdc++ complex square root yield invalid results
13
// for -x-0.0j unless C99 csqrt/csqrtf fallbacks are used
14

15
#if defined(_LIBCPP_VERSION) || \
16
    (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
17

18
namespace {
19
template <typename T>
20
c10::complex<T> compute_csqrt(const c10::complex<T>& z) {
21
  constexpr auto half = T(.5);
22

23
  // Trust standard library to correctly handle infs and NaNs
24
  if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) ||
25
      std::isnan(z.imag())) {
26
    return static_cast<c10::complex<T>>(
27
        std::sqrt(static_cast<std::complex<T>>(z)));
28
  }
29

30
  // Special case for square root of pure imaginary values
31
  if (z.real() == T(0)) {
32
    if (z.imag() == T(0)) {
33
      return c10::complex<T>(T(0), z.imag());
34
    }
35
    auto v = std::sqrt(half * std::abs(z.imag()));
36
    return c10::complex<T>(v, std::copysign(v, z.imag()));
37
  }
38

39
  // At this point, z is non-zero and finite
40
  if (z.real() >= 0.0) {
41
    auto t = std::sqrt((z.real() + std::abs(z)) * half);
42
    return c10::complex<T>(t, half * (z.imag() / t));
43
  }
44

45
  auto t = std::sqrt((-z.real() + std::abs(z)) * half);
46
  return c10::complex<T>(
47
      half * std::abs(z.imag() / t), std::copysign(t, z.imag()));
48
}
49

50
// Compute complex arccosine using formula from W. Kahan
51
// "Branch Cuts for Complex Elementary Functions" 1986 paper:
52
// cacos(z).re = 2*atan2(sqrt(1-z).re(), sqrt(1+z).re())
53
// cacos(z).im = asinh((sqrt(conj(1+z))*sqrt(1-z)).im())
54
template <typename T>
55
c10::complex<T> compute_cacos(const c10::complex<T>& z) {
56
  auto constexpr one = T(1);
57
  // Trust standard library to correctly handle infs and NaNs
58
  if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) ||
59
      std::isnan(z.imag())) {
60
    return static_cast<c10::complex<T>>(
61
        std::acos(static_cast<std::complex<T>>(z)));
62
  }
63
  auto a = compute_csqrt(c10::complex<T>(one - z.real(), -z.imag()));
64
  auto b = compute_csqrt(c10::complex<T>(one + z.real(), z.imag()));
65
  auto c = compute_csqrt(c10::complex<T>(one + z.real(), -z.imag()));
66
  auto r = T(2) * std::atan2(a.real(), b.real());
67
  // Explicitly unroll (a*c).imag()
68
  auto i = std::asinh(a.real() * c.imag() + a.imag() * c.real());
69
  return c10::complex<T>(r, i);
70
}
71
} // anonymous namespace
72

73
namespace c10_complex_math {
74
namespace _detail {
75
c10::complex<float> sqrt(const c10::complex<float>& in) {
76
  return compute_csqrt(in);
77
}
78

79
c10::complex<double> sqrt(const c10::complex<double>& in) {
80
  return compute_csqrt(in);
81
}
82

83
c10::complex<float> acos(const c10::complex<float>& in) {
84
  return compute_cacos(in);
85
}
86

87
c10::complex<double> acos(const c10::complex<double>& in) {
88
  return compute_cacos(in);
89
}
90

91
} // namespace _detail
92
} // namespace c10_complex_math
93
#endif
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.