22
using namespace std::literals::string_literals;
32
Expr* Expr::copy() const {
33
auto res = new Expr{*this};
34
for (auto& arg : res->args) {
40
Expr::Expr(int c, sym_idx_t name_idx, std::initializer_list<Expr*> _arglist) : cls(c), args(std::move(_arglist)) {
41
sym = sym::lookup_symbol(name_idx);
46
void Expr::chk_rvalue(const Lexem& lem) const {
48
lem.error_at("rvalue expected before `", "`");
52
void Expr::chk_lvalue(const Lexem& lem) const {
54
lem.error_at("lvalue expected before `", "`");
58
void Expr::chk_type(const Lexem& lem) const {
60
lem.error_at("type expression expected before `", "`");
64
bool Expr::deduce_type(const Lexem& lem) {
73
SymVal* sym_val = dynamic_cast<SymVal*>(sym->value);
74
if (!sym_val || !sym_val->get_type()) {
77
std::vector<TypeExpr*> arg_types;
78
for (const auto& arg : args) {
79
arg_types.push_back(arg->e_type);
81
TypeExpr* fun_type = TypeExpr::new_map(TypeExpr::new_tensor(arg_types), TypeExpr::new_hole());
83
unify(fun_type, sym_val->sym_type);
84
} catch (UnifyError& ue) {
85
std::ostringstream os;
86
os << "cannot apply function " << sym->name() << " : " << sym_val->get_type() << " to arguments of type "
87
<< fun_type->args[0] << ": " << ue;
90
e_type = fun_type->args[1];
91
TypeExpr::remove_indirect(e_type);
95
func_assert(args.size() == 2);
96
TypeExpr* fun_type = TypeExpr::new_map(args[1]->e_type, TypeExpr::new_hole());
98
unify(fun_type, args[0]->e_type);
99
} catch (UnifyError& ue) {
100
std::ostringstream os;
101
os << "cannot apply expression of type " << args[0]->e_type << " to an expression of type " << args[1]->e_type
105
e_type = fun_type->args[1];
106
TypeExpr::remove_indirect(e_type);
110
func_assert(args.size() == 2);
113
unify(args[0]->e_type, args[1]->e_type);
114
} catch (UnifyError& ue) {
115
std::ostringstream os;
116
os << "cannot assign an expression of type " << args[1]->e_type << " to a variable or pattern of type "
117
<< args[0]->e_type << ": " << ue;
120
e_type = args[0]->e_type;
121
TypeExpr::remove_indirect(e_type);
125
func_assert(args.size() == 2);
126
TypeExpr* rhs_type = TypeExpr::new_tensor({args[0]->e_type, TypeExpr::new_hole()});
129
unify(rhs_type, args[1]->e_type);
130
} catch (UnifyError& ue) {
131
std::ostringstream os;
132
os << "cannot implicitly assign an expression of type " << args[1]->e_type
133
<< " to a variable or pattern of type " << rhs_type << " in modifying method `" << sym::symbols.get_name(val)
137
e_type = rhs_type->args[1];
138
TypeExpr::remove_indirect(e_type);
143
func_assert(args.size() == 3);
144
auto flag_type = TypeExpr::new_atomic(_Int);
146
unify(args[0]->e_type, flag_type);
147
} catch (UnifyError& ue) {
148
std::ostringstream os;
149
os << "condition in a conditional expression has non-integer type " << args[0]->e_type << ": " << ue;
153
unify(args[1]->e_type, args[2]->e_type);
154
} catch (UnifyError& ue) {
155
std::ostringstream os;
156
os << "the two variants in a conditional expression have different types " << args[1]->e_type << " and "
157
<< args[2]->e_type << " : " << ue;
160
e_type = args[1]->e_type;
161
TypeExpr::remove_indirect(e_type);
168
int Expr::define_new_vars(CodeBlob& code) {
174
for (const auto& x : args) {
175
res += x->define_new_vars(code);
181
val = code.create_var(TmpVar::_Named, e_type, sym, &here);
187
val = code.create_var(TmpVar::_Tmp, e_type, nullptr, &here);
194
int Expr::predefine_vars() {
200
for (const auto& x : args) {
201
res += x->predefine_vars();
207
func_assert(val < 0 && here.defined());
208
if (prohibited_var_names.count(sym::symbols.get_name(~val))) {
209
throw src::ParseError{
210
here, PSTRING() << "symbol `" << sym::symbols.get_name(~val) << "` cannot be redefined as a variable"};
212
sym = sym::define_symbol(~val, false, here);
215
throw src::ParseError{here, std::string{"redefined variable `"} + sym::symbols.get_name(~val) + "`"};
217
sym->value = new SymVal{SymVal::_Var, -1, e_type};
225
var_idx_t Expr::new_tmp(CodeBlob& code) const {
226
return code.create_tmp_var(e_type, &here);
229
void add_set_globs(CodeBlob& code, std::vector<std::pair<SymDef*, var_idx_t>>& globs, const SrcLocation& here) {
230
for (const auto& p : globs) {
231
auto& op = code.emplace_back(here, Op::_SetGlob, std::vector<var_idx_t>{}, std::vector<var_idx_t>{ p.second }, p.first);
232
op.flags |= Op::_Impure;
236
std::vector<var_idx_t> Expr::pre_compile_let(CodeBlob& code, Expr* lhs, Expr* rhs, const SrcLocation& here) {
237
while (lhs->is_type_apply()) {
238
lhs = lhs->args.at(0);
240
while (rhs->is_type_apply()) {
241
rhs = rhs->args.at(0);
243
if (lhs->is_mktuple()) {
244
if (rhs->is_mktuple()) {
245
return pre_compile_let(code, lhs->args.at(0), rhs->args.at(0), here);
247
auto right = rhs->pre_compile(code);
248
TypeExpr::remove_indirect(rhs->e_type);
249
auto unpacked_type = rhs->e_type->args.at(0);
250
std::vector<var_idx_t> tmp{code.create_tmp_var(unpacked_type, &rhs->here)};
251
code.emplace_back(lhs->here, Op::_UnTuple, tmp, std::move(right));
252
auto tvar = new Expr{_Var};
253
tvar->set_val(tmp[0]);
254
tvar->set_location(rhs->here);
255
tvar->e_type = unpacked_type;
256
pre_compile_let(code, lhs->args.at(0), tvar, here);
259
auto right = rhs->pre_compile(code);
260
std::vector<std::pair<SymDef*, var_idx_t>> globs;
261
auto left = lhs->pre_compile(code, &globs);
262
for (var_idx_t v : left) {
263
code.on_var_modification(v, here);
265
code.emplace_back(here, Op::_Let, std::move(left), right);
266
add_set_globs(code, globs, here);
270
std::vector<var_idx_t> pre_compile_tensor(const std::vector<Expr *> args, CodeBlob &code,
271
std::vector<std::pair<SymDef*, var_idx_t>> *lval_globs,
272
std::vector<int> arg_order) {
273
if (arg_order.empty()) {
274
arg_order.resize(args.size());
275
std::iota(arg_order.begin(), arg_order.end(), 0);
277
func_assert(args.size() == arg_order.size());
278
std::vector<std::vector<var_idx_t>> res_lists(args.size());
284
auto modified_vars = std::make_shared<std::vector<ModifiedVar>>();
285
for (size_t i : arg_order) {
286
res_lists[i] = args[i]->pre_compile(code, lval_globs);
287
for (size_t j = 0; j < res_lists[i].size(); ++j) {
288
TmpVar& var = code.vars.at(res_lists[i][j]);
289
if (code.flags & CodeBlob::_AllowPostModification) {
290
if (!lval_globs && (var.cls & TmpVar::_Named)) {
291
Op *op = &code.emplace_back(nullptr, Op::_Let, std::vector<var_idx_t>(), std::vector<var_idx_t>());
292
op->flags |= Op::_Disabled;
293
var.on_modification.push_back([modified_vars, i, j, op, done = false](const SrcLocation &here) mutable {
296
modified_vars->push_back({i, j, op});
300
var.on_modification.push_back([](const SrcLocation &) {
304
var.on_modification.push_back([name = var.to_string()](const SrcLocation &here) {
305
throw src::ParseError{here, PSTRING() << "Modifying local variable " << name
306
<< " after using it in the same expression"};
311
for (const auto& list : res_lists) {
312
for (var_idx_t v : list) {
313
func_assert(!code.vars.at(v).on_modification.empty());
314
code.vars.at(v).on_modification.pop_back();
317
for (const ModifiedVar &m : *modified_vars) {
318
var_idx_t& v = res_lists[m.i][m.j];
319
var_idx_t v2 = code.create_tmp_var(code.vars[v].v_type, code.vars[v].where.get());
322
m.op->flags &= ~Op::_Disabled;
325
std::vector<var_idx_t> res;
326
for (const auto& list : res_lists) {
327
res.insert(res.end(), list.cbegin(), list.cend());
332
std::vector<var_idx_t> Expr::pre_compile(CodeBlob& code, std::vector<std::pair<SymDef*, var_idx_t>>* lval_globs) const {
333
if (lval_globs && !(cls == _Tensor || cls == _Var || cls == _Hole || cls == _TypeApply || cls == _GlobVar)) {
334
std::cerr << "lvalue expression constructor is " << cls << std::endl;
335
throw src::Fatal{"cannot compile lvalue expression with unknown constructor"};
339
return pre_compile_tensor(args, code, lval_globs, {});
343
auto func = dynamic_cast<SymValFunc*>(sym->value);
344
std::vector<var_idx_t> res;
345
if (func && func->arg_order.size() == args.size() && !(code.flags & CodeBlob::_ComputeAsmLtr)) {
347
res = pre_compile_tensor(args, code, lval_globs, func->arg_order);
349
res = pre_compile_tensor(args, code, lval_globs, {});
351
auto rvect = new_tmp_vect(code);
352
auto& op = code.emplace_back(here, Op::_Call, rvect, std::move(res), sym);
353
if (flags & _IsImpure) {
354
op.flags |= Op::_Impure;
359
return args[0]->pre_compile(code, lval_globs);
363
throw src::ParseError{here, "unexpected variable definition"};
367
if (args[0]->cls == _Glob) {
368
auto res = args[1]->pre_compile(code);
369
auto rvect = new_tmp_vect(code);
370
auto& op = code.emplace_back(here, Op::_Call, rvect, std::move(res), args[0]->sym);
371
if (args[0]->flags & _IsImpure) {
372
op.flags |= Op::_Impure;
376
auto res = args[1]->pre_compile(code);
377
auto tfunc = args[0]->pre_compile(code);
378
if (tfunc.size() != 1) {
379
throw src::Fatal{"stack tuple used as a function"};
381
res.push_back(tfunc[0]);
382
auto rvect = new_tmp_vect(code);
383
code.emplace_back(here, Op::_CallInd, rvect, std::move(res));
387
auto rvect = new_tmp_vect(code);
388
code.emplace_back(here, Op::_IntConst, rvect, intval);
393
auto rvect = new_tmp_vect(code);
395
lval_globs->push_back({ sym, rvect[0] });
398
code.emplace_back(here, Op::_GlobVar, rvect, std::vector<var_idx_t>{}, sym);
403
return pre_compile_let(code, args.at(0), args.at(1), here);
406
auto rvect = new_tmp_vect(code);
407
auto right = args[1]->pre_compile(code);
408
std::vector<std::pair<SymDef*, var_idx_t>> local_globs;
410
lval_globs = &local_globs;
412
auto left = args[0]->pre_compile(code, lval_globs);
413
left.push_back(rvect[0]);
414
for (var_idx_t v : left) {
415
code.on_var_modification(v, here);
417
code.emplace_back(here, Op::_Let, std::move(left), std::move(right));
418
add_set_globs(code, local_globs, here);
422
auto left = new_tmp_vect(code);
423
auto right = args[0]->pre_compile(code);
424
code.emplace_back(here, Op::_Tuple, left, std::move(right));
428
auto cond = args[0]->pre_compile(code);
429
func_assert(cond.size() == 1);
430
auto rvect = new_tmp_vect(code);
431
Op& if_op = code.emplace_back(here, Op::_If, cond);
432
code.push_set_cur(if_op.block0);
433
code.emplace_back(here, Op::_Let, rvect, args[1]->pre_compile(code));
434
code.close_pop_cur(args[1]->here);
435
code.push_set_cur(if_op.block1);
436
code.emplace_back(here, Op::_Let, rvect, args[2]->pre_compile(code));
437
code.close_pop_cur(args[2]->here);
441
auto rvect = new_tmp_vect(code);
442
code.emplace_back(here, Op::_SliceConst, rvect, strval);
446
std::cerr << "expression constructor is " << cls << std::endl;
447
throw src::Fatal{"cannot compile expression with unknown constructor"};