llvm-project

Форк
0
2488 строк · 96.4 Кб
1
//===- ir.c - Simple test of C APIs ---------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4
// Exceptions.
5
// See https://llvm.org/LICENSE.txt for license information.
6
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7
//
8
//===----------------------------------------------------------------------===//
9

10
/* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
11
 */
12

13
#include "mlir-c/IR.h"
14
#include "mlir-c/AffineExpr.h"
15
#include "mlir-c/AffineMap.h"
16
#include "mlir-c/BuiltinAttributes.h"
17
#include "mlir-c/BuiltinTypes.h"
18
#include "mlir-c/Diagnostics.h"
19
#include "mlir-c/Dialect/Func.h"
20
#include "mlir-c/IntegerSet.h"
21
#include "mlir-c/RegisterEverything.h"
22
#include "mlir-c/Support.h"
23

24
#include <assert.h>
25
#include <inttypes.h>
26
#include <math.h>
27
#include <stdio.h>
28
#include <stdlib.h>
29
#include <string.h>
30

31
static void registerAllUpstreamDialects(MlirContext ctx) {
32
  MlirDialectRegistry registry = mlirDialectRegistryCreate();
33
  mlirRegisterAllDialects(registry);
34
  mlirContextAppendDialectRegistry(ctx, registry);
35
  mlirDialectRegistryDestroy(registry);
36
}
37

38
struct ResourceDeleteUserData {
39
  const char *name;
40
};
41
static struct ResourceDeleteUserData resourceI64BlobUserData = {
42
    "resource_i64_blob"};
43
static void reportResourceDelete(void *userData, const void *data, size_t size,
44
                                 size_t align) {
45
  fprintf(stderr, "reportResourceDelete: %s\n",
46
          ((struct ResourceDeleteUserData *)userData)->name);
47
}
48

49
void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
50
                      MlirLocation location, MlirBlock funcBody) {
51
  MlirValue iv = mlirBlockGetArgument(loopBody, 0);
52
  MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
53
  MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
54
  MlirType f32Type =
55
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
56

57
  MlirOperationState loadLHSState = mlirOperationStateGet(
58
      mlirStringRefCreateFromCString("memref.load"), location);
59
  MlirValue loadLHSOperands[] = {funcArg0, iv};
60
  mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
61
  mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
62
  MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
63
  mlirBlockAppendOwnedOperation(loopBody, loadLHS);
64

65
  MlirOperationState loadRHSState = mlirOperationStateGet(
66
      mlirStringRefCreateFromCString("memref.load"), location);
67
  MlirValue loadRHSOperands[] = {funcArg1, iv};
68
  mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
69
  mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
70
  MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
71
  mlirBlockAppendOwnedOperation(loopBody, loadRHS);
72

73
  MlirOperationState addState = mlirOperationStateGet(
74
      mlirStringRefCreateFromCString("arith.addf"), location);
75
  MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
76
                             mlirOperationGetResult(loadRHS, 0)};
77
  mlirOperationStateAddOperands(&addState, 2, addOperands);
78
  mlirOperationStateAddResults(&addState, 1, &f32Type);
79
  MlirOperation add = mlirOperationCreate(&addState);
80
  mlirBlockAppendOwnedOperation(loopBody, add);
81

82
  MlirOperationState storeState = mlirOperationStateGet(
83
      mlirStringRefCreateFromCString("memref.store"), location);
84
  MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
85
  mlirOperationStateAddOperands(&storeState, 3, storeOperands);
86
  MlirOperation store = mlirOperationCreate(&storeState);
87
  mlirBlockAppendOwnedOperation(loopBody, store);
88

89
  MlirOperationState yieldState = mlirOperationStateGet(
90
      mlirStringRefCreateFromCString("scf.yield"), location);
91
  MlirOperation yield = mlirOperationCreate(&yieldState);
92
  mlirBlockAppendOwnedOperation(loopBody, yield);
93
}
94

95
MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
96
  MlirModule moduleOp = mlirModuleCreateEmpty(location);
97
  MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
98

99
  MlirType memrefType =
100
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
101
  MlirType funcBodyArgTypes[] = {memrefType, memrefType};
102
  MlirLocation funcBodyArgLocs[] = {location, location};
103
  MlirRegion funcBodyRegion = mlirRegionCreate();
104
  MlirBlock funcBody =
105
      mlirBlockCreate(sizeof(funcBodyArgTypes) / sizeof(MlirType),
106
                      funcBodyArgTypes, funcBodyArgLocs);
107
  mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
108

109
  MlirAttribute funcTypeAttr = mlirAttributeParseGet(
110
      ctx,
111
      mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
112
  MlirAttribute funcNameAttr =
113
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
114
  MlirNamedAttribute funcAttrs[] = {
115
      mlirNamedAttributeGet(
116
          mlirIdentifierGet(ctx,
117
                            mlirStringRefCreateFromCString("function_type")),
118
          funcTypeAttr),
119
      mlirNamedAttributeGet(
120
          mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sym_name")),
121
          funcNameAttr)};
122
  MlirOperationState funcState = mlirOperationStateGet(
123
      mlirStringRefCreateFromCString("func.func"), location);
124
  mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
125
  mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
126
  MlirOperation func = mlirOperationCreate(&funcState);
127
  mlirBlockInsertOwnedOperation(moduleBody, 0, func);
128

129
  MlirType indexType =
130
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
131
  MlirAttribute indexZeroLiteral =
132
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
133
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
134
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
135
      indexZeroLiteral);
136
  MlirOperationState constZeroState = mlirOperationStateGet(
137
      mlirStringRefCreateFromCString("arith.constant"), location);
138
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
139
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
140
  MlirOperation constZero = mlirOperationCreate(&constZeroState);
141
  mlirBlockAppendOwnedOperation(funcBody, constZero);
142

143
  MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
144
  MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
145
  MlirValue dimOperands[] = {funcArg0, constZeroValue};
146
  MlirOperationState dimState = mlirOperationStateGet(
147
      mlirStringRefCreateFromCString("memref.dim"), location);
148
  mlirOperationStateAddOperands(&dimState, 2, dimOperands);
149
  mlirOperationStateAddResults(&dimState, 1, &indexType);
150
  MlirOperation dim = mlirOperationCreate(&dimState);
151
  mlirBlockAppendOwnedOperation(funcBody, dim);
152

153
  MlirRegion loopBodyRegion = mlirRegionCreate();
154
  MlirBlock loopBody = mlirBlockCreate(0, NULL, NULL);
155
  mlirBlockAddArgument(loopBody, indexType, location);
156
  mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
157

158
  MlirAttribute indexOneLiteral =
159
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
160
  MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
161
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
162
      indexOneLiteral);
163
  MlirOperationState constOneState = mlirOperationStateGet(
164
      mlirStringRefCreateFromCString("arith.constant"), location);
165
  mlirOperationStateAddResults(&constOneState, 1, &indexType);
166
  mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
167
  MlirOperation constOne = mlirOperationCreate(&constOneState);
168
  mlirBlockAppendOwnedOperation(funcBody, constOne);
169

170
  MlirValue dimValue = mlirOperationGetResult(dim, 0);
171
  MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
172
  MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
173
  MlirOperationState loopState = mlirOperationStateGet(
174
      mlirStringRefCreateFromCString("scf.for"), location);
175
  mlirOperationStateAddOperands(&loopState, 3, loopOperands);
176
  mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
177
  MlirOperation loop = mlirOperationCreate(&loopState);
178
  mlirBlockAppendOwnedOperation(funcBody, loop);
179

180
  populateLoopBody(ctx, loopBody, location, funcBody);
181

182
  MlirOperationState retState = mlirOperationStateGet(
183
      mlirStringRefCreateFromCString("func.return"), location);
184
  MlirOperation ret = mlirOperationCreate(&retState);
185
  mlirBlockAppendOwnedOperation(funcBody, ret);
186

187
  MlirOperation module = mlirModuleGetOperation(moduleOp);
188
  mlirOperationDump(module);
189
  // clang-format off
190
  // CHECK: module {
191
  // CHECK:   func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
192
  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
193
  // CHECK:     %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
194
  // CHECK:     %[[C1:.*]] = arith.constant 1 : index
195
  // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
196
  // CHECK:       %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32>
197
  // CHECK:       %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32>
198
  // CHECK:       %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
199
  // CHECK:       memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
200
  // CHECK:     }
201
  // CHECK:     return
202
  // CHECK:   }
203
  // CHECK: }
204
  // clang-format on
205

206
  return moduleOp;
207
}
208

209
struct OpListNode {
210
  MlirOperation op;
211
  struct OpListNode *next;
212
};
213
typedef struct OpListNode OpListNode;
214

215
struct ModuleStats {
216
  unsigned numOperations;
217
  unsigned numAttributes;
218
  unsigned numBlocks;
219
  unsigned numRegions;
220
  unsigned numValues;
221
  unsigned numBlockArguments;
222
  unsigned numOpResults;
223
};
224
typedef struct ModuleStats ModuleStats;
225

226
int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
227
  MlirOperation operation = head->op;
228
  stats->numOperations += 1;
229
  stats->numValues += mlirOperationGetNumResults(operation);
230
  stats->numAttributes += mlirOperationGetNumAttributes(operation);
231

232
  unsigned numRegions = mlirOperationGetNumRegions(operation);
233

234
  stats->numRegions += numRegions;
235

236
  intptr_t numResults = mlirOperationGetNumResults(operation);
237
  for (intptr_t i = 0; i < numResults; ++i) {
238
    MlirValue result = mlirOperationGetResult(operation, i);
239
    if (!mlirValueIsAOpResult(result))
240
      return 1;
241
    if (mlirValueIsABlockArgument(result))
242
      return 2;
243
    if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
244
      return 3;
245
    if (i != mlirOpResultGetResultNumber(result))
246
      return 4;
247
    ++stats->numOpResults;
248
  }
249

250
  MlirRegion region = mlirOperationGetFirstRegion(operation);
251
  while (!mlirRegionIsNull(region)) {
252
    for (MlirBlock block = mlirRegionGetFirstBlock(region);
253
         !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
254
      ++stats->numBlocks;
255
      intptr_t numArgs = mlirBlockGetNumArguments(block);
256
      stats->numValues += numArgs;
257
      for (intptr_t j = 0; j < numArgs; ++j) {
258
        MlirValue arg = mlirBlockGetArgument(block, j);
259
        if (!mlirValueIsABlockArgument(arg))
260
          return 5;
261
        if (mlirValueIsAOpResult(arg))
262
          return 6;
263
        if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
264
          return 7;
265
        if (j != mlirBlockArgumentGetArgNumber(arg))
266
          return 8;
267
        ++stats->numBlockArguments;
268
      }
269

270
      for (MlirOperation child = mlirBlockGetFirstOperation(block);
271
           !mlirOperationIsNull(child);
272
           child = mlirOperationGetNextInBlock(child)) {
273
        OpListNode *node = malloc(sizeof(OpListNode));
274
        node->op = child;
275
        node->next = head->next;
276
        head->next = node;
277
      }
278
    }
279
    region = mlirRegionGetNextInOperation(region);
280
  }
281
  return 0;
282
}
283

284
int collectStats(MlirOperation operation) {
285
  OpListNode *head = malloc(sizeof(OpListNode));
286
  head->op = operation;
287
  head->next = NULL;
288

289
  ModuleStats stats;
290
  stats.numOperations = 0;
291
  stats.numAttributes = 0;
292
  stats.numBlocks = 0;
293
  stats.numRegions = 0;
294
  stats.numValues = 0;
295
  stats.numBlockArguments = 0;
296
  stats.numOpResults = 0;
297

298
  do {
299
    int retval = collectStatsSingle(head, &stats);
300
    if (retval) {
301
      free(head);
302
      return retval;
303
    }
304
    OpListNode *next = head->next;
305
    free(head);
306
    head = next;
307
  } while (head);
308

309
  if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
310
    return 100;
311

312
  fprintf(stderr, "@stats\n");
313
  fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
314
  fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
315
  fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
316
  fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
317
  fprintf(stderr, "Number of values: %u\n", stats.numValues);
318
  fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
319
  fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
320
  // clang-format off
321
  // CHECK-LABEL: @stats
322
  // CHECK: Number of operations: 12
323
  // CHECK: Number of attributes: 5
324
  // CHECK: Number of blocks: 3
325
  // CHECK: Number of regions: 3
326
  // CHECK: Number of values: 9
327
  // CHECK: Number of block arguments: 3
328
  // CHECK: Number of op results: 6
329
  // clang-format on
330
  return 0;
331
}
332

333
static void printToStderr(MlirStringRef str, void *userData) {
334
  (void)userData;
335
  fwrite(str.data, 1, str.length, stderr);
336
}
337

338
static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
339
  // Assuming we are given a module, go to the first operation of the first
340
  // function.
341
  MlirRegion region = mlirOperationGetRegion(operation, 0);
342
  MlirBlock block = mlirRegionGetFirstBlock(region);
343
  MlirOperation function = mlirBlockGetFirstOperation(block);
344
  region = mlirOperationGetRegion(function, 0);
345
  MlirOperation parentOperation = function;
346
  block = mlirRegionGetFirstBlock(region);
347
  operation = mlirBlockGetFirstOperation(block);
348
  assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
349

350
  // Verify that parent operation and block report correctly.
351
  // CHECK: Parent operation eq: 1
352
  fprintf(stderr, "Parent operation eq: %d\n",
353
          mlirOperationEqual(mlirOperationGetParentOperation(operation),
354
                             parentOperation));
355
  // CHECK: Block eq: 1
356
  fprintf(stderr, "Block eq: %d\n",
357
          mlirBlockEqual(mlirOperationGetBlock(operation), block));
358
  // CHECK: Block parent operation eq: 1
359
  fprintf(
360
      stderr, "Block parent operation eq: %d\n",
361
      mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation));
362
  // CHECK: Block parent region eq: 1
363
  fprintf(stderr, "Block parent region eq: %d\n",
364
          mlirRegionEqual(mlirBlockGetParentRegion(block), region));
365

366
  // In the module we created, the first operation of the first function is
367
  // an "memref.dim", which has an attribute and a single result that we can
368
  // use to test the printing mechanism.
369
  mlirBlockPrint(block, printToStderr, NULL);
370
  fprintf(stderr, "\n");
371
  fprintf(stderr, "First operation: ");
372
  mlirOperationPrint(operation, printToStderr, NULL);
373
  fprintf(stderr, "\n");
374
  // clang-format off
375
  // CHECK:   %[[C0:.*]] = arith.constant 0 : index
376
  // CHECK:   %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
377
  // CHECK:   %[[C1:.*]] = arith.constant 1 : index
378
  // CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
379
  // CHECK:     %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
380
  // CHECK:     %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
381
  // CHECK:     %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
382
  // CHECK:     memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
383
  // CHECK:   }
384
  // CHECK: return
385
  // CHECK: First operation: {{.*}} = arith.constant 0 : index
386
  // clang-format on
387

388
  // Get the operation name and print it.
389
  MlirIdentifier ident = mlirOperationGetName(operation);
390
  MlirStringRef identStr = mlirIdentifierStr(ident);
391
  fprintf(stderr, "Operation name: '");
392
  for (size_t i = 0; i < identStr.length; ++i)
393
    fputc(identStr.data[i], stderr);
394
  fprintf(stderr, "'\n");
395
  // CHECK: Operation name: 'arith.constant'
396

397
  // Get the identifier again and verify equal.
398
  MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
399
  fprintf(stderr, "Identifier equal: %d\n",
400
          mlirIdentifierEqual(ident, identAgain));
401
  // CHECK: Identifier equal: 1
402

403
  // Get the block terminator and print it.
404
  MlirOperation terminator = mlirBlockGetTerminator(block);
405
  fprintf(stderr, "Terminator: ");
406
  mlirOperationPrint(terminator, printToStderr, NULL);
407
  fprintf(stderr, "\n");
408
  // CHECK: Terminator: func.return
409

410
  // Get the attribute by name.
411
  bool hasValueAttr = mlirOperationHasInherentAttributeByName(
412
      operation, mlirStringRefCreateFromCString("value"));
413
  if (hasValueAttr)
414
    // CHECK: Has attr "value"
415
    fprintf(stderr, "Has attr \"value\"");
416

417
  MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(
418
      operation, mlirStringRefCreateFromCString("value"));
419
  fprintf(stderr, "Get attr \"value\": ");
420
  mlirAttributePrint(valueAttr0, printToStderr, NULL);
421
  fprintf(stderr, "\n");
422
  // CHECK: Get attr "value": 0 : index
423

424
  // Get a non-existing attribute and assert that it is null (sanity).
425
  fprintf(stderr, "does_not_exist is null: %d\n",
426
          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
427
              operation, mlirStringRefCreateFromCString("does_not_exist"))));
428
  // CHECK: does_not_exist is null: 1
429

430
  // Get result 0 and its type.
431
  MlirValue value = mlirOperationGetResult(operation, 0);
432
  fprintf(stderr, "Result 0: ");
433
  mlirValuePrint(value, printToStderr, NULL);
434
  fprintf(stderr, "\n");
435
  fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
436
  // CHECK: Result 0: {{.*}} = arith.constant 0 : index
437
  // CHECK: Value is null: 0
438

439
  MlirType type = mlirValueGetType(value);
440
  fprintf(stderr, "Result 0 type: ");
441
  mlirTypePrint(type, printToStderr, NULL);
442
  fprintf(stderr, "\n");
443
  // CHECK: Result 0 type: index
444

445
  // Set a discardable attribute.
446
  mlirOperationSetDiscardableAttributeByName(
447
      operation, mlirStringRefCreateFromCString("custom_attr"),
448
      mlirBoolAttrGet(ctx, 1));
449
  fprintf(stderr, "Op with set attr: ");
450
  mlirOperationPrint(operation, printToStderr, NULL);
451
  fprintf(stderr, "\n");
452
  // CHECK: Op with set attr: {{.*}} {custom_attr = true}
453

454
  // Remove the attribute.
455
  fprintf(stderr, "Remove attr: %d\n",
456
          mlirOperationRemoveDiscardableAttributeByName(
457
              operation, mlirStringRefCreateFromCString("custom_attr")));
458
  fprintf(stderr, "Remove attr again: %d\n",
459
          mlirOperationRemoveDiscardableAttributeByName(
460
              operation, mlirStringRefCreateFromCString("custom_attr")));
461
  fprintf(stderr, "Removed attr is null: %d\n",
462
          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
463
              operation, mlirStringRefCreateFromCString("custom_attr"))));
464
  // CHECK: Remove attr: 1
465
  // CHECK: Remove attr again: 0
466
  // CHECK: Removed attr is null: 1
467

468
  // Add a large attribute to verify printing flags.
469
  int64_t eltsShape[] = {4};
470
  int32_t eltsData[] = {1, 2, 3, 4};
471
  mlirOperationSetDiscardableAttributeByName(
472
      operation, mlirStringRefCreateFromCString("elts"),
473
      mlirDenseElementsAttrInt32Get(
474
          mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
475
                                  mlirAttributeGetNull()),
476
          4, eltsData));
477
  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
478
  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
479
  mlirOpPrintingFlagsPrintGenericOpForm(flags);
480
  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/1, /*prettyForm=*/0);
481
  mlirOpPrintingFlagsUseLocalScope(flags);
482
  fprintf(stderr, "Op print with all flags: ");
483
  mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
484
  fprintf(stderr, "\n");
485
  fprintf(stderr, "Op print with state: ");
486
  MlirAsmState state = mlirAsmStateCreateForOperation(parentOperation, flags);
487
  mlirOperationPrintWithState(operation, state, printToStderr, NULL);
488
  fprintf(stderr, "\n");
489
  // clang-format off
490
  // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
491
  // clang-format on
492

493
  mlirOpPrintingFlagsDestroy(flags);
494
  flags = mlirOpPrintingFlagsCreate();
495
  mlirOpPrintingFlagsSkipRegions(flags);
496
  fprintf(stderr, "Op print with skip regions flag: ");
497
  mlirOperationPrintWithFlags(function, flags, printToStderr, NULL);
498
  fprintf(stderr, "\n");
499
  // clang-format off
500
  // CHECK: Op print with skip regions flag: func.func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>)
501
  // CHECK-NOT: constant
502
  // CHECK-NOT: return
503
  // clang-format on
504

505
  fprintf(stderr, "With state: |");
506
  mlirValuePrintAsOperand(value, state, printToStderr, NULL);
507
  // CHECK: With state: |%0|
508
  fprintf(stderr, "|\n");
509
  mlirAsmStateDestroy(state);
510

511
  mlirOpPrintingFlagsDestroy(flags);
512
}
513

514
static int constructAndTraverseIr(MlirContext ctx) {
515
  MlirLocation location = mlirLocationUnknownGet(ctx);
516

517
  MlirModule moduleOp = makeAndDumpAdd(ctx, location);
518
  MlirOperation module = mlirModuleGetOperation(moduleOp);
519
  assert(!mlirModuleIsNull(mlirModuleFromOperation(module)));
520

521
  int errcode = collectStats(module);
522
  if (errcode)
523
    return errcode;
524

525
  printFirstOfEach(ctx, module);
526

527
  mlirModuleDestroy(moduleOp);
528
  return 0;
529
}
530

531
/// Creates an operation with a region containing multiple blocks with
532
/// operations and dumps it. The blocks and operations are inserted using
533
/// block/operation-relative API and their final order is checked.
534
static void buildWithInsertionsAndPrint(MlirContext ctx) {
535
  MlirLocation loc = mlirLocationUnknownGet(ctx);
536
  mlirContextSetAllowUnregisteredDialects(ctx, true);
537

538
  MlirRegion owningRegion = mlirRegionCreate();
539
  MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
540
  MlirOperationState state = mlirOperationStateGet(
541
      mlirStringRefCreateFromCString("insertion.order.test"), loc);
542
  mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
543
  MlirOperation op = mlirOperationCreate(&state);
544
  MlirRegion region = mlirOperationGetRegion(op, 0);
545

546
  // Use integer types of different bitwidth as block arguments in order to
547
  // differentiate blocks.
548
  MlirType i1 = mlirIntegerTypeGet(ctx, 1);
549
  MlirType i2 = mlirIntegerTypeGet(ctx, 2);
550
  MlirType i3 = mlirIntegerTypeGet(ctx, 3);
551
  MlirType i4 = mlirIntegerTypeGet(ctx, 4);
552
  MlirType i5 = mlirIntegerTypeGet(ctx, 5);
553
  MlirBlock block1 = mlirBlockCreate(1, &i1, &loc);
554
  MlirBlock block2 = mlirBlockCreate(1, &i2, &loc);
555
  MlirBlock block3 = mlirBlockCreate(1, &i3, &loc);
556
  MlirBlock block4 = mlirBlockCreate(1, &i4, &loc);
557
  MlirBlock block5 = mlirBlockCreate(1, &i5, &loc);
558
  // Insert blocks so as to obtain the 1-2-3-4 order,
559
  mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
560
  mlirRegionInsertOwnedBlockBefore(region, block3, block2);
561
  mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
562
  mlirRegionInsertOwnedBlockAfter(region, block3, block4);
563
  mlirRegionInsertOwnedBlockBefore(region, block3, block5);
564

565
  MlirOperationState op1State =
566
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
567
  MlirOperationState op2State =
568
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
569
  MlirOperationState op3State =
570
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
571
  MlirOperationState op4State =
572
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
573
  MlirOperationState op5State =
574
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
575
  MlirOperationState op6State =
576
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
577
  MlirOperationState op7State =
578
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
579
  MlirOperationState op8State =
580
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op8"), loc);
581
  MlirOperation op1 = mlirOperationCreate(&op1State);
582
  MlirOperation op2 = mlirOperationCreate(&op2State);
583
  MlirOperation op3 = mlirOperationCreate(&op3State);
584
  MlirOperation op4 = mlirOperationCreate(&op4State);
585
  MlirOperation op5 = mlirOperationCreate(&op5State);
586
  MlirOperation op6 = mlirOperationCreate(&op6State);
587
  MlirOperation op7 = mlirOperationCreate(&op7State);
588
  MlirOperation op8 = mlirOperationCreate(&op8State);
589

590
  // Insert operations in the first block so as to obtain the 1-2-3-4 order.
591
  MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
592
  assert(mlirOperationIsNull(nullOperation));
593
  mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
594
  mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
595
  mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
596
  mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
597

598
  // Append operations to the rest of blocks to make them non-empty and thus
599
  // printable.
600
  mlirBlockAppendOwnedOperation(block2, op5);
601
  mlirBlockAppendOwnedOperation(block3, op6);
602
  mlirBlockAppendOwnedOperation(block4, op7);
603
  mlirBlockAppendOwnedOperation(block5, op8);
604

605
  // Remove block5.
606
  mlirBlockDetach(block5);
607
  mlirBlockDestroy(block5);
608

609
  mlirOperationDump(op);
610
  mlirOperationDestroy(op);
611
  mlirContextSetAllowUnregisteredDialects(ctx, false);
612
  // clang-format off
613
  // CHECK-LABEL:  "insertion.order.test"
614
  // CHECK:      ^{{.*}}(%{{.*}}: i1
615
  // CHECK:        "dummy.op1"
616
  // CHECK-NEXT:   "dummy.op2"
617
  // CHECK-NEXT:   "dummy.op3"
618
  // CHECK-NEXT:   "dummy.op4"
619
  // CHECK:      ^{{.*}}(%{{.*}}: i2
620
  // CHECK:        "dummy.op5"
621
  // CHECK-NOT:  ^{{.*}}(%{{.*}}: i5
622
  // CHECK-NOT:    "dummy.op8"
623
  // CHECK:      ^{{.*}}(%{{.*}}: i3
624
  // CHECK:        "dummy.op6"
625
  // CHECK:      ^{{.*}}(%{{.*}}: i4
626
  // CHECK:        "dummy.op7"
627
  // clang-format on
628
}
629

630
/// Creates operations with type inference and tests various failure modes.
631
static int createOperationWithTypeInference(MlirContext ctx) {
632
  MlirLocation loc = mlirLocationUnknownGet(ctx);
633
  MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);
634

635
  // The shape.const_size op implements result type inference and is only used
636
  // for that reason.
637
  MlirOperationState state = mlirOperationStateGet(
638
      mlirStringRefCreateFromCString("shape.const_size"), loc);
639
  MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
640
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
641
  mlirOperationStateAddAttributes(&state, 1, &valueAttr);
642
  mlirOperationStateEnableResultTypeInference(&state);
643

644
  // Expect result type inference to succeed.
645
  MlirOperation op = mlirOperationCreate(&state);
646
  if (mlirOperationIsNull(op)) {
647
    fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
648
    return 1;
649
  }
650

651
  // CHECK: RESULT_TYPE_INFERENCE: !shape.size
652
  fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
653
  mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
654
  fprintf(stderr, "\n");
655
  mlirOperationDestroy(op);
656
  return 0;
657
}
658

659
/// Dumps instances of all builtin types to check that C API works correctly.
660
/// Additionally, performs simple identity checks that a builtin type
661
/// constructed with C API can be inspected and has the expected type. The
662
/// latter achieves full coverage of C API for builtin types. Returns 0 on
663
/// success and a non-zero error code on failure.
664
static int printBuiltinTypes(MlirContext ctx) {
665
  // Integer types.
666
  MlirType i32 = mlirIntegerTypeGet(ctx, 32);
667
  MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
668
  MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
669
  if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
670
    return 1;
671
  if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
672
    return 2;
673
  if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
674
    return 3;
675
  if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
676
    return 4;
677
  if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
678
    return 5;
679
  fprintf(stderr, "@types\n");
680
  mlirTypeDump(i32);
681
  fprintf(stderr, "\n");
682
  mlirTypeDump(si32);
683
  fprintf(stderr, "\n");
684
  mlirTypeDump(ui32);
685
  fprintf(stderr, "\n");
686
  // CHECK-LABEL: @types
687
  // CHECK: i32
688
  // CHECK: si32
689
  // CHECK: ui32
690

691
  // Index type.
692
  MlirType index = mlirIndexTypeGet(ctx);
693
  if (!mlirTypeIsAIndex(index))
694
    return 6;
695
  mlirTypeDump(index);
696
  fprintf(stderr, "\n");
697
  // CHECK: index
698

699
  // Floating-point types.
700
  MlirType bf16 = mlirBF16TypeGet(ctx);
701
  MlirType f16 = mlirF16TypeGet(ctx);
702
  MlirType f32 = mlirF32TypeGet(ctx);
703
  MlirType f64 = mlirF64TypeGet(ctx);
704
  if (!mlirTypeIsABF16(bf16))
705
    return 7;
706
  if (!mlirTypeIsAF16(f16))
707
    return 9;
708
  if (!mlirTypeIsAF32(f32))
709
    return 10;
710
  if (!mlirTypeIsAF64(f64))
711
    return 11;
712
  mlirTypeDump(bf16);
713
  fprintf(stderr, "\n");
714
  mlirTypeDump(f16);
715
  fprintf(stderr, "\n");
716
  mlirTypeDump(f32);
717
  fprintf(stderr, "\n");
718
  mlirTypeDump(f64);
719
  fprintf(stderr, "\n");
720
  // CHECK: bf16
721
  // CHECK: f16
722
  // CHECK: f32
723
  // CHECK: f64
724

725
  // None type.
726
  MlirType none = mlirNoneTypeGet(ctx);
727
  if (!mlirTypeIsANone(none))
728
    return 12;
729
  mlirTypeDump(none);
730
  fprintf(stderr, "\n");
731
  // CHECK: none
732

733
  // Complex type.
734
  MlirType cplx = mlirComplexTypeGet(f32);
735
  if (!mlirTypeIsAComplex(cplx) ||
736
      !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
737
    return 13;
738
  mlirTypeDump(cplx);
739
  fprintf(stderr, "\n");
740
  // CHECK: complex<f32>
741

742
  // Vector (and Shaped) type. ShapedType is a common base class for vectors,
743
  // memrefs and tensors, one cannot create instances of this class so it is
744
  // tested on an instance of vector type.
745
  int64_t shape[] = {2, 3};
746
  MlirType vector =
747
      mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
748
  if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
749
    return 14;
750
  if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
751
      !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
752
      mlirShapedTypeGetDimSize(vector, 0) != 2 ||
753
      mlirShapedTypeIsDynamicDim(vector, 0) ||
754
      mlirShapedTypeGetDimSize(vector, 1) != 3 ||
755
      !mlirShapedTypeHasStaticShape(vector))
756
    return 15;
757
  mlirTypeDump(vector);
758
  fprintf(stderr, "\n");
759
  // CHECK: vector<2x3xf32>
760

761
  // Scalable vector type.
762
  bool scalable[] = {false, true};
763
  MlirType scalableVector = mlirVectorTypeGetScalable(
764
      sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
765
  if (!mlirTypeIsAVector(scalableVector))
766
    return 16;
767
  if (!mlirVectorTypeIsScalable(scalableVector) ||
768
      mlirVectorTypeIsDimScalable(scalableVector, 0) ||
769
      !mlirVectorTypeIsDimScalable(scalableVector, 1))
770
    return 17;
771
  mlirTypeDump(scalableVector);
772
  fprintf(stderr, "\n");
773
  // CHECK: vector<2x[3]xf32>
774

775
  // Ranked tensor type.
776
  MlirType rankedTensor = mlirRankedTensorTypeGet(
777
      sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
778
  if (!mlirTypeIsATensor(rankedTensor) ||
779
      !mlirTypeIsARankedTensor(rankedTensor) ||
780
      !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
781
    return 18;
782
  mlirTypeDump(rankedTensor);
783
  fprintf(stderr, "\n");
784
  // CHECK: tensor<2x3xf32>
785

786
  // Unranked tensor type.
787
  MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
788
  if (!mlirTypeIsATensor(unrankedTensor) ||
789
      !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
790
      mlirShapedTypeHasRank(unrankedTensor))
791
    return 19;
792
  mlirTypeDump(unrankedTensor);
793
  fprintf(stderr, "\n");
794
  // CHECK: tensor<*xf32>
795

796
  // MemRef type.
797
  MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2);
798
  MlirType memRef = mlirMemRefTypeContiguousGet(
799
      f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
800
  if (!mlirTypeIsAMemRef(memRef) ||
801
      !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
802
    return 20;
803
  mlirTypeDump(memRef);
804
  fprintf(stderr, "\n");
805
  // CHECK: memref<2x3xf32, 2>
806

807
  // Unranked MemRef type.
808
  MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4);
809
  MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4);
810
  if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
811
      mlirTypeIsAMemRef(unrankedMemRef) ||
812
      !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
813
                          memSpace4))
814
    return 21;
815
  mlirTypeDump(unrankedMemRef);
816
  fprintf(stderr, "\n");
817
  // CHECK: memref<*xf32, 4>
818

819
  // Tuple type.
820
  MlirType types[] = {unrankedMemRef, f32};
821
  MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
822
  if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
823
      !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
824
      !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
825
    return 22;
826
  mlirTypeDump(tuple);
827
  fprintf(stderr, "\n");
828
  // CHECK: tuple<memref<*xf32, 4>, f32>
829

830
  // Function type.
831
  MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
832
  MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
833
                             mlirIntegerTypeGet(ctx, 32),
834
                             mlirIntegerTypeGet(ctx, 64)};
835
  MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
836
  if (mlirFunctionTypeGetNumInputs(funcType) != 2)
837
    return 23;
838
  if (mlirFunctionTypeGetNumResults(funcType) != 3)
839
    return 24;
840
  if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
841
      !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
842
    return 25;
843
  if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
844
      !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
845
      !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
846
    return 26;
847
  mlirTypeDump(funcType);
848
  fprintf(stderr, "\n");
849
  // CHECK: (index, i1) -> (i16, i32, i64)
850

851
  // Opaque type.
852
  MlirStringRef namespace = mlirStringRefCreate("dialect", 7);
853
  MlirStringRef data = mlirStringRefCreate("type", 4);
854
  mlirContextSetAllowUnregisteredDialects(ctx, true);
855
  MlirType opaque = mlirOpaqueTypeGet(ctx, namespace, data);
856
  mlirContextSetAllowUnregisteredDialects(ctx, false);
857
  if (!mlirTypeIsAOpaque(opaque) ||
858
      !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
859
                          namespace) ||
860
      !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
861
    return 27;
862
  mlirTypeDump(opaque);
863
  fprintf(stderr, "\n");
864
  // CHECK: !dialect.type
865

866
  return 0;
867
}
868

869
void callbackSetFixedLengthString(const char *data, intptr_t len,
870
                                  void *userData) {
871
  strncpy(userData, data, len);
872
}
873

874
bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
875
  if (strlen(lhs) != rhs.length) {
876
    return false;
877
  }
878
  return !strncmp(lhs, rhs.data, rhs.length);
879
}
880

881
int printBuiltinAttributes(MlirContext ctx) {
882
  MlirAttribute floating =
883
      mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
884
  if (!mlirAttributeIsAFloat(floating) ||
885
      fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
886
    return 1;
887
  fprintf(stderr, "@attrs\n");
888
  mlirAttributeDump(floating);
889
  // CHECK-LABEL: @attrs
890
  // CHECK: 2.000000e+00 : f64
891

892
  // Exercise mlirAttributeGetType() just for the first one.
893
  MlirType floatingType = mlirAttributeGetType(floating);
894
  mlirTypeDump(floatingType);
895
  // CHECK: f64
896

897
  MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
898
  MlirAttribute signedInteger =
899
      mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx, 8), -1);
900
  MlirAttribute unsignedInteger =
901
      mlirIntegerAttrGet(mlirIntegerTypeUnsignedGet(ctx, 8), 255);
902
  if (!mlirAttributeIsAInteger(integer) ||
903
      mlirIntegerAttrGetValueInt(integer) != 42 ||
904
      mlirIntegerAttrGetValueSInt(signedInteger) != -1 ||
905
      mlirIntegerAttrGetValueUInt(unsignedInteger) != 255)
906
    return 2;
907
  mlirAttributeDump(integer);
908
  mlirAttributeDump(signedInteger);
909
  mlirAttributeDump(unsignedInteger);
910
  // CHECK: 42 : i32
911
  // CHECK: -1 : si8
912
  // CHECK: 255 : ui8
913

914
  MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
915
  if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
916
    return 3;
917
  mlirAttributeDump(boolean);
918
  // CHECK: true
919

920
  const char data[] = "abcdefghijklmnopqestuvwxyz";
921
  MlirAttribute opaque =
922
      mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("func"), 3, data,
923
                        mlirNoneTypeGet(ctx));
924
  if (!mlirAttributeIsAOpaque(opaque) ||
925
      !stringIsEqual("func", mlirOpaqueAttrGetDialectNamespace(opaque)))
926
    return 4;
927

928
  MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
929
  if (opaqueData.length != 3 ||
930
      strncmp(data, opaqueData.data, opaqueData.length))
931
    return 5;
932
  mlirAttributeDump(opaque);
933
  // CHECK: #func.abc
934

935
  MlirAttribute string =
936
      mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
937
  if (!mlirAttributeIsAString(string))
938
    return 6;
939

940
  MlirStringRef stringValue = mlirStringAttrGetValue(string);
941
  if (stringValue.length != 2 ||
942
      strncmp(data + 3, stringValue.data, stringValue.length))
943
    return 7;
944
  mlirAttributeDump(string);
945
  // CHECK: "de"
946

947
  MlirAttribute flatSymbolRef =
948
      mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
949
  if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
950
    return 8;
951

952
  MlirStringRef flatSymbolRefValue =
953
      mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
954
  if (flatSymbolRefValue.length != 3 ||
955
      strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
956
    return 9;
957
  mlirAttributeDump(flatSymbolRef);
958
  // CHECK: @fgh
959

960
  MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
961
  MlirAttribute symbolRef =
962
      mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
963
  if (!mlirAttributeIsASymbolRef(symbolRef) ||
964
      mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
965
      !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
966
                          flatSymbolRef) ||
967
      !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
968
                          flatSymbolRef))
969
    return 10;
970

971
  MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
972
  MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
973
  if (symbolRefLeaf.length != 3 ||
974
      strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
975
      symbolRefRoot.length != 2 ||
976
      strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
977
    return 11;
978
  mlirAttributeDump(symbolRef);
979
  // CHECK: @ij::@fgh::@fgh
980

981
  MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
982
  if (!mlirAttributeIsAType(type) ||
983
      !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
984
    return 12;
985
  mlirAttributeDump(type);
986
  // CHECK: f32
987

988
  MlirAttribute unit = mlirUnitAttrGet(ctx);
989
  if (!mlirAttributeIsAUnit(unit))
990
    return 13;
991
  mlirAttributeDump(unit);
992
  // CHECK: unit
993

994
  int64_t shape[] = {1, 2};
995

996
  int bools[] = {0, 1};
997
  uint8_t uints8[] = {0u, 1u};
998
  int8_t ints8[] = {0, 1};
999
  uint16_t uints16[] = {0u, 1u};
1000
  int16_t ints16[] = {0, 1};
1001
  uint32_t uints32[] = {0u, 1u};
1002
  int32_t ints32[] = {0, 1};
1003
  uint64_t uints64[] = {0u, 1u};
1004
  int64_t ints64[] = {0, 1};
1005
  float floats[] = {0.0f, 1.0f};
1006
  double doubles[] = {0.0, 1.0};
1007
  uint16_t bf16s[] = {0x0, 0x3f80};
1008
  uint16_t f16s[] = {0x0, 0x3c00};
1009
  MlirAttribute encoding = mlirAttributeGetNull();
1010
  MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
1011
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
1012
      2, bools);
1013
  MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
1014
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
1015
                              encoding),
1016
      2, uints8);
1017
  MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
1018
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
1019
      2, ints8);
1020
  MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
1021
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
1022
                              encoding),
1023
      2, uints16);
1024
  MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
1025
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
1026
      2, ints16);
1027
  MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
1028
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
1029
                              encoding),
1030
      2, uints32);
1031
  MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
1032
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
1033
      2, ints32);
1034
  MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
1035
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
1036
                              encoding),
1037
      2, uints64);
1038
  MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
1039
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1040
      2, ints64);
1041
  MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
1042
      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2,
1043
      floats);
1044
  MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
1045
      mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
1046
      doubles);
1047
  MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get(
1048
      mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2,
1049
      bf16s);
1050
  MlirAttribute f16Elements = mlirDenseElementsAttrFloat16Get(
1051
      mlirRankedTensorTypeGet(2, shape, mlirF16TypeGet(ctx), encoding), 2,
1052
      f16s);
1053

1054
  if (!mlirAttributeIsADenseElements(boolElements) ||
1055
      !mlirAttributeIsADenseElements(uint8Elements) ||
1056
      !mlirAttributeIsADenseElements(int8Elements) ||
1057
      !mlirAttributeIsADenseElements(uint32Elements) ||
1058
      !mlirAttributeIsADenseElements(int32Elements) ||
1059
      !mlirAttributeIsADenseElements(uint64Elements) ||
1060
      !mlirAttributeIsADenseElements(int64Elements) ||
1061
      !mlirAttributeIsADenseElements(floatElements) ||
1062
      !mlirAttributeIsADenseElements(doubleElements) ||
1063
      !mlirAttributeIsADenseElements(bf16Elements) ||
1064
      !mlirAttributeIsADenseElements(f16Elements))
1065
    return 14;
1066

1067
  if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
1068
      mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
1069
      mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
1070
      mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
1071
      mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
1072
      mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
1073
      mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
1074
      mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
1075
      mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
1076
      fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
1077
          1E-6f ||
1078
      fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
1079
    return 15;
1080

1081
  mlirAttributeDump(boolElements);
1082
  mlirAttributeDump(uint8Elements);
1083
  mlirAttributeDump(int8Elements);
1084
  mlirAttributeDump(uint32Elements);
1085
  mlirAttributeDump(int32Elements);
1086
  mlirAttributeDump(uint64Elements);
1087
  mlirAttributeDump(int64Elements);
1088
  mlirAttributeDump(floatElements);
1089
  mlirAttributeDump(doubleElements);
1090
  mlirAttributeDump(bf16Elements);
1091
  mlirAttributeDump(f16Elements);
1092
  // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
1093
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
1094
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
1095
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
1096
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
1097
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
1098
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
1099
  // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
1100
  // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
1101
  // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16>
1102
  // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf16>
1103

1104
  MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
1105
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
1106
      1);
1107
  MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
1108
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
1109
                              encoding),
1110
      1);
1111
  MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
1112
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
1113
      1);
1114
  MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
1115
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
1116
                              encoding),
1117
      1);
1118
  MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
1119
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
1120
      1);
1121
  MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
1122
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
1123
                              encoding),
1124
      1);
1125
  MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
1126
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1127
      1);
1128
  MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
1129
      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
1130
  MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
1131
      mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0);
1132

1133
  if (!mlirAttributeIsADenseElements(splatBool) ||
1134
      !mlirDenseElementsAttrIsSplat(splatBool) ||
1135
      !mlirAttributeIsADenseElements(splatUInt8) ||
1136
      !mlirDenseElementsAttrIsSplat(splatUInt8) ||
1137
      !mlirAttributeIsADenseElements(splatInt8) ||
1138
      !mlirDenseElementsAttrIsSplat(splatInt8) ||
1139
      !mlirAttributeIsADenseElements(splatUInt32) ||
1140
      !mlirDenseElementsAttrIsSplat(splatUInt32) ||
1141
      !mlirAttributeIsADenseElements(splatInt32) ||
1142
      !mlirDenseElementsAttrIsSplat(splatInt32) ||
1143
      !mlirAttributeIsADenseElements(splatUInt64) ||
1144
      !mlirDenseElementsAttrIsSplat(splatUInt64) ||
1145
      !mlirAttributeIsADenseElements(splatInt64) ||
1146
      !mlirDenseElementsAttrIsSplat(splatInt64) ||
1147
      !mlirAttributeIsADenseElements(splatFloat) ||
1148
      !mlirDenseElementsAttrIsSplat(splatFloat) ||
1149
      !mlirAttributeIsADenseElements(splatDouble) ||
1150
      !mlirDenseElementsAttrIsSplat(splatDouble))
1151
    return 16;
1152

1153
  if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
1154
      mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
1155
      mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
1156
      mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
1157
      mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
1158
      mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
1159
      mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
1160
      fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
1161
          1E-6f ||
1162
      fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
1163
    return 17;
1164

1165
  const uint8_t *uint8RawData =
1166
      (const uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
1167
  const int8_t *int8RawData =
1168
      (const int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
1169
  const uint32_t *uint32RawData =
1170
      (const uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
1171
  const int32_t *int32RawData =
1172
      (const int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
1173
  const uint64_t *uint64RawData =
1174
      (const uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
1175
  const int64_t *int64RawData =
1176
      (const int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
1177
  const float *floatRawData =
1178
      (const float *)mlirDenseElementsAttrGetRawData(floatElements);
1179
  const double *doubleRawData =
1180
      (const double *)mlirDenseElementsAttrGetRawData(doubleElements);
1181
  const uint16_t *bf16RawData =
1182
      (const uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements);
1183
  const uint16_t *f16RawData =
1184
      (const uint16_t *)mlirDenseElementsAttrGetRawData(f16Elements);
1185
  if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
1186
      int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
1187
      int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
1188
      uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
1189
      floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
1190
      doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 ||
1191
      bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80 || f16RawData[0] != 0 ||
1192
      f16RawData[1] != 0x3c00)
1193
    return 18;
1194

1195
  mlirAttributeDump(splatBool);
1196
  mlirAttributeDump(splatUInt8);
1197
  mlirAttributeDump(splatInt8);
1198
  mlirAttributeDump(splatUInt32);
1199
  mlirAttributeDump(splatInt32);
1200
  mlirAttributeDump(splatUInt64);
1201
  mlirAttributeDump(splatInt64);
1202
  mlirAttributeDump(splatFloat);
1203
  mlirAttributeDump(splatDouble);
1204
  // CHECK: dense<true> : tensor<1x2xi1>
1205
  // CHECK: dense<1> : tensor<1x2xui8>
1206
  // CHECK: dense<1> : tensor<1x2xi8>
1207
  // CHECK: dense<1> : tensor<1x2xui32>
1208
  // CHECK: dense<1> : tensor<1x2xi32>
1209
  // CHECK: dense<1> : tensor<1x2xui64>
1210
  // CHECK: dense<1> : tensor<1x2xi64>
1211
  // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
1212
  // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
1213

1214
  mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
1215
  mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
1216
  mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64));
1217
  mlirAttributeDump(mlirElementsAttrGetValue(f16Elements, 2, uints64));
1218
  // CHECK: 1.000000e+00 : f32
1219
  // CHECK: 1.000000e+00 : f64
1220
  // CHECK: 1.000000e+00 : bf16
1221
  // CHECK: 1.000000e+00 : f16
1222

1223
  int64_t indices[] = {0, 1};
1224
  int64_t one = 1;
1225
  MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
1226
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1227
      2, indices);
1228
  MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
1229
      mlirRankedTensorTypeGet(1, &one, mlirF32TypeGet(ctx), encoding), 1,
1230
      floats);
1231
  MlirAttribute sparseAttr = mlirSparseElementsAttribute(
1232
      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
1233
      indicesAttr, valuesAttr);
1234
  mlirAttributeDump(sparseAttr);
1235
  // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32>
1236

1237
  MlirAttribute boolArray = mlirDenseBoolArrayGet(ctx, 2, bools);
1238
  MlirAttribute int8Array = mlirDenseI8ArrayGet(ctx, 2, ints8);
1239
  MlirAttribute int16Array = mlirDenseI16ArrayGet(ctx, 2, ints16);
1240
  MlirAttribute int32Array = mlirDenseI32ArrayGet(ctx, 2, ints32);
1241
  MlirAttribute int64Array = mlirDenseI64ArrayGet(ctx, 2, ints64);
1242
  MlirAttribute floatArray = mlirDenseF32ArrayGet(ctx, 2, floats);
1243
  MlirAttribute doubleArray = mlirDenseF64ArrayGet(ctx, 2, doubles);
1244
  if (!mlirAttributeIsADenseBoolArray(boolArray) ||
1245
      !mlirAttributeIsADenseI8Array(int8Array) ||
1246
      !mlirAttributeIsADenseI16Array(int16Array) ||
1247
      !mlirAttributeIsADenseI32Array(int32Array) ||
1248
      !mlirAttributeIsADenseI64Array(int64Array) ||
1249
      !mlirAttributeIsADenseF32Array(floatArray) ||
1250
      !mlirAttributeIsADenseF64Array(doubleArray))
1251
    return 19;
1252

1253
  if (mlirDenseArrayGetNumElements(boolArray) != 2 ||
1254
      mlirDenseArrayGetNumElements(int8Array) != 2 ||
1255
      mlirDenseArrayGetNumElements(int16Array) != 2 ||
1256
      mlirDenseArrayGetNumElements(int32Array) != 2 ||
1257
      mlirDenseArrayGetNumElements(int64Array) != 2 ||
1258
      mlirDenseArrayGetNumElements(floatArray) != 2 ||
1259
      mlirDenseArrayGetNumElements(doubleArray) != 2)
1260
    return 20;
1261

1262
  if (mlirDenseBoolArrayGetElement(boolArray, 1) != 1 ||
1263
      mlirDenseI8ArrayGetElement(int8Array, 1) != 1 ||
1264
      mlirDenseI16ArrayGetElement(int16Array, 1) != 1 ||
1265
      mlirDenseI32ArrayGetElement(int32Array, 1) != 1 ||
1266
      mlirDenseI64ArrayGetElement(int64Array, 1) != 1 ||
1267
      fabsf(mlirDenseF32ArrayGetElement(floatArray, 1) - 1.0f) > 1E-6f ||
1268
      fabs(mlirDenseF64ArrayGetElement(doubleArray, 1) - 1.0) > 1E-6)
1269
    return 21;
1270

1271
  int64_t layoutStrides[3] = {5, 7, 13};
1272
  MlirAttribute stridedLayoutAttr =
1273
      mlirStridedLayoutAttrGet(ctx, 42, 3, &layoutStrides[0]);
1274

1275
  // CHECK: strided<[5, 7, 13], offset: 42>
1276
  mlirAttributeDump(stridedLayoutAttr);
1277

1278
  if (mlirStridedLayoutAttrGetOffset(stridedLayoutAttr) != 42 ||
1279
      mlirStridedLayoutAttrGetNumStrides(stridedLayoutAttr) != 3 ||
1280
      mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 0) != 5 ||
1281
      mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 1) != 7 ||
1282
      mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 2) != 13)
1283
    return 22;
1284

1285
  MlirAttribute uint8Blob = mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
1286
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
1287
                              encoding),
1288
      mlirStringRefCreateFromCString("resource_ui8"), 2, uints8);
1289
  MlirAttribute uint16Blob = mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
1290
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
1291
                              encoding),
1292
      mlirStringRefCreateFromCString("resource_ui16"), 2, uints16);
1293
  MlirAttribute uint32Blob = mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
1294
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
1295
                              encoding),
1296
      mlirStringRefCreateFromCString("resource_ui32"), 2, uints32);
1297
  MlirAttribute uint64Blob = mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
1298
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
1299
                              encoding),
1300
      mlirStringRefCreateFromCString("resource_ui64"), 2, uints64);
1301
  MlirAttribute int8Blob = mlirUnmanagedDenseInt8ResourceElementsAttrGet(
1302
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
1303
      mlirStringRefCreateFromCString("resource_i8"), 2, ints8);
1304
  MlirAttribute int16Blob = mlirUnmanagedDenseInt16ResourceElementsAttrGet(
1305
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
1306
      mlirStringRefCreateFromCString("resource_i16"), 2, ints16);
1307
  MlirAttribute int32Blob = mlirUnmanagedDenseInt32ResourceElementsAttrGet(
1308
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
1309
      mlirStringRefCreateFromCString("resource_i32"), 2, ints32);
1310
  MlirAttribute int64Blob = mlirUnmanagedDenseInt64ResourceElementsAttrGet(
1311
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1312
      mlirStringRefCreateFromCString("resource_i64"), 2, ints64);
1313
  MlirAttribute floatsBlob = mlirUnmanagedDenseFloatResourceElementsAttrGet(
1314
      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
1315
      mlirStringRefCreateFromCString("resource_f32"), 2, floats);
1316
  MlirAttribute doublesBlob = mlirUnmanagedDenseDoubleResourceElementsAttrGet(
1317
      mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
1318
      mlirStringRefCreateFromCString("resource_f64"), 2, doubles);
1319
  MlirAttribute blobBlob = mlirUnmanagedDenseResourceElementsAttrGet(
1320
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1321
      mlirStringRefCreateFromCString("resource_i64_blob"), /*data=*/uints64,
1322
      /*dataLength=*/sizeof(uints64),
1323
      /*dataAlignment=*/_Alignof(uint64_t),
1324
      /*dataIsMutable=*/false,
1325
      /*deleter=*/reportResourceDelete,
1326
      /*userData=*/(void *)&resourceI64BlobUserData);
1327

1328
  mlirAttributeDump(uint8Blob);
1329
  mlirAttributeDump(uint16Blob);
1330
  mlirAttributeDump(uint32Blob);
1331
  mlirAttributeDump(uint64Blob);
1332
  mlirAttributeDump(int8Blob);
1333
  mlirAttributeDump(int16Blob);
1334
  mlirAttributeDump(int32Blob);
1335
  mlirAttributeDump(int64Blob);
1336
  mlirAttributeDump(floatsBlob);
1337
  mlirAttributeDump(doublesBlob);
1338
  mlirAttributeDump(blobBlob);
1339
  // CHECK: dense_resource<resource_ui8> : tensor<1x2xui8>
1340
  // CHECK: dense_resource<resource_ui16> : tensor<1x2xui16>
1341
  // CHECK: dense_resource<resource_ui32> : tensor<1x2xui32>
1342
  // CHECK: dense_resource<resource_ui64> : tensor<1x2xui64>
1343
  // CHECK: dense_resource<resource_i8> : tensor<1x2xi8>
1344
  // CHECK: dense_resource<resource_i16> : tensor<1x2xi16>
1345
  // CHECK: dense_resource<resource_i32> : tensor<1x2xi32>
1346
  // CHECK: dense_resource<resource_i64> : tensor<1x2xi64>
1347
  // CHECK: dense_resource<resource_f32> : tensor<1x2xf32>
1348
  // CHECK: dense_resource<resource_f64> : tensor<1x2xf64>
1349
  // CHECK: dense_resource<resource_i64_blob> : tensor<1x2xi64>
1350

1351
  if (mlirDenseUInt8ResourceElementsAttrGetValue(uint8Blob, 1) != 1 ||
1352
      mlirDenseUInt16ResourceElementsAttrGetValue(uint16Blob, 1) != 1 ||
1353
      mlirDenseUInt32ResourceElementsAttrGetValue(uint32Blob, 1) != 1 ||
1354
      mlirDenseUInt64ResourceElementsAttrGetValue(uint64Blob, 1) != 1 ||
1355
      mlirDenseInt8ResourceElementsAttrGetValue(int8Blob, 1) != 1 ||
1356
      mlirDenseInt16ResourceElementsAttrGetValue(int16Blob, 1) != 1 ||
1357
      mlirDenseInt32ResourceElementsAttrGetValue(int32Blob, 1) != 1 ||
1358
      mlirDenseInt64ResourceElementsAttrGetValue(int64Blob, 1) != 1 ||
1359
      fabsf(mlirDenseF32ArrayGetElement(floatArray, 1) - 1.0f) > 1E-6f ||
1360
      fabsf(mlirDenseFloatResourceElementsAttrGetValue(floatsBlob, 1) - 1.0f) >
1361
          1e-6 ||
1362
      fabs(mlirDenseDoubleResourceElementsAttrGetValue(doublesBlob, 1) - 1.0f) >
1363
          1e-6 ||
1364
      mlirDenseUInt64ResourceElementsAttrGetValue(blobBlob, 1) != 1)
1365
    return 23;
1366

1367
  MlirLocation loc = mlirLocationUnknownGet(ctx);
1368
  MlirAttribute locAttr = mlirLocationGetAttribute(loc);
1369
  if (!mlirAttributeIsALocation(locAttr))
1370
    return 24;
1371

1372
  return 0;
1373
}
1374

1375
int printAffineMap(MlirContext ctx) {
1376
  MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
1377
  MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
1378
  MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
1379
  MlirAffineMap multiDimIdentityAffineMap =
1380
      mlirAffineMapMultiDimIdentityGet(ctx, 3);
1381
  MlirAffineMap minorIdentityAffineMap =
1382
      mlirAffineMapMinorIdentityGet(ctx, 3, 2);
1383
  unsigned permutation[] = {1, 2, 0};
1384
  MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
1385
      ctx, sizeof(permutation) / sizeof(unsigned), permutation);
1386

1387
  fprintf(stderr, "@affineMap\n");
1388
  mlirAffineMapDump(emptyAffineMap);
1389
  mlirAffineMapDump(affineMap);
1390
  mlirAffineMapDump(constAffineMap);
1391
  mlirAffineMapDump(multiDimIdentityAffineMap);
1392
  mlirAffineMapDump(minorIdentityAffineMap);
1393
  mlirAffineMapDump(permutationAffineMap);
1394
  // CHECK-LABEL: @affineMap
1395
  // CHECK: () -> ()
1396
  // CHECK: (d0, d1, d2)[s0, s1] -> ()
1397
  // CHECK: () -> (2)
1398
  // CHECK: (d0, d1, d2) -> (d0, d1, d2)
1399
  // CHECK: (d0, d1, d2) -> (d1, d2)
1400
  // CHECK: (d0, d1, d2) -> (d1, d2, d0)
1401

1402
  if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
1403
      mlirAffineMapIsIdentity(affineMap) ||
1404
      mlirAffineMapIsIdentity(constAffineMap) ||
1405
      !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
1406
      mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
1407
      mlirAffineMapIsIdentity(permutationAffineMap))
1408
    return 1;
1409

1410
  if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
1411
      mlirAffineMapIsMinorIdentity(affineMap) ||
1412
      !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
1413
      !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
1414
      mlirAffineMapIsMinorIdentity(permutationAffineMap))
1415
    return 2;
1416

1417
  if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
1418
      mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
1419
      mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
1420
      mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
1421
      mlirAffineMapIsEmpty(permutationAffineMap))
1422
    return 3;
1423

1424
  if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
1425
      mlirAffineMapIsSingleConstant(affineMap) ||
1426
      !mlirAffineMapIsSingleConstant(constAffineMap) ||
1427
      mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
1428
      mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
1429
      mlirAffineMapIsSingleConstant(permutationAffineMap))
1430
    return 4;
1431

1432
  if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
1433
    return 5;
1434

1435
  if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
1436
      mlirAffineMapGetNumDims(affineMap) != 3 ||
1437
      mlirAffineMapGetNumDims(constAffineMap) != 0 ||
1438
      mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
1439
      mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
1440
      mlirAffineMapGetNumDims(permutationAffineMap) != 3)
1441
    return 6;
1442

1443
  if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
1444
      mlirAffineMapGetNumSymbols(affineMap) != 2 ||
1445
      mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
1446
      mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
1447
      mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
1448
      mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
1449
    return 7;
1450

1451
  if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
1452
      mlirAffineMapGetNumResults(affineMap) != 0 ||
1453
      mlirAffineMapGetNumResults(constAffineMap) != 1 ||
1454
      mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
1455
      mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
1456
      mlirAffineMapGetNumResults(permutationAffineMap) != 3)
1457
    return 8;
1458

1459
  if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
1460
      mlirAffineMapGetNumInputs(affineMap) != 5 ||
1461
      mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
1462
      mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
1463
      mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
1464
      mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
1465
    return 9;
1466

1467
  if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
1468
      !mlirAffineMapIsPermutation(emptyAffineMap) ||
1469
      mlirAffineMapIsProjectedPermutation(affineMap) ||
1470
      mlirAffineMapIsPermutation(affineMap) ||
1471
      mlirAffineMapIsProjectedPermutation(constAffineMap) ||
1472
      mlirAffineMapIsPermutation(constAffineMap) ||
1473
      !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
1474
      !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
1475
      !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
1476
      mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
1477
      !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
1478
      !mlirAffineMapIsPermutation(permutationAffineMap))
1479
    return 10;
1480

1481
  intptr_t sub[] = {1};
1482

1483
  MlirAffineMap subMap = mlirAffineMapGetSubMap(
1484
      multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
1485
  MlirAffineMap majorSubMap =
1486
      mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
1487
  MlirAffineMap minorSubMap =
1488
      mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);
1489

1490
  mlirAffineMapDump(subMap);
1491
  mlirAffineMapDump(majorSubMap);
1492
  mlirAffineMapDump(minorSubMap);
1493
  // CHECK: (d0, d1, d2) -> (d1)
1494
  // CHECK: (d0, d1, d2) -> (d0)
1495
  // CHECK: (d0, d1, d2) -> (d2)
1496

1497
  // CHECK: distinct[0]<"foo">
1498
  mlirAttributeDump(mlirDisctinctAttrCreate(
1499
      mlirStringAttrGet(ctx, mlirStringRefCreateFromCString("foo"))));
1500

1501
  return 0;
1502
}
1503

1504
int printAffineExpr(MlirContext ctx) {
1505
  MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
1506
  MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
1507
  MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
1508
  MlirAffineExpr affineAddExpr =
1509
      mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
1510
  MlirAffineExpr affineMulExpr =
1511
      mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
1512
  MlirAffineExpr affineModExpr =
1513
      mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
1514
  MlirAffineExpr affineFloorDivExpr =
1515
      mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
1516
  MlirAffineExpr affineCeilDivExpr =
1517
      mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);
1518

1519
  // Tests mlirAffineExprDump.
1520
  fprintf(stderr, "@affineExpr\n");
1521
  mlirAffineExprDump(affineDimExpr);
1522
  mlirAffineExprDump(affineSymbolExpr);
1523
  mlirAffineExprDump(affineConstantExpr);
1524
  mlirAffineExprDump(affineAddExpr);
1525
  mlirAffineExprDump(affineMulExpr);
1526
  mlirAffineExprDump(affineModExpr);
1527
  mlirAffineExprDump(affineFloorDivExpr);
1528
  mlirAffineExprDump(affineCeilDivExpr);
1529
  // CHECK-LABEL: @affineExpr
1530
  // CHECK: d5
1531
  // CHECK: s5
1532
  // CHECK: 5
1533
  // CHECK: d5 + s5
1534
  // CHECK: d5 * s5
1535
  // CHECK: d5 mod s5
1536
  // CHECK: d5 floordiv s5
1537
  // CHECK: d5 ceildiv s5
1538

1539
  // Tests methods of affine binary operation expression, takes add expression
1540
  // as an example.
1541
  mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
1542
  mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
1543
  // CHECK: d5
1544
  // CHECK: s5
1545

1546
  // Tests methods of affine dimension expression.
1547
  if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
1548
    return 1;
1549

1550
  // Tests methods of affine symbol expression.
1551
  if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
1552
    return 2;
1553

1554
  // Tests methods of affine constant expression.
1555
  if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
1556
    return 3;
1557

1558
  // Tests methods of affine expression.
1559
  if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
1560
      !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
1561
      !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
1562
      mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
1563
      mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
1564
      mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
1565
      mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
1566
      mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
1567
    return 4;
1568

1569
  if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
1570
      !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
1571
      !mlirAffineExprIsPureAffine(affineConstantExpr) ||
1572
      !mlirAffineExprIsPureAffine(affineAddExpr) ||
1573
      mlirAffineExprIsPureAffine(affineMulExpr) ||
1574
      mlirAffineExprIsPureAffine(affineModExpr) ||
1575
      mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
1576
      mlirAffineExprIsPureAffine(affineCeilDivExpr))
1577
    return 5;
1578

1579
  if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
1580
      mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
1581
      mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
1582
      mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
1583
      mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
1584
      mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
1585
      mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
1586
      mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
1587
    return 6;
1588

1589
  if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
1590
      !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
1591
      !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
1592
      !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
1593
      !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
1594
      !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
1595
      !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
1596
      !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
1597
    return 7;
1598

1599
  if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
1600
      mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
1601
      mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
1602
      !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
1603
      !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
1604
      !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
1605
      !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
1606
      !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
1607
    return 8;
1608

1609
  // Tests 'IsA' methods of affine binary operation expression.
1610
  if (!mlirAffineExprIsAAdd(affineAddExpr))
1611
    return 9;
1612

1613
  if (!mlirAffineExprIsAMul(affineMulExpr))
1614
    return 10;
1615

1616
  if (!mlirAffineExprIsAMod(affineModExpr))
1617
    return 11;
1618

1619
  if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
1620
    return 12;
1621

1622
  if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
1623
    return 13;
1624

1625
  if (!mlirAffineExprIsABinary(affineAddExpr))
1626
    return 14;
1627

1628
  // Test other 'IsA' method on affine expressions.
1629
  if (!mlirAffineExprIsAConstant(affineConstantExpr))
1630
    return 15;
1631

1632
  if (!mlirAffineExprIsADim(affineDimExpr))
1633
    return 16;
1634

1635
  if (!mlirAffineExprIsASymbol(affineSymbolExpr))
1636
    return 17;
1637

1638
  // Test equality and nullity.
1639
  MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
1640
  if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
1641
    return 18;
1642

1643
  if (mlirAffineExprIsNull(affineDimExpr))
1644
    return 19;
1645

1646
  return 0;
1647
}
1648

1649
int affineMapFromExprs(MlirContext ctx) {
1650
  MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
1651
  MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
1652
  MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
1653
  MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);
1654

1655
  // CHECK-LABEL: @affineMapFromExprs
1656
  fprintf(stderr, "@affineMapFromExprs");
1657
  // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
1658
  mlirAffineMapDump(map);
1659

1660
  if (mlirAffineMapGetNumResults(map) != 2)
1661
    return 1;
1662

1663
  if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
1664
    return 2;
1665

1666
  if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
1667
    return 3;
1668

1669
  MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, 1);
1670
  MlirAffineExpr composed = mlirAffineExprCompose(affineDim2Expr, map);
1671
  // CHECK: s1
1672
  mlirAffineExprDump(composed);
1673
  if (!mlirAffineExprEqual(composed, affineSymbolExpr))
1674
    return 4;
1675

1676
  return 0;
1677
}
1678

1679
int printIntegerSet(MlirContext ctx) {
1680
  MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1681

1682
  // CHECK-LABEL: @printIntegerSet
1683
  fprintf(stderr, "@printIntegerSet");
1684

1685
  // CHECK: (d0, d1)[s0] : (1 == 0)
1686
  mlirIntegerSetDump(emptySet);
1687

1688
  if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
1689
    return 1;
1690

1691
  MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1692
  if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
1693
    return 2;
1694

1695
  // Construct a set constrained by:
1696
  //   d0 - s0 == 0,
1697
  //   d1 - 42 >= 0.
1698
  MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
1699
  MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
1700
  MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
1701
  MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
1702
  MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
1703
  MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
1704
  MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
1705
  MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
1706
  MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
1707
  bool flags[] = {true, false};
1708

1709
  MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
1710
  // CHECK: (d0, d1)[s0] : (
1711
  // CHECK-DAG: d0 - s0 == 0
1712
  // CHECK-DAG: d1 - 42 >= 0
1713
  mlirIntegerSetDump(set);
1714

1715
  // Transform d1 into s0.
1716
  MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
1717
  MlirAffineExpr repl[] = {d0, s1};
1718
  MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
1719
  // CHECK: (d0)[s0, s1] : (
1720
  // CHECK-DAG: d0 - s0 == 0
1721
  // CHECK-DAG: s1 - 42 >= 0
1722
  mlirIntegerSetDump(replaced);
1723

1724
  if (mlirIntegerSetGetNumDims(set) != 2)
1725
    return 3;
1726
  if (mlirIntegerSetGetNumDims(replaced) != 1)
1727
    return 4;
1728

1729
  if (mlirIntegerSetGetNumSymbols(set) != 1)
1730
    return 5;
1731
  if (mlirIntegerSetGetNumSymbols(replaced) != 2)
1732
    return 6;
1733

1734
  if (mlirIntegerSetGetNumInputs(set) != 3)
1735
    return 7;
1736

1737
  if (mlirIntegerSetGetNumConstraints(set) != 2)
1738
    return 8;
1739

1740
  if (mlirIntegerSetGetNumEqualities(set) != 1)
1741
    return 9;
1742

1743
  if (mlirIntegerSetGetNumInequalities(set) != 1)
1744
    return 10;
1745

1746
  MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
1747
  MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
1748
  bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
1749
  bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
1750
  if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
1751
    return 11;
1752
  if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
1753
    return 12;
1754

1755
  return 0;
1756
}
1757

1758
int registerOnlyStd(void) {
1759
  MlirContext ctx = mlirContextCreate();
1760
  // The built-in dialect is always loaded.
1761
  if (mlirContextGetNumLoadedDialects(ctx) != 1)
1762
    return 1;
1763

1764
  MlirDialectHandle stdHandle = mlirGetDialectHandle__func__();
1765

1766
  MlirDialect std = mlirContextGetOrLoadDialect(
1767
      ctx, mlirDialectHandleGetNamespace(stdHandle));
1768
  if (!mlirDialectIsNull(std))
1769
    return 2;
1770

1771
  mlirDialectHandleRegisterDialect(stdHandle, ctx);
1772

1773
  std = mlirContextGetOrLoadDialect(ctx,
1774
                                    mlirDialectHandleGetNamespace(stdHandle));
1775
  if (mlirDialectIsNull(std))
1776
    return 3;
1777

1778
  MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
1779
  if (!mlirDialectEqual(std, alsoStd))
1780
    return 4;
1781

1782
  MlirStringRef stdNs = mlirDialectGetNamespace(std);
1783
  MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
1784
  if (stdNs.length != alsoStdNs.length ||
1785
      strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
1786
    return 5;
1787

1788
  fprintf(stderr, "@registration\n");
1789
  // CHECK-LABEL: @registration
1790

1791
  // CHECK: func.call is_registered: 1
1792
  fprintf(stderr, "func.call is_registered: %d\n",
1793
          mlirContextIsRegisteredOperation(
1794
              ctx, mlirStringRefCreateFromCString("func.call")));
1795

1796
  // CHECK: func.not_existing_op is_registered: 0
1797
  fprintf(stderr, "func.not_existing_op is_registered: %d\n",
1798
          mlirContextIsRegisteredOperation(
1799
              ctx, mlirStringRefCreateFromCString("func.not_existing_op")));
1800

1801
  // CHECK: not_existing_dialect.not_existing_op is_registered: 0
1802
  fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
1803
          mlirContextIsRegisteredOperation(
1804
              ctx, mlirStringRefCreateFromCString(
1805
                       "not_existing_dialect.not_existing_op")));
1806

1807
  mlirContextDestroy(ctx);
1808
  return 0;
1809
}
1810

1811
/// Tests backreference APIs
1812
static int testBackreferences(void) {
1813
  fprintf(stderr, "@test_backreferences\n");
1814

1815
  MlirContext ctx = mlirContextCreate();
1816
  mlirContextSetAllowUnregisteredDialects(ctx, true);
1817
  MlirLocation loc = mlirLocationUnknownGet(ctx);
1818

1819
  MlirOperationState opState =
1820
      mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc);
1821
  MlirRegion region = mlirRegionCreate();
1822
  MlirBlock block = mlirBlockCreate(0, NULL, NULL);
1823
  mlirRegionAppendOwnedBlock(region, block);
1824
  mlirOperationStateAddOwnedRegions(&opState, 1, &region);
1825
  MlirOperation op = mlirOperationCreate(&opState);
1826
  MlirIdentifier ident =
1827
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("identifier"));
1828

1829
  if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) {
1830
    fprintf(stderr, "ERROR: Getting context from operation failed\n");
1831
    return 1;
1832
  }
1833
  if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) {
1834
    fprintf(stderr, "ERROR: Getting parent operation from block failed\n");
1835
    return 2;
1836
  }
1837
  if (!mlirContextEqual(ctx, mlirIdentifierGetContext(ident))) {
1838
    fprintf(stderr, "ERROR: Getting context from identifier failed\n");
1839
    return 3;
1840
  }
1841

1842
  mlirOperationDestroy(op);
1843
  mlirContextDestroy(ctx);
1844

1845
  // CHECK-LABEL: @test_backreferences
1846
  return 0;
1847
}
1848

1849
/// Tests operand APIs.
1850
int testOperands(void) {
1851
  fprintf(stderr, "@testOperands\n");
1852
  // CHECK-LABEL: @testOperands
1853

1854
  MlirContext ctx = mlirContextCreate();
1855
  registerAllUpstreamDialects(ctx);
1856

1857
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
1858
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
1859
  MlirLocation loc = mlirLocationUnknownGet(ctx);
1860
  MlirType indexType = mlirIndexTypeGet(ctx);
1861

1862
  // Create some constants to use as operands.
1863
  MlirAttribute indexZeroLiteral =
1864
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1865
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1866
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1867
      indexZeroLiteral);
1868
  MlirOperationState constZeroState = mlirOperationStateGet(
1869
      mlirStringRefCreateFromCString("arith.constant"), loc);
1870
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1871
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1872
  MlirOperation constZero = mlirOperationCreate(&constZeroState);
1873
  MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
1874

1875
  MlirAttribute indexOneLiteral =
1876
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1877
  MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
1878
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1879
      indexOneLiteral);
1880
  MlirOperationState constOneState = mlirOperationStateGet(
1881
      mlirStringRefCreateFromCString("arith.constant"), loc);
1882
  mlirOperationStateAddResults(&constOneState, 1, &indexType);
1883
  mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
1884
  MlirOperation constOne = mlirOperationCreate(&constOneState);
1885
  MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
1886

1887
  // Create the operation under test.
1888
  mlirContextSetAllowUnregisteredDialects(ctx, true);
1889
  MlirOperationState opState =
1890
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
1891
  MlirValue initialOperands[] = {constZeroValue};
1892
  mlirOperationStateAddOperands(&opState, 1, initialOperands);
1893
  MlirOperation op = mlirOperationCreate(&opState);
1894

1895
  // Test operand APIs.
1896
  intptr_t numOperands = mlirOperationGetNumOperands(op);
1897
  fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands);
1898
  // CHECK: Num Operands: 1
1899

1900
  MlirValue opOperand1 = mlirOperationGetOperand(op, 0);
1901
  fprintf(stderr, "Original operand: ");
1902
  mlirValuePrint(opOperand1, printToStderr, NULL);
1903
  // CHECK: Original operand: {{.+}} arith.constant 0 : index
1904

1905
  mlirOperationSetOperand(op, 0, constOneValue);
1906
  MlirValue opOperand2 = mlirOperationGetOperand(op, 0);
1907
  fprintf(stderr, "Updated operand: ");
1908
  mlirValuePrint(opOperand2, printToStderr, NULL);
1909
  // CHECK: Updated operand: {{.+}} arith.constant 1 : index
1910

1911
  // Test op operand APIs.
1912
  MlirOpOperand use1 = mlirValueGetFirstUse(opOperand1);
1913
  if (!mlirOpOperandIsNull(use1)) {
1914
    fprintf(stderr, "ERROR: Use should be null\n");
1915
    return 1;
1916
  }
1917

1918
  MlirOpOperand use2 = mlirValueGetFirstUse(opOperand2);
1919
  if (mlirOpOperandIsNull(use2)) {
1920
    fprintf(stderr, "ERROR: Use should not be null\n");
1921
    return 2;
1922
  }
1923

1924
  fprintf(stderr, "Use owner: ");
1925
  mlirOperationPrint(mlirOpOperandGetOwner(use2), printToStderr, NULL);
1926
  fprintf(stderr, "\n");
1927
  // CHECK: Use owner: "dummy.op"
1928

1929
  fprintf(stderr, "Use operandNumber: %d\n",
1930
          mlirOpOperandGetOperandNumber(use2));
1931
  // CHECK: Use operandNumber: 0
1932

1933
  use2 = mlirOpOperandGetNextUse(use2);
1934
  if (!mlirOpOperandIsNull(use2)) {
1935
    fprintf(stderr, "ERROR: Next use should be null\n");
1936
    return 3;
1937
  }
1938

1939
  MlirOperationState op2State =
1940
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
1941
  MlirValue initialOperands2[] = {constOneValue};
1942
  mlirOperationStateAddOperands(&op2State, 1, initialOperands2);
1943
  MlirOperation op2 = mlirOperationCreate(&op2State);
1944

1945
  MlirOpOperand use3 = mlirValueGetFirstUse(constOneValue);
1946
  fprintf(stderr, "First use owner: ");
1947
  mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
1948
  fprintf(stderr, "\n");
1949
  // CHECK: First use owner: "dummy.op2"
1950

1951
  use3 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constOneValue));
1952
  fprintf(stderr, "Second use owner: ");
1953
  mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
1954
  fprintf(stderr, "\n");
1955
  // CHECK: Second use owner: "dummy.op"
1956

1957
  MlirAttribute indexTwoLiteral =
1958
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("2 : index"));
1959
  MlirNamedAttribute indexTwoValueAttr = mlirNamedAttributeGet(
1960
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1961
      indexTwoLiteral);
1962
  MlirOperationState constTwoState = mlirOperationStateGet(
1963
      mlirStringRefCreateFromCString("arith.constant"), loc);
1964
  mlirOperationStateAddResults(&constTwoState, 1, &indexType);
1965
  mlirOperationStateAddAttributes(&constTwoState, 1, &indexTwoValueAttr);
1966
  MlirOperation constTwo = mlirOperationCreate(&constTwoState);
1967
  MlirValue constTwoValue = mlirOperationGetResult(constTwo, 0);
1968

1969
  mlirValueReplaceAllUsesOfWith(constOneValue, constTwoValue);
1970

1971
  use3 = mlirValueGetFirstUse(constOneValue);
1972
  if (!mlirOpOperandIsNull(use3)) {
1973
    fprintf(stderr, "ERROR: Use should be null\n");
1974
    return 4;
1975
  }
1976

1977
  MlirOpOperand use4 = mlirValueGetFirstUse(constTwoValue);
1978
  fprintf(stderr, "First replacement use owner: ");
1979
  mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
1980
  fprintf(stderr, "\n");
1981
  // CHECK: First replacement use owner: "dummy.op"
1982

1983
  use4 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constTwoValue));
1984
  fprintf(stderr, "Second replacement use owner: ");
1985
  mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
1986
  fprintf(stderr, "\n");
1987
  // CHECK: Second replacement use owner: "dummy.op2"
1988

1989
  MlirOpOperand use5 = mlirValueGetFirstUse(constTwoValue);
1990
  MlirOpOperand use6 = mlirOpOperandGetNextUse(use5);
1991
  if (!mlirValueEqual(mlirOpOperandGetValue(use5),
1992
                      mlirOpOperandGetValue(use6))) {
1993
    fprintf(stderr,
1994
            "ERROR: First and second operand should share the same value\n");
1995
    return 5;
1996
  }
1997

1998
  mlirOperationDestroy(op);
1999
  mlirOperationDestroy(op2);
2000
  mlirOperationDestroy(constZero);
2001
  mlirOperationDestroy(constOne);
2002
  mlirOperationDestroy(constTwo);
2003
  mlirContextDestroy(ctx);
2004

2005
  return 0;
2006
}
2007

2008
/// Tests clone APIs.
2009
int testClone(void) {
2010
  fprintf(stderr, "@testClone\n");
2011
  // CHECK-LABEL: @testClone
2012

2013
  MlirContext ctx = mlirContextCreate();
2014
  registerAllUpstreamDialects(ctx);
2015

2016
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
2017
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
2018
  MlirLocation loc = mlirLocationUnknownGet(ctx);
2019
  MlirType indexType = mlirIndexTypeGet(ctx);
2020
  MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
2021

2022
  MlirAttribute indexZeroLiteral =
2023
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
2024
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
2025
      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
2026
  MlirOperationState constZeroState = mlirOperationStateGet(
2027
      mlirStringRefCreateFromCString("arith.constant"), loc);
2028
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
2029
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
2030
  MlirOperation constZero = mlirOperationCreate(&constZeroState);
2031

2032
  MlirAttribute indexOneLiteral =
2033
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
2034
  MlirOperation constOne = mlirOperationClone(constZero);
2035
  mlirOperationSetAttributeByName(constOne, valueStringRef, indexOneLiteral);
2036

2037
  mlirOperationPrint(constZero, printToStderr, NULL);
2038
  mlirOperationPrint(constOne, printToStderr, NULL);
2039
  // CHECK: arith.constant 0 : index
2040
  // CHECK: arith.constant 1 : index
2041

2042
  mlirOperationDestroy(constZero);
2043
  mlirOperationDestroy(constOne);
2044
  mlirContextDestroy(ctx);
2045
  return 0;
2046
}
2047

2048
// Wraps a diagnostic into additional text we can match against.
2049
MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
2050
  fprintf(stderr, "processing diagnostic (userData: %" PRIdPTR ") <<\n",
2051
          (intptr_t)userData);
2052
  mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
2053
  fprintf(stderr, "\n");
2054
  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
2055
  mlirLocationPrint(loc, printToStderr, NULL);
2056
  assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
2057
  fprintf(stderr, "\n>> end of diagnostic (userData: %" PRIdPTR ")\n",
2058
          (intptr_t)userData);
2059
  return mlirLogicalResultSuccess();
2060
}
2061

2062
// Logs when the delete user data callback is called
2063
static void deleteUserData(void *userData) {
2064
  fprintf(stderr, "deleting user data (userData: %" PRIdPTR ")\n",
2065
          (intptr_t)userData);
2066
}
2067

2068
int testTypeID(MlirContext ctx) {
2069
  fprintf(stderr, "@testTypeID\n");
2070

2071
  // Test getting and comparing type and attribute type ids.
2072
  MlirType i32 = mlirIntegerTypeGet(ctx, 32);
2073
  MlirTypeID i32ID = mlirTypeGetTypeID(i32);
2074
  MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
2075
  MlirTypeID ui32ID = mlirTypeGetTypeID(ui32);
2076
  MlirType f32 = mlirF32TypeGet(ctx);
2077
  MlirTypeID f32ID = mlirTypeGetTypeID(f32);
2078
  MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1);
2079
  MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr);
2080

2081
  if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) ||
2082
      mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) {
2083
    fprintf(stderr, "ERROR: Expected type ids to be present\n");
2084
    return 1;
2085
  }
2086

2087
  if (!mlirTypeIDEqual(i32ID, ui32ID) ||
2088
      mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) {
2089
    fprintf(
2090
        stderr,
2091
        "ERROR: Expected different integer types to have the same type id\n");
2092
    return 2;
2093
  }
2094

2095
  if (mlirTypeIDEqual(i32ID, f32ID)) {
2096
    fprintf(stderr,
2097
            "ERROR: Expected integer type id to not equal float type id\n");
2098
    return 3;
2099
  }
2100

2101
  if (mlirTypeIDEqual(i32ID, i32AttrID)) {
2102
    fprintf(stderr, "ERROR: Expected integer type id to not equal integer "
2103
                    "attribute type id\n");
2104
    return 4;
2105
  }
2106

2107
  MlirLocation loc = mlirLocationUnknownGet(ctx);
2108
  MlirType indexType = mlirIndexTypeGet(ctx);
2109
  MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
2110

2111
  // Create a registered operation, which should have a type id.
2112
  MlirAttribute indexZeroLiteral =
2113
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
2114
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
2115
      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
2116
  MlirOperationState constZeroState = mlirOperationStateGet(
2117
      mlirStringRefCreateFromCString("arith.constant"), loc);
2118
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
2119
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
2120
  MlirOperation constZero = mlirOperationCreate(&constZeroState);
2121

2122
  if (!mlirOperationVerify(constZero)) {
2123
    fprintf(stderr, "ERROR: Expected operation to verify correctly\n");
2124
    return 5;
2125
  }
2126

2127
  if (mlirOperationIsNull(constZero)) {
2128
    fprintf(stderr, "ERROR: Expected registered operation to be present\n");
2129
    return 6;
2130
  }
2131

2132
  MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);
2133

2134
  if (mlirTypeIDIsNull(registeredOpID)) {
2135
    fprintf(stderr,
2136
            "ERROR: Expected registered operation type id to be present\n");
2137
    return 7;
2138
  }
2139

2140
  // Create an unregistered operation, which should not have a type id.
2141
  mlirContextSetAllowUnregisteredDialects(ctx, true);
2142
  MlirOperationState opState =
2143
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
2144
  MlirOperation unregisteredOp = mlirOperationCreate(&opState);
2145
  if (mlirOperationIsNull(unregisteredOp)) {
2146
    fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
2147
    return 8;
2148
  }
2149

2150
  MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);
2151

2152
  if (!mlirTypeIDIsNull(unregisteredOpID)) {
2153
    fprintf(stderr,
2154
            "ERROR: Expected unregistered operation type id to be null\n");
2155
    return 9;
2156
  }
2157

2158
  mlirOperationDestroy(constZero);
2159
  mlirOperationDestroy(unregisteredOp);
2160

2161
  return 0;
2162
}
2163

2164
int testSymbolTable(MlirContext ctx) {
2165
  fprintf(stderr, "@testSymbolTable\n");
2166

2167
  const char *moduleString = "func.func private @foo()"
2168
                             "func.func private @bar()";
2169
  const char *otherModuleString = "func.func private @qux()"
2170
                                  "func.func private @foo()";
2171

2172
  MlirModule module =
2173
      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
2174
  MlirModule otherModule = mlirModuleCreateParse(
2175
      ctx, mlirStringRefCreateFromCString(otherModuleString));
2176

2177
  MlirSymbolTable symbolTable =
2178
      mlirSymbolTableCreate(mlirModuleGetOperation(module));
2179

2180
  MlirOperation funcFoo =
2181
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("foo"));
2182
  if (mlirOperationIsNull(funcFoo))
2183
    return 1;
2184

2185
  MlirOperation funcBar =
2186
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
2187
  if (mlirOperationEqual(funcFoo, funcBar))
2188
    return 2;
2189

2190
  MlirOperation missing =
2191
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
2192
  if (!mlirOperationIsNull(missing))
2193
    return 3;
2194

2195
  MlirBlock moduleBody = mlirModuleGetBody(module);
2196
  MlirBlock otherModuleBody = mlirModuleGetBody(otherModule);
2197
  MlirOperation operation = mlirBlockGetFirstOperation(otherModuleBody);
2198
  mlirOperationRemoveFromParent(operation);
2199
  mlirBlockAppendOwnedOperation(moduleBody, operation);
2200

2201
  // At this moment, the operation is still missing from the symbol table.
2202
  MlirOperation stillMissing =
2203
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
2204
  if (!mlirOperationIsNull(stillMissing))
2205
    return 4;
2206

2207
  // After it is added to the symbol table, and not only the operation with
2208
  // which the table is associated, it can be looked up.
2209
  mlirSymbolTableInsert(symbolTable, operation);
2210
  MlirOperation funcQux =
2211
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
2212
  if (!mlirOperationEqual(operation, funcQux))
2213
    return 5;
2214

2215
  // Erasing from the symbol table also removes the operation.
2216
  mlirSymbolTableErase(symbolTable, funcBar);
2217
  MlirOperation nowMissing =
2218
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
2219
  if (!mlirOperationIsNull(nowMissing))
2220
    return 6;
2221

2222
  // Adding a symbol with the same name to the table should rename.
2223
  MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(otherModuleBody);
2224
  mlirOperationRemoveFromParent(duplicateNameOp);
2225
  mlirBlockAppendOwnedOperation(moduleBody, duplicateNameOp);
2226
  MlirAttribute newName = mlirSymbolTableInsert(symbolTable, duplicateNameOp);
2227
  MlirStringRef newNameStr = mlirStringAttrGetValue(newName);
2228
  if (mlirStringRefEqual(newNameStr, mlirStringRefCreateFromCString("foo")))
2229
    return 7;
2230
  MlirAttribute updatedName = mlirOperationGetAttributeByName(
2231
      duplicateNameOp, mlirSymbolTableGetSymbolAttributeName());
2232
  if (!mlirAttributeEqual(updatedName, newName))
2233
    return 8;
2234

2235
  mlirOperationDump(mlirModuleGetOperation(module));
2236
  mlirOperationDump(mlirModuleGetOperation(otherModule));
2237
  // clang-format off
2238
  // CHECK-LABEL: @testSymbolTable
2239
  // CHECK: module
2240
  // CHECK:   func private @foo
2241
  // CHECK:   func private @qux
2242
  // CHECK:   func private @foo{{.+}}
2243
  // CHECK: module
2244
  // CHECK-NOT: @qux
2245
  // CHECK-NOT: @foo
2246
  // clang-format on
2247

2248
  mlirSymbolTableDestroy(symbolTable);
2249
  mlirModuleDestroy(module);
2250
  mlirModuleDestroy(otherModule);
2251

2252
  return 0;
2253
}
2254

2255
typedef struct {
2256
  const char *x;
2257
} callBackData;
2258

2259
MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
2260
  fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
2261
          mlirIdentifierStr(mlirOperationGetName(op)).data);
2262
  return MlirWalkResultAdvance;
2263
}
2264

2265
MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
2266
  fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
2267
          mlirIdentifierStr(mlirOperationGetName(op)).data);
2268
  if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
2269
      0)
2270
    return MlirWalkResultSkip;
2271
  if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
2272
      0)
2273
    return MlirWalkResultInterrupt;
2274
  return MlirWalkResultAdvance;
2275
}
2276

2277
int testOperationWalk(MlirContext ctx) {
2278
  // CHECK-LABEL: @testOperationWalk
2279
  fprintf(stderr, "@testOperationWalk\n");
2280

2281
  const char *moduleString = "module {\n"
2282
                             "  func.func @foo() {\n"
2283
                             "    %1 = arith.constant 10: i32\n"
2284
                             "    arith.addi %1, %1: i32\n"
2285
                             "    return\n"
2286
                             "  }\n"
2287
                             "  func.func @bar() {\n"
2288
                             "    return\n"
2289
                             "  }\n"
2290
                             "}";
2291
  MlirModule module =
2292
      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
2293

2294
  callBackData data;
2295
  data.x = "i love you";
2296

2297
  // CHECK-NEXT: i love you: arith.constant
2298
  // CHECK-NEXT: i love you: arith.addi
2299
  // CHECK-NEXT: i love you: func.return
2300
  // CHECK-NEXT: i love you: func.func
2301
  // CHECK-NEXT: i love you: func.return
2302
  // CHECK-NEXT: i love you: func.func
2303
  // CHECK-NEXT: i love you: builtin.module
2304
  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
2305
                    (void *)(&data), MlirWalkPostOrder);
2306

2307
  data.x = "i don't love you";
2308
  // CHECK-NEXT: i don't love you: builtin.module
2309
  // CHECK-NEXT: i don't love you: func.func
2310
  // CHECK-NEXT: i don't love you: arith.constant
2311
  // CHECK-NEXT: i don't love you: arith.addi
2312
  // CHECK-NEXT: i don't love you: func.return
2313
  // CHECK-NEXT: i don't love you: func.func
2314
  // CHECK-NEXT: i don't love you: func.return
2315
  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
2316
                    (void *)(&data), MlirWalkPreOrder);
2317

2318
  data.x = "interrupt";
2319
  // Interrupted at `arith.addi`
2320
  // CHECK-NEXT: interrupt: arith.constant
2321
  // CHECK-NEXT: interrupt: arith.addi
2322
  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
2323
                    (void *)(&data), MlirWalkPostOrder);
2324

2325
  data.x = "skip";
2326
  // Skip at `func.func`
2327
  // CHECK-NEXT: skip: builtin.module
2328
  // CHECK-NEXT: skip: func.func
2329
  // CHECK-NEXT: skip: func.func
2330
  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
2331
                    (void *)(&data), MlirWalkPreOrder);
2332

2333
  mlirModuleDestroy(module);
2334
  return 0;
2335
}
2336

2337
int testDialectRegistry(void) {
2338
  fprintf(stderr, "@testDialectRegistry\n");
2339

2340
  MlirDialectRegistry registry = mlirDialectRegistryCreate();
2341
  if (mlirDialectRegistryIsNull(registry)) {
2342
    fprintf(stderr, "ERROR: Expected registry to be present\n");
2343
    return 1;
2344
  }
2345

2346
  MlirDialectHandle stdHandle = mlirGetDialectHandle__func__();
2347
  mlirDialectHandleInsertDialect(stdHandle, registry);
2348

2349
  MlirContext ctx = mlirContextCreate();
2350
  if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
2351
    fprintf(stderr,
2352
            "ERROR: Expected no dialects to be registered to new context\n");
2353
  }
2354

2355
  mlirContextAppendDialectRegistry(ctx, registry);
2356
  if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
2357
    fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
2358
                    "registered to the context\n");
2359
  }
2360

2361
  mlirContextDestroy(ctx);
2362
  mlirDialectRegistryDestroy(registry);
2363

2364
  return 0;
2365
}
2366

2367
void testExplicitThreadPools(void) {
2368
  MlirLlvmThreadPool threadPool = mlirLlvmThreadPoolCreate();
2369
  MlirDialectRegistry registry = mlirDialectRegistryCreate();
2370
  mlirRegisterAllDialects(registry);
2371
  MlirContext context =
2372
      mlirContextCreateWithRegistry(registry, /*threadingEnabled=*/false);
2373
  mlirContextSetThreadPool(context, threadPool);
2374
  mlirContextDestroy(context);
2375
  mlirDialectRegistryDestroy(registry);
2376
  mlirLlvmThreadPoolDestroy(threadPool);
2377
}
2378

2379
void testDiagnostics(void) {
2380
  MlirContext ctx = mlirContextCreate();
2381
  MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
2382
      ctx, errorHandler, (void *)42, deleteUserData);
2383
  fprintf(stderr, "@test_diagnostics\n");
2384
  MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
2385
  mlirEmitError(unknownLoc, "test diagnostics");
2386
  MlirAttribute unknownAttr = mlirLocationGetAttribute(unknownLoc);
2387
  MlirLocation unknownClone = mlirLocationFromAttribute(unknownAttr);
2388
  mlirEmitError(unknownClone, "test clone");
2389
  MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
2390
      ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
2391
  mlirEmitError(fileLineColLoc, "test diagnostics");
2392
  MlirLocation callSiteLoc = mlirLocationCallSiteGet(
2393
      mlirLocationFileLineColGet(
2394
          ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
2395
      fileLineColLoc);
2396
  mlirEmitError(callSiteLoc, "test diagnostics");
2397
  MlirLocation null = {0};
2398
  MlirLocation nameLoc =
2399
      mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
2400
  mlirEmitError(nameLoc, "test diagnostics");
2401
  MlirLocation locs[2] = {nameLoc, callSiteLoc};
2402
  MlirAttribute nullAttr = {0};
2403
  MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
2404
  mlirEmitError(fusedLoc, "test diagnostics");
2405
  mlirContextDetachDiagnosticHandler(ctx, id);
2406
  mlirEmitError(unknownLoc, "more test diagnostics");
2407
  // CHECK-LABEL: @test_diagnostics
2408
  // CHECK: processing diagnostic (userData: 42) <<
2409
  // CHECK:   test diagnostics
2410
  // CHECK:   loc(unknown)
2411
  // CHECK: processing diagnostic (userData: 42) <<
2412
  // CHECK:   test clone
2413
  // CHECK:   loc(unknown)
2414
  // CHECK: >> end of diagnostic (userData: 42)
2415
  // CHECK: processing diagnostic (userData: 42) <<
2416
  // CHECK:   test diagnostics
2417
  // CHECK:   loc("file.c":1:2)
2418
  // CHECK: >> end of diagnostic (userData: 42)
2419
  // CHECK: processing diagnostic (userData: 42) <<
2420
  // CHECK:   test diagnostics
2421
  // CHECK:   loc(callsite("other-file.c":2:3 at "file.c":1:2))
2422
  // CHECK: >> end of diagnostic (userData: 42)
2423
  // CHECK: processing diagnostic (userData: 42) <<
2424
  // CHECK:   test diagnostics
2425
  // CHECK:   loc("named")
2426
  // CHECK: >> end of diagnostic (userData: 42)
2427
  // CHECK: processing diagnostic (userData: 42) <<
2428
  // CHECK:   test diagnostics
2429
  // CHECK:   loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
2430
  // CHECK: deleting user data (userData: 42)
2431
  // CHECK-NOT: processing diagnostic
2432
  // CHECK:     more test diagnostics
2433
  mlirContextDestroy(ctx);
2434
}
2435

2436
int main(void) {
2437
  MlirContext ctx = mlirContextCreate();
2438
  registerAllUpstreamDialects(ctx);
2439
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
2440
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("memref"));
2441
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("shape"));
2442
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("scf"));
2443

2444
  if (constructAndTraverseIr(ctx))
2445
    return 1;
2446
  buildWithInsertionsAndPrint(ctx);
2447
  if (createOperationWithTypeInference(ctx))
2448
    return 2;
2449

2450
  if (printBuiltinTypes(ctx))
2451
    return 3;
2452
  if (printBuiltinAttributes(ctx))
2453
    return 4;
2454
  if (printAffineMap(ctx))
2455
    return 5;
2456
  if (printAffineExpr(ctx))
2457
    return 6;
2458
  if (affineMapFromExprs(ctx))
2459
    return 7;
2460
  if (printIntegerSet(ctx))
2461
    return 8;
2462
  if (registerOnlyStd())
2463
    return 9;
2464
  if (testBackreferences())
2465
    return 10;
2466
  if (testOperands())
2467
    return 11;
2468
  if (testClone())
2469
    return 12;
2470
  if (testTypeID(ctx))
2471
    return 13;
2472
  if (testSymbolTable(ctx))
2473
    return 14;
2474
  if (testDialectRegistry())
2475
    return 15;
2476
  if (testOperationWalk(ctx))
2477
    return 16;
2478

2479
  testExplicitThreadPools();
2480
  testDiagnostics();
2481

2482
  // CHECK: DESTROY MAIN CONTEXT
2483
  // CHECK: reportResourceDelete: resource_i64_blob
2484
  fprintf(stderr, "DESTROY MAIN CONTEXT\n");
2485
  mlirContextDestroy(ctx);
2486

2487
  return 0;
2488
}
2489

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.