llvm-project
562 строки · 19.5 Кб
1//===-- lib/Semantics/check-cuda.cpp ----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "check-cuda.h"10#include "flang/Common/template.h"11#include "flang/Evaluate/fold.h"12#include "flang/Evaluate/tools.h"13#include "flang/Evaluate/traverse.h"14#include "flang/Parser/parse-tree-visitor.h"15#include "flang/Parser/parse-tree.h"16#include "flang/Parser/tools.h"17#include "flang/Semantics/expression.h"18#include "flang/Semantics/symbol.h"19#include "flang/Semantics/tools.h"20
21// Once labeled DO constructs have been canonicalized and their parse subtrees
22// transformed into parser::DoConstructs, scan the parser::Blocks of the program
23// and merge adjacent CUFKernelDoConstructs and DoConstructs whenever the
24// CUFKernelDoConstruct doesn't already have an embedded DoConstruct. Also
25// emit errors about improper or missing DoConstructs.
26
27namespace Fortran::parser {28struct Mutator {29template <typename A> bool Pre(A &) { return true; }30template <typename A> void Post(A &) {}31bool Pre(Block &);32};33
34bool Mutator::Pre(Block &block) {35for (auto iter{block.begin()}; iter != block.end(); ++iter) {36if (auto *kernel{Unwrap<CUFKernelDoConstruct>(*iter)}) {37auto &nested{std::get<std::optional<DoConstruct>>(kernel->t)};38if (!nested) {39if (auto next{iter}; ++next != block.end()) {40if (auto *doConstruct{Unwrap<DoConstruct>(*next)}) {41nested = std::move(*doConstruct);42block.erase(next);43}44}45}46} else {47Walk(*iter, *this);48}49}50return false;51}
52} // namespace Fortran::parser53
54namespace Fortran::semantics {55
56bool CanonicalizeCUDA(parser::Program &program) {57parser::Mutator mutator;58parser::Walk(program, mutator);59return true;60}
61
62using MaybeMsg = std::optional<parser::MessageFormattedText>;63
64// Traverses an evaluate::Expr<> in search of unsupported operations
65// on the device.
66
67struct DeviceExprChecker68: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {69using Result = MaybeMsg;70using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;71DeviceExprChecker() : Base(*this) {}72using Base::operator();73Result operator()(const evaluate::ProcedureDesignator &x) const {74if (const Symbol * sym{x.GetInterfaceSymbol()}) {75const auto *subp{76sym->GetUltimate().detailsIf<semantics::SubprogramDetails>()};77if (subp) {78if (auto attrs{subp->cudaSubprogramAttrs()}) {79if (*attrs == common::CUDASubprogramAttrs::HostDevice ||80*attrs == common::CUDASubprogramAttrs::Device) {81return {};82}83}84}85} else if (x.GetSpecificIntrinsic()) {86// TODO(CUDA): Check for unsupported intrinsics here87return {};88}89return parser::MessageFormattedText(90"'%s' may not be called in device code"_err_en_US, x.GetName());91}92};93
94template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {95if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {96return DeviceExprChecker{}(expr->typedExpr);97}98return {};99}
100
101template <typename A>102static void CheckUnwrappedExpr(103SemanticsContext &context, SourceName at, const A &x) {104if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {105if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {106context.Say(at, std::move(*msg));107}108}109}
110
111template <bool CUF_KERNEL> struct ActionStmtChecker {112template <typename A> static MaybeMsg WhyNotOk(const A &x) {113if constexpr (ConstraintTrait<A>) {114return WhyNotOk(x.thing);115} else if constexpr (WrapperTrait<A>) {116return WhyNotOk(x.v);117} else if constexpr (UnionTrait<A>) {118return WhyNotOk(x.u);119} else if constexpr (TupleTrait<A>) {120return WhyNotOk(x.t);121} else {122return parser::MessageFormattedText{123"Statement may not appear in device code"_err_en_US};124}125}126template <typename A>127static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {128return WhyNotOk(x.value());129}130template <typename... As>131static MaybeMsg WhyNotOk(const std::variant<As...> &x) {132return common::visit([](const auto &x) { return WhyNotOk(x); }, x);133}134template <std::size_t J = 0, typename... As>135static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {136if constexpr (J == sizeof...(As)) {137return {};138} else if (auto msg{WhyNotOk(std::get<J>(x))}) {139return msg;140} else {141return WhyNotOk<(J + 1)>(x);142}143}144template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {145for (const auto &y : x) {146if (MaybeMsg result{WhyNotOk(y)}) {147return result;148}149}150return {};151}152template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {153if (x) {154return WhyNotOk(*x);155} else {156return {};157}158}159template <typename A>160static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {161return WhyNotOk(x.statement);162}163template <typename A>164static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {165return WhyNotOk(x.statement);166}167static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {168return {}; // AllocateObjects are checked elsewhere169}170static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {171return parser::MessageFormattedText(172"A coarray may not be allocated on the device"_err_en_US);173}174static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {175return {}; // AllocateObjects are checked elsewhere176}177static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {178return DeviceExprChecker{}(x.typedAssignment);179}180static MaybeMsg WhyNotOk(const parser::CallStmt &x) {181return DeviceExprChecker{}(x.typedCall);182}183static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }184static MaybeMsg WhyNotOk(const parser::IfStmt &x) {185if (auto result{186CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {187return result;188}189return WhyNotOk(190std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)191.statement);192}193static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {194for (const auto &y : x.v) {195if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {196return result;197}198}199return {};200}201static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {202return DeviceExprChecker{}(x.typedAssignment);203}204};205
206template <bool IsCUFKernelDo> class DeviceContextChecker {207public:208explicit DeviceContextChecker(SemanticsContext &c) : context_{c} {}209void CheckSubprogram(const parser::Name &name, const parser::Block &body) {210if (name.symbol) {211const auto *subp{212name.symbol->GetUltimate().detailsIf<SubprogramDetails>()};213if (subp && subp->moduleInterface()) {214subp = subp->moduleInterface()215->GetUltimate()216.detailsIf<SubprogramDetails>();217}218if (subp &&219subp->cudaSubprogramAttrs().value_or(220common::CUDASubprogramAttrs::Host) !=221common::CUDASubprogramAttrs::Host) {222Check(body);223}224}225}226void Check(const parser::Block &block) {227for (const auto &epc : block) {228Check(epc);229}230}231
232private:233void Check(const parser::ExecutionPartConstruct &epc) {234common::visit(235common::visitors{236[&](const parser::ExecutableConstruct &x) { Check(x); },237[&](const parser::Statement<common::Indirection<parser::EntryStmt>>238&x) {239context_.Say(x.source,240"Device code may not contain an ENTRY statement"_err_en_US);241},242[](const parser::Statement<common::Indirection<parser::FormatStmt>>243&) {},244[](const parser::Statement<common::Indirection<parser::DataStmt>>245&) {},246[](const parser::Statement<247common::Indirection<parser::NamelistStmt>> &) {},248[](const parser::ErrorRecovery &) {},249},250epc.u);251}252void Check(const parser::ExecutableConstruct &ec) {253common::visit(254common::visitors{255[&](const parser::Statement<parser::ActionStmt> &stmt) {256Check(stmt.statement, stmt.source);257},258[&](const common::Indirection<parser::DoConstruct> &x) {259if (const std::optional<parser::LoopControl> &control{260x.value().GetLoopControl()}) {261common::visit([&](const auto &y) { Check(y); }, control->u);262}263Check(std::get<parser::Block>(x.value().t));264},265[&](const common::Indirection<parser::BlockConstruct> &x) {266Check(std::get<parser::Block>(x.value().t));267},268[&](const common::Indirection<parser::IfConstruct> &x) {269Check(x.value());270},271[&](const auto &x) {272if (auto source{parser::GetSource(x)}) {273context_.Say(*source,274"Statement may not appear in device code"_err_en_US);275}276},277},278ec.u);279}280template <typename SEEK, typename A>281static const SEEK *GetIOControl(const A &stmt) {282for (const auto &spec : stmt.controls) {283if (const auto *result{std::get_if<SEEK>(&spec.u)}) {284return result;285}286}287return nullptr;288}289template <typename A> static bool IsInternalIO(const A &stmt) {290if (stmt.iounit.has_value()) {291return std::holds_alternative<Fortran::parser::Variable>(stmt.iounit->u);292}293if (auto *unit{GetIOControl<Fortran::parser::IoUnit>(stmt)}) {294return std::holds_alternative<Fortran::parser::Variable>(unit->u);295}296return false;297}298void WarnOnIoStmt(const parser::CharBlock &source) {299if (context_.ShouldWarn(common::UsageWarning::CUDAUsage)) {300context_.Say(301source, "I/O statement might not be supported on device"_warn_en_US);302}303}304template <typename A>305void WarnIfNotInternal(const A &stmt, const parser::CharBlock &source) {306if (!IsInternalIO(stmt)) {307WarnOnIoStmt(source);308}309}310void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {311common::visit(312common::visitors{313[&](const common::Indirection<parser::PrintStmt> &) {},314[&](const common::Indirection<parser::WriteStmt> &x) {315if (x.value().format) { // Formatted write to '*' or '6'316if (std::holds_alternative<Fortran::parser::Star>(317x.value().format->u)) {318if (x.value().iounit) {319if (std::holds_alternative<Fortran::parser::Star>(320x.value().iounit->u)) {321return;322}323}324}325}326WarnIfNotInternal(x.value(), source);327},328[&](const common::Indirection<parser::CloseStmt> &x) {329WarnOnIoStmt(source);330},331[&](const common::Indirection<parser::EndfileStmt> &x) {332WarnOnIoStmt(source);333},334[&](const common::Indirection<parser::OpenStmt> &x) {335WarnOnIoStmt(source);336},337[&](const common::Indirection<parser::ReadStmt> &x) {338WarnIfNotInternal(x.value(), source);339},340[&](const common::Indirection<parser::InquireStmt> &x) {341WarnOnIoStmt(source);342},343[&](const common::Indirection<parser::RewindStmt> &x) {344WarnOnIoStmt(source);345},346[&](const common::Indirection<parser::BackspaceStmt> &x) {347WarnOnIoStmt(source);348},349[&](const common::Indirection<parser::IfStmt> &x) {350Check(x.value());351},352[&](const auto &x) {353if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {354context_.Say(source, std::move(*msg));355}356},357},358stmt.u);359}360void Check(const parser::IfConstruct &ic) {361const auto &ifS{std::get<parser::Statement<parser::IfThenStmt>>(ic.t)};362CheckUnwrappedExpr(context_, ifS.source,363std::get<parser::ScalarLogicalExpr>(ifS.statement.t));364Check(std::get<parser::Block>(ic.t));365for (const auto &eib :366std::get<std::list<parser::IfConstruct::ElseIfBlock>>(ic.t)) {367const auto &eIfS{std::get<parser::Statement<parser::ElseIfStmt>>(eib.t)};368CheckUnwrappedExpr(context_, eIfS.source,369std::get<parser::ScalarLogicalExpr>(eIfS.statement.t));370Check(std::get<parser::Block>(eib.t));371}372if (const auto &eb{373std::get<std::optional<parser::IfConstruct::ElseBlock>>(ic.t)}) {374Check(std::get<parser::Block>(eb->t));375}376}377void Check(const parser::IfStmt &is) {378const auto &uS{379std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)};380CheckUnwrappedExpr(381context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t));382Check(uS.statement, uS.source);383}384void Check(const parser::LoopControl::Bounds &bounds) {385Check(bounds.lower);386Check(bounds.upper);387if (bounds.step) {388Check(*bounds.step);389}390}391void Check(const parser::LoopControl::Concurrent &x) {392const auto &header{std::get<parser::ConcurrentHeader>(x.t)};393for (const auto &cc :394std::get<std::list<parser::ConcurrentControl>>(header.t)) {395Check(std::get<1>(cc.t));396Check(std::get<2>(cc.t));397if (const auto &step{398std::get<std::optional<parser::ScalarIntExpr>>(cc.t)}) {399Check(*step);400}401}402if (const auto &mask{403std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {404Check(*mask);405}406}407void Check(const parser::ScalarLogicalExpr &x) {408Check(DEREF(parser::Unwrap<parser::Expr>(x)));409}410void Check(const parser::ScalarIntExpr &x) {411Check(DEREF(parser::Unwrap<parser::Expr>(x)));412}413void Check(const parser::ScalarExpr &x) {414Check(DEREF(parser::Unwrap<parser::Expr>(x)));415}416void Check(const parser::Expr &expr) {417if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {418context_.Say(expr.source, std::move(*msg));419}420}421
422SemanticsContext &context_;423};424
425void CUDAChecker::Enter(const parser::SubroutineSubprogram &x) {426DeviceContextChecker<false>{context_}.CheckSubprogram(427std::get<parser::Name>(428std::get<parser::Statement<parser::SubroutineStmt>>(x.t).statement.t),429std::get<parser::ExecutionPart>(x.t).v);430}
431
432void CUDAChecker::Enter(const parser::FunctionSubprogram &x) {433DeviceContextChecker<false>{context_}.CheckSubprogram(434std::get<parser::Name>(435std::get<parser::Statement<parser::FunctionStmt>>(x.t).statement.t),436std::get<parser::ExecutionPart>(x.t).v);437}
438
439void CUDAChecker::Enter(const parser::SeparateModuleSubprogram &x) {440DeviceContextChecker<false>{context_}.CheckSubprogram(441std::get<parser::Statement<parser::MpSubprogramStmt>>(x.t).statement.v,442std::get<parser::ExecutionPart>(x.t).v);443}
444
445// !$CUF KERNEL DO semantic checks
446
447static int DoConstructTightNesting(448const parser::DoConstruct *doConstruct, const parser::Block *&innerBlock) {449if (!doConstruct || !doConstruct->IsDoNormal()) {450return 0;451}452innerBlock = &std::get<parser::Block>(doConstruct->t);453if (innerBlock->size() == 1) {454if (const auto *execConstruct{455std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) {456if (const auto *next{457std::get_if<common::Indirection<parser::DoConstruct>>(458&execConstruct->u)}) {459return 1 + DoConstructTightNesting(&next->value(), innerBlock);460}461}462}463return 1;464}
465
466static void CheckReduce(467SemanticsContext &context, const parser::CUFReduction &reduce) {468auto op{std::get<parser::CUFReduction::Operator>(reduce.t).v};469for (const auto &var :470std::get<std::list<parser::Scalar<parser::Variable>>>(reduce.t)) {471if (const auto &typedExprPtr{var.thing.typedExpr};472typedExprPtr && typedExprPtr->v) {473const auto &expr{*typedExprPtr->v};474if (auto type{expr.GetType()}) {475auto cat{type->category()};476bool isOk{false};477switch (op) {478case parser::ReductionOperator::Operator::Plus:479case parser::ReductionOperator::Operator::Multiply:480case parser::ReductionOperator::Operator::Max:481case parser::ReductionOperator::Operator::Min:482isOk = cat == TypeCategory::Integer || cat == TypeCategory::Real;483break;484case parser::ReductionOperator::Operator::Iand:485case parser::ReductionOperator::Operator::Ior:486case parser::ReductionOperator::Operator::Ieor:487isOk = cat == TypeCategory::Integer;488break;489case parser::ReductionOperator::Operator::And:490case parser::ReductionOperator::Operator::Or:491case parser::ReductionOperator::Operator::Eqv:492case parser::ReductionOperator::Operator::Neqv:493isOk = cat == TypeCategory::Logical;494break;495}496if (!isOk) {497context.Say(var.thing.GetSource(),498"!$CUF KERNEL DO REDUCE operation is not acceptable for a variable with type %s"_err_en_US,499type->AsFortran());500}501}502}503}504}
505
506void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {507auto source{std::get<parser::CUFKernelDoConstruct::Directive>(x.t).source};508const auto &directive{std::get<parser::CUFKernelDoConstruct::Directive>(x.t)};509std::int64_t depth{1};510if (auto expr{AnalyzeExpr(context_,511std::get<std::optional<parser::ScalarIntConstantExpr>>(512directive.t))}) {513depth = evaluate::ToInt64(expr).value_or(0);514if (depth <= 0) {515context_.Say(source,516"!$CUF KERNEL DO (%jd): loop nesting depth must be positive"_err_en_US,517std::intmax_t{depth});518depth = 1;519}520}521const parser::DoConstruct *doConstruct{common::GetPtrFromOptional(522std::get<std::optional<parser::DoConstruct>>(x.t))};523const parser::Block *innerBlock{nullptr};524if (DoConstructTightNesting(doConstruct, innerBlock) < depth) {525context_.Say(source,526"!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,527std::intmax_t{depth});528}529if (innerBlock) {530DeviceContextChecker<true>{context_}.Check(*innerBlock);531}532for (const auto &reduce :533std::get<std::list<parser::CUFReduction>>(directive.t)) {534CheckReduce(context_, reduce);535}536}
537
538void CUDAChecker::Enter(const parser::AssignmentStmt &x) {539auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()};540const auto &scope{context_.FindScope(lhsLoc)};541const Scope &progUnit{GetProgramUnitContaining(scope)};542if (IsCUDADeviceContext(&progUnit)) {543return; // Data transfer with assignment is only perform on host.544}545
546const evaluate::Assignment *assign{semantics::GetAssignment(x)};547if (!assign) {548return;549}550
551int nbLhs{evaluate::GetNbOfCUDADeviceSymbols(assign->lhs)};552int nbRhs{evaluate::GetNbOfCUDADeviceSymbols(assign->rhs)};553
554// device to host transfer with more than one device object on the rhs is not555// legal.556if (nbLhs == 0 && nbRhs > 1) {557context_.Say(lhsLoc,558"More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);559}560}
561
562} // namespace Fortran::semantics563