llvm-project

Форк
0
/
check-cuda.cpp 
562 строки · 19.5 Кб
1
//===-- lib/Semantics/check-cuda.cpp ----------------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8

9
#include "check-cuda.h"
10
#include "flang/Common/template.h"
11
#include "flang/Evaluate/fold.h"
12
#include "flang/Evaluate/tools.h"
13
#include "flang/Evaluate/traverse.h"
14
#include "flang/Parser/parse-tree-visitor.h"
15
#include "flang/Parser/parse-tree.h"
16
#include "flang/Parser/tools.h"
17
#include "flang/Semantics/expression.h"
18
#include "flang/Semantics/symbol.h"
19
#include "flang/Semantics/tools.h"
20

21
// Once labeled DO constructs have been canonicalized and their parse subtrees
22
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
23
// and merge adjacent CUFKernelDoConstructs and DoConstructs whenever the
24
// CUFKernelDoConstruct doesn't already have an embedded DoConstruct.  Also
25
// emit errors about improper or missing DoConstructs.
26

27
namespace Fortran::parser {
28
struct Mutator {
29
  template <typename A> bool Pre(A &) { return true; }
30
  template <typename A> void Post(A &) {}
31
  bool Pre(Block &);
32
};
33

34
bool Mutator::Pre(Block &block) {
35
  for (auto iter{block.begin()}; iter != block.end(); ++iter) {
36
    if (auto *kernel{Unwrap<CUFKernelDoConstruct>(*iter)}) {
37
      auto &nested{std::get<std::optional<DoConstruct>>(kernel->t)};
38
      if (!nested) {
39
        if (auto next{iter}; ++next != block.end()) {
40
          if (auto *doConstruct{Unwrap<DoConstruct>(*next)}) {
41
            nested = std::move(*doConstruct);
42
            block.erase(next);
43
          }
44
        }
45
      }
46
    } else {
47
      Walk(*iter, *this);
48
    }
49
  }
50
  return false;
51
}
52
} // namespace Fortran::parser
53

54
namespace Fortran::semantics {
55

56
bool CanonicalizeCUDA(parser::Program &program) {
57
  parser::Mutator mutator;
58
  parser::Walk(program, mutator);
59
  return true;
60
}
61

62
using MaybeMsg = std::optional<parser::MessageFormattedText>;
63

64
// Traverses an evaluate::Expr<> in search of unsupported operations
65
// on the device.
66

67
struct DeviceExprChecker
68
    : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
69
  using Result = MaybeMsg;
70
  using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
71
  DeviceExprChecker() : Base(*this) {}
72
  using Base::operator();
73
  Result operator()(const evaluate::ProcedureDesignator &x) const {
74
    if (const Symbol * sym{x.GetInterfaceSymbol()}) {
75
      const auto *subp{
76
          sym->GetUltimate().detailsIf<semantics::SubprogramDetails>()};
77
      if (subp) {
78
        if (auto attrs{subp->cudaSubprogramAttrs()}) {
79
          if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
80
              *attrs == common::CUDASubprogramAttrs::Device) {
81
            return {};
82
          }
83
        }
84
      }
85
    } else if (x.GetSpecificIntrinsic()) {
86
      // TODO(CUDA): Check for unsupported intrinsics here
87
      return {};
88
    }
89
    return parser::MessageFormattedText(
90
        "'%s' may not be called in device code"_err_en_US, x.GetName());
91
  }
92
};
93

94
template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
95
  if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
96
    return DeviceExprChecker{}(expr->typedExpr);
97
  }
98
  return {};
99
}
100

101
template <typename A>
102
static void CheckUnwrappedExpr(
103
    SemanticsContext &context, SourceName at, const A &x) {
104
  if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
105
    if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
106
      context.Say(at, std::move(*msg));
107
    }
108
  }
109
}
110

111
template <bool CUF_KERNEL> struct ActionStmtChecker {
112
  template <typename A> static MaybeMsg WhyNotOk(const A &x) {
113
    if constexpr (ConstraintTrait<A>) {
114
      return WhyNotOk(x.thing);
115
    } else if constexpr (WrapperTrait<A>) {
116
      return WhyNotOk(x.v);
117
    } else if constexpr (UnionTrait<A>) {
118
      return WhyNotOk(x.u);
119
    } else if constexpr (TupleTrait<A>) {
120
      return WhyNotOk(x.t);
121
    } else {
122
      return parser::MessageFormattedText{
123
          "Statement may not appear in device code"_err_en_US};
124
    }
125
  }
126
  template <typename A>
127
  static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
128
    return WhyNotOk(x.value());
129
  }
130
  template <typename... As>
131
  static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
132
    return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
133
  }
134
  template <std::size_t J = 0, typename... As>
135
  static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
136
    if constexpr (J == sizeof...(As)) {
137
      return {};
138
    } else if (auto msg{WhyNotOk(std::get<J>(x))}) {
139
      return msg;
140
    } else {
141
      return WhyNotOk<(J + 1)>(x);
142
    }
143
  }
144
  template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
145
    for (const auto &y : x) {
146
      if (MaybeMsg result{WhyNotOk(y)}) {
147
        return result;
148
      }
149
    }
150
    return {};
151
  }
152
  template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
153
    if (x) {
154
      return WhyNotOk(*x);
155
    } else {
156
      return {};
157
    }
158
  }
159
  template <typename A>
160
  static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
161
    return WhyNotOk(x.statement);
162
  }
163
  template <typename A>
164
  static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
165
    return WhyNotOk(x.statement);
166
  }
167
  static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
168
    return {}; // AllocateObjects are checked elsewhere
169
  }
170
  static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
171
    return parser::MessageFormattedText(
172
        "A coarray may not be allocated on the device"_err_en_US);
173
  }
174
  static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
175
    return {}; // AllocateObjects are checked elsewhere
176
  }
177
  static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
178
    return DeviceExprChecker{}(x.typedAssignment);
179
  }
180
  static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
181
    return DeviceExprChecker{}(x.typedCall);
182
  }
183
  static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
184
  static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
185
    if (auto result{
186
            CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
187
      return result;
188
    }
189
    return WhyNotOk(
190
        std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
191
            .statement);
192
  }
193
  static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
194
    for (const auto &y : x.v) {
195
      if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
196
        return result;
197
      }
198
    }
199
    return {};
200
  }
201
  static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
202
    return DeviceExprChecker{}(x.typedAssignment);
203
  }
204
};
205

206
template <bool IsCUFKernelDo> class DeviceContextChecker {
207
public:
208
  explicit DeviceContextChecker(SemanticsContext &c) : context_{c} {}
209
  void CheckSubprogram(const parser::Name &name, const parser::Block &body) {
210
    if (name.symbol) {
211
      const auto *subp{
212
          name.symbol->GetUltimate().detailsIf<SubprogramDetails>()};
213
      if (subp && subp->moduleInterface()) {
214
        subp = subp->moduleInterface()
215
                   ->GetUltimate()
216
                   .detailsIf<SubprogramDetails>();
217
      }
218
      if (subp &&
219
          subp->cudaSubprogramAttrs().value_or(
220
              common::CUDASubprogramAttrs::Host) !=
221
              common::CUDASubprogramAttrs::Host) {
222
        Check(body);
223
      }
224
    }
225
  }
226
  void Check(const parser::Block &block) {
227
    for (const auto &epc : block) {
228
      Check(epc);
229
    }
230
  }
231

232
private:
233
  void Check(const parser::ExecutionPartConstruct &epc) {
234
    common::visit(
235
        common::visitors{
236
            [&](const parser::ExecutableConstruct &x) { Check(x); },
237
            [&](const parser::Statement<common::Indirection<parser::EntryStmt>>
238
                    &x) {
239
              context_.Say(x.source,
240
                  "Device code may not contain an ENTRY statement"_err_en_US);
241
            },
242
            [](const parser::Statement<common::Indirection<parser::FormatStmt>>
243
                    &) {},
244
            [](const parser::Statement<common::Indirection<parser::DataStmt>>
245
                    &) {},
246
            [](const parser::Statement<
247
                common::Indirection<parser::NamelistStmt>> &) {},
248
            [](const parser::ErrorRecovery &) {},
249
        },
250
        epc.u);
251
  }
252
  void Check(const parser::ExecutableConstruct &ec) {
253
    common::visit(
254
        common::visitors{
255
            [&](const parser::Statement<parser::ActionStmt> &stmt) {
256
              Check(stmt.statement, stmt.source);
257
            },
258
            [&](const common::Indirection<parser::DoConstruct> &x) {
259
              if (const std::optional<parser::LoopControl> &control{
260
                      x.value().GetLoopControl()}) {
261
                common::visit([&](const auto &y) { Check(y); }, control->u);
262
              }
263
              Check(std::get<parser::Block>(x.value().t));
264
            },
265
            [&](const common::Indirection<parser::BlockConstruct> &x) {
266
              Check(std::get<parser::Block>(x.value().t));
267
            },
268
            [&](const common::Indirection<parser::IfConstruct> &x) {
269
              Check(x.value());
270
            },
271
            [&](const auto &x) {
272
              if (auto source{parser::GetSource(x)}) {
273
                context_.Say(*source,
274
                    "Statement may not appear in device code"_err_en_US);
275
              }
276
            },
277
        },
278
        ec.u);
279
  }
280
  template <typename SEEK, typename A>
281
  static const SEEK *GetIOControl(const A &stmt) {
282
    for (const auto &spec : stmt.controls) {
283
      if (const auto *result{std::get_if<SEEK>(&spec.u)}) {
284
        return result;
285
      }
286
    }
287
    return nullptr;
288
  }
289
  template <typename A> static bool IsInternalIO(const A &stmt) {
290
    if (stmt.iounit.has_value()) {
291
      return std::holds_alternative<Fortran::parser::Variable>(stmt.iounit->u);
292
    }
293
    if (auto *unit{GetIOControl<Fortran::parser::IoUnit>(stmt)}) {
294
      return std::holds_alternative<Fortran::parser::Variable>(unit->u);
295
    }
296
    return false;
297
  }
298
  void WarnOnIoStmt(const parser::CharBlock &source) {
299
    if (context_.ShouldWarn(common::UsageWarning::CUDAUsage)) {
300
      context_.Say(
301
          source, "I/O statement might not be supported on device"_warn_en_US);
302
    }
303
  }
304
  template <typename A>
305
  void WarnIfNotInternal(const A &stmt, const parser::CharBlock &source) {
306
    if (!IsInternalIO(stmt)) {
307
      WarnOnIoStmt(source);
308
    }
309
  }
310
  void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {
311
    common::visit(
312
        common::visitors{
313
            [&](const common::Indirection<parser::PrintStmt> &) {},
314
            [&](const common::Indirection<parser::WriteStmt> &x) {
315
              if (x.value().format) { // Formatted write to '*' or '6'
316
                if (std::holds_alternative<Fortran::parser::Star>(
317
                        x.value().format->u)) {
318
                  if (x.value().iounit) {
319
                    if (std::holds_alternative<Fortran::parser::Star>(
320
                            x.value().iounit->u)) {
321
                      return;
322
                    }
323
                  }
324
                }
325
              }
326
              WarnIfNotInternal(x.value(), source);
327
            },
328
            [&](const common::Indirection<parser::CloseStmt> &x) {
329
              WarnOnIoStmt(source);
330
            },
331
            [&](const common::Indirection<parser::EndfileStmt> &x) {
332
              WarnOnIoStmt(source);
333
            },
334
            [&](const common::Indirection<parser::OpenStmt> &x) {
335
              WarnOnIoStmt(source);
336
            },
337
            [&](const common::Indirection<parser::ReadStmt> &x) {
338
              WarnIfNotInternal(x.value(), source);
339
            },
340
            [&](const common::Indirection<parser::InquireStmt> &x) {
341
              WarnOnIoStmt(source);
342
            },
343
            [&](const common::Indirection<parser::RewindStmt> &x) {
344
              WarnOnIoStmt(source);
345
            },
346
            [&](const common::Indirection<parser::BackspaceStmt> &x) {
347
              WarnOnIoStmt(source);
348
            },
349
            [&](const common::Indirection<parser::IfStmt> &x) {
350
              Check(x.value());
351
            },
352
            [&](const auto &x) {
353
              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
354
                context_.Say(source, std::move(*msg));
355
              }
356
            },
357
        },
358
        stmt.u);
359
  }
360
  void Check(const parser::IfConstruct &ic) {
361
    const auto &ifS{std::get<parser::Statement<parser::IfThenStmt>>(ic.t)};
362
    CheckUnwrappedExpr(context_, ifS.source,
363
        std::get<parser::ScalarLogicalExpr>(ifS.statement.t));
364
    Check(std::get<parser::Block>(ic.t));
365
    for (const auto &eib :
366
        std::get<std::list<parser::IfConstruct::ElseIfBlock>>(ic.t)) {
367
      const auto &eIfS{std::get<parser::Statement<parser::ElseIfStmt>>(eib.t)};
368
      CheckUnwrappedExpr(context_, eIfS.source,
369
          std::get<parser::ScalarLogicalExpr>(eIfS.statement.t));
370
      Check(std::get<parser::Block>(eib.t));
371
    }
372
    if (const auto &eb{
373
            std::get<std::optional<parser::IfConstruct::ElseBlock>>(ic.t)}) {
374
      Check(std::get<parser::Block>(eb->t));
375
    }
376
  }
377
  void Check(const parser::IfStmt &is) {
378
    const auto &uS{
379
        std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)};
380
    CheckUnwrappedExpr(
381
        context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t));
382
    Check(uS.statement, uS.source);
383
  }
384
  void Check(const parser::LoopControl::Bounds &bounds) {
385
    Check(bounds.lower);
386
    Check(bounds.upper);
387
    if (bounds.step) {
388
      Check(*bounds.step);
389
    }
390
  }
391
  void Check(const parser::LoopControl::Concurrent &x) {
392
    const auto &header{std::get<parser::ConcurrentHeader>(x.t)};
393
    for (const auto &cc :
394
        std::get<std::list<parser::ConcurrentControl>>(header.t)) {
395
      Check(std::get<1>(cc.t));
396
      Check(std::get<2>(cc.t));
397
      if (const auto &step{
398
              std::get<std::optional<parser::ScalarIntExpr>>(cc.t)}) {
399
        Check(*step);
400
      }
401
    }
402
    if (const auto &mask{
403
            std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
404
      Check(*mask);
405
    }
406
  }
407
  void Check(const parser::ScalarLogicalExpr &x) {
408
    Check(DEREF(parser::Unwrap<parser::Expr>(x)));
409
  }
410
  void Check(const parser::ScalarIntExpr &x) {
411
    Check(DEREF(parser::Unwrap<parser::Expr>(x)));
412
  }
413
  void Check(const parser::ScalarExpr &x) {
414
    Check(DEREF(parser::Unwrap<parser::Expr>(x)));
415
  }
416
  void Check(const parser::Expr &expr) {
417
    if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
418
      context_.Say(expr.source, std::move(*msg));
419
    }
420
  }
421

422
  SemanticsContext &context_;
423
};
424

425
void CUDAChecker::Enter(const parser::SubroutineSubprogram &x) {
426
  DeviceContextChecker<false>{context_}.CheckSubprogram(
427
      std::get<parser::Name>(
428
          std::get<parser::Statement<parser::SubroutineStmt>>(x.t).statement.t),
429
      std::get<parser::ExecutionPart>(x.t).v);
430
}
431

432
void CUDAChecker::Enter(const parser::FunctionSubprogram &x) {
433
  DeviceContextChecker<false>{context_}.CheckSubprogram(
434
      std::get<parser::Name>(
435
          std::get<parser::Statement<parser::FunctionStmt>>(x.t).statement.t),
436
      std::get<parser::ExecutionPart>(x.t).v);
437
}
438

439
void CUDAChecker::Enter(const parser::SeparateModuleSubprogram &x) {
440
  DeviceContextChecker<false>{context_}.CheckSubprogram(
441
      std::get<parser::Statement<parser::MpSubprogramStmt>>(x.t).statement.v,
442
      std::get<parser::ExecutionPart>(x.t).v);
443
}
444

445
// !$CUF KERNEL DO semantic checks
446

447
static int DoConstructTightNesting(
448
    const parser::DoConstruct *doConstruct, const parser::Block *&innerBlock) {
449
  if (!doConstruct || !doConstruct->IsDoNormal()) {
450
    return 0;
451
  }
452
  innerBlock = &std::get<parser::Block>(doConstruct->t);
453
  if (innerBlock->size() == 1) {
454
    if (const auto *execConstruct{
455
            std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) {
456
      if (const auto *next{
457
              std::get_if<common::Indirection<parser::DoConstruct>>(
458
                  &execConstruct->u)}) {
459
        return 1 + DoConstructTightNesting(&next->value(), innerBlock);
460
      }
461
    }
462
  }
463
  return 1;
464
}
465

466
static void CheckReduce(
467
    SemanticsContext &context, const parser::CUFReduction &reduce) {
468
  auto op{std::get<parser::CUFReduction::Operator>(reduce.t).v};
469
  for (const auto &var :
470
      std::get<std::list<parser::Scalar<parser::Variable>>>(reduce.t)) {
471
    if (const auto &typedExprPtr{var.thing.typedExpr};
472
        typedExprPtr && typedExprPtr->v) {
473
      const auto &expr{*typedExprPtr->v};
474
      if (auto type{expr.GetType()}) {
475
        auto cat{type->category()};
476
        bool isOk{false};
477
        switch (op) {
478
        case parser::ReductionOperator::Operator::Plus:
479
        case parser::ReductionOperator::Operator::Multiply:
480
        case parser::ReductionOperator::Operator::Max:
481
        case parser::ReductionOperator::Operator::Min:
482
          isOk = cat == TypeCategory::Integer || cat == TypeCategory::Real;
483
          break;
484
        case parser::ReductionOperator::Operator::Iand:
485
        case parser::ReductionOperator::Operator::Ior:
486
        case parser::ReductionOperator::Operator::Ieor:
487
          isOk = cat == TypeCategory::Integer;
488
          break;
489
        case parser::ReductionOperator::Operator::And:
490
        case parser::ReductionOperator::Operator::Or:
491
        case parser::ReductionOperator::Operator::Eqv:
492
        case parser::ReductionOperator::Operator::Neqv:
493
          isOk = cat == TypeCategory::Logical;
494
          break;
495
        }
496
        if (!isOk) {
497
          context.Say(var.thing.GetSource(),
498
              "!$CUF KERNEL DO REDUCE operation is not acceptable for a variable with type %s"_err_en_US,
499
              type->AsFortran());
500
        }
501
      }
502
    }
503
  }
504
}
505

506
void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
507
  auto source{std::get<parser::CUFKernelDoConstruct::Directive>(x.t).source};
508
  const auto &directive{std::get<parser::CUFKernelDoConstruct::Directive>(x.t)};
509
  std::int64_t depth{1};
510
  if (auto expr{AnalyzeExpr(context_,
511
          std::get<std::optional<parser::ScalarIntConstantExpr>>(
512
              directive.t))}) {
513
    depth = evaluate::ToInt64(expr).value_or(0);
514
    if (depth <= 0) {
515
      context_.Say(source,
516
          "!$CUF KERNEL DO (%jd): loop nesting depth must be positive"_err_en_US,
517
          std::intmax_t{depth});
518
      depth = 1;
519
    }
520
  }
521
  const parser::DoConstruct *doConstruct{common::GetPtrFromOptional(
522
      std::get<std::optional<parser::DoConstruct>>(x.t))};
523
  const parser::Block *innerBlock{nullptr};
524
  if (DoConstructTightNesting(doConstruct, innerBlock) < depth) {
525
    context_.Say(source,
526
        "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
527
        std::intmax_t{depth});
528
  }
529
  if (innerBlock) {
530
    DeviceContextChecker<true>{context_}.Check(*innerBlock);
531
  }
532
  for (const auto &reduce :
533
      std::get<std::list<parser::CUFReduction>>(directive.t)) {
534
    CheckReduce(context_, reduce);
535
  }
536
}
537

538
void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
539
  auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()};
540
  const auto &scope{context_.FindScope(lhsLoc)};
541
  const Scope &progUnit{GetProgramUnitContaining(scope)};
542
  if (IsCUDADeviceContext(&progUnit)) {
543
    return; // Data transfer with assignment is only perform on host.
544
  }
545

546
  const evaluate::Assignment *assign{semantics::GetAssignment(x)};
547
  if (!assign) {
548
    return;
549
  }
550

551
  int nbLhs{evaluate::GetNbOfCUDADeviceSymbols(assign->lhs)};
552
  int nbRhs{evaluate::GetNbOfCUDADeviceSymbols(assign->rhs)};
553

554
  // device to host transfer with more than one device object on the rhs is not
555
  // legal.
556
  if (nbLhs == 0 && nbRhs > 1) {
557
    context_.Say(lhsLoc,
558
        "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
559
  }
560
}
561

562
} // namespace Fortran::semantics
563

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.