llvm-project

Форк
0
/
sgemm-naive-codegen.mlir 
79 строк · 2.9 Кб
1
// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-linalg-to-loops,lower-affine,convert-scf-to-cf,convert-arith-to-llvm),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" %s | mlir-cpu-runner -O3 -e main -entry-point-result=void -shared-libs=%mlir_c_runner_utils | FileCheck %s
2

3
func.func @main() {
4
  %A = memref.alloc() : memref<16x16xf32>
5
  %B = memref.alloc() : memref<16x16xf32>
6
  %C = memref.alloc() : memref<16x16xf32>
7

8
  %cf1 = arith.constant 1.00000e+00 : f32
9

10
  linalg.fill ins(%cf1 : f32) outs(%A : memref<16x16xf32>)
11
  linalg.fill ins(%cf1 : f32) outs(%B : memref<16x16xf32>)
12

13
  %num_reps = arith.constant 5 : index
14

15
  %t_start = call @rtclock() : () -> f64
16
  affine.for %arg0 = 0 to %num_reps {
17
    linalg.fill ins(%cf1 : f32) outs(%C : memref<16x16xf32>)
18
    func.call @sgemm_naive(%A, %B, %C) : (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) -> ()
19
  }
20
  %t_end = call @rtclock() : () -> f64
21
  %t = arith.subf %t_end, %t_start : f64
22

23
  %res = affine.load %C[0, 0]: memref<16x16xf32>
24
  vector.print %res: f32
25

26
  %c0 = arith.constant 0 : index
27
  %c1 = arith.constant 1 : index
28
  %c2 = arith.constant 2 : index
29

30
  %M = memref.dim %C, %c0 : memref<16x16xf32>
31
  %N = memref.dim %C, %c1 : memref<16x16xf32>
32
  %K = memref.dim %A, %c1 : memref<16x16xf32>
33

34
  // num_flops_per_iter = 2*M*N*K
35
  %f1 = arith.muli %M, %N : index
36
  %f2 = arith.muli %f1, %K : index
37
  %num_flops_per_iter = arith.muli %c2, %f2 : index
38

39
  // num_flops_total = num_flops_per_iter * num_reps
40
  %num_flops_total = arith.muli %num_flops_per_iter, %num_reps: index
41

42
  // Print the number of flops per second
43
  %num_flops_total_i = arith.index_cast %num_flops_total : index to i16
44
  %num_flops_total_f = arith.uitofp %num_flops_total_i : i16 to f64
45
  %flops_per_s = arith.divf %num_flops_total_f, %t : f64
46
  call @printFlops(%flops_per_s) : (f64) -> ()
47

48
  memref.dealloc %A : memref<16x16xf32>
49
  memref.dealloc %B : memref<16x16xf32>
50
  memref.dealloc %C : memref<16x16xf32>
51
  return
52
}
53
// CHECK: 17
54

55
func.func @sgemm_naive(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>) {
56
  %c0 = arith.constant 0 : index
57
  affine.for %arg3 = 0 to 16 {
58
    affine.for %arg4 = 0 to 16 {
59
      %m = memref.alloc() : memref<1xf32>
60
      %v = affine.load %arg2[%arg3, %arg4] : memref<16x16xf32>
61
      affine.store %v, %m[%c0] : memref<1xf32>
62
      affine.for %arg5 = 0 to 16 {
63
        %3 = affine.load %arg0[%arg3, %arg5] : memref<16x16xf32>
64
        %4 = affine.load %arg1[%arg5, %arg4] : memref<16x16xf32>
65
        %5 = affine.load %m[0] : memref<1xf32>
66
        %6 = arith.mulf %3, %4 : f32
67
        %7 = arith.addf %6, %5 : f32
68
        affine.store %7, %m[0] : memref<1xf32>
69
      }
70
      %s = affine.load %m[%c0] : memref<1xf32>
71
      affine.store %s, %arg2[%arg3, %arg4] : memref<16x16xf32>
72
      memref.dealloc %m : memref<1xf32>
73
    }
74
  }
75
  return
76
}
77

78
func.func private @printFlops(f64)
79
func.func private @rtclock() -> f64
80

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.