pytorch

Форк
0
/
log_extract.py 
153 строки · 4.4 Кб
1
import argparse
2
import functools
3
import traceback
4
from typing import Callable, List, Optional, Tuple
5

6
from torch.utils.jit.log_extract import (
7
    extract_ir,
8
    load_graph_and_inputs,
9
    run_baseline_no_fusion,
10
    run_nnc,
11
    run_nvfuser,
12
)
13

14

15
"""
16
Usage:
17
1. Run your script and pipe into a log file
18
  PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
19
2. Run log_extract:
20
  log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static
21

22
You can also extract the list of extracted IR:
23
  log_extract.py log.txt --output
24

25
Passing in --graphs 0 2 will only run graphs 0 and 2
26
"""
27

28

29
def test_runners(
30
    graphs: List[str],
31
    runners: List[Tuple[str, Callable]],
32
    graph_set: Optional[List[int]],
33
):
34
    for i, ir in enumerate(graphs):
35
        _, inputs = load_graph_and_inputs(ir)
36
        if graph_set and i not in graph_set:
37
            continue
38

39
        print(f"Running Graph {i}")
40
        prev_result = None
41
        prev_runner_name = None
42
        for runner in runners:
43
            runner_name, runner_fn = runner
44
            try:
45
                result = runner_fn(ir, inputs)
46
                if prev_result:
47
                    improvement = (prev_result / result - 1) * 100
48
                    print(
49
                        f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%"
50
                    )
51
                else:
52
                    print(f"{runner_name} : {result:.6f} ms")
53
                prev_result = result
54
                prev_runner_name = runner_name
55
            except RuntimeError:
56
                print(f"  Graph {i} failed for {runner_name} :", traceback.format_exc())
57

58

59
def run():
60
    parser = argparse.ArgumentParser(
61
        description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
62
    )
63
    parser.add_argument("filename", help="Filename of log file")
64
    parser.add_argument(
65
        "--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser"
66
    )
67
    parser.add_argument(
68
        "--no-nvfuser",
69
        dest="nvfuser",
70
        action="store_false",
71
        help="DON'T benchmark nvfuser",
72
    )
73
    parser.set_defaults(nvfuser=False)
74
    parser.add_argument(
75
        "--nnc-static",
76
        dest="nnc_static",
77
        action="store_true",
78
        help="benchmark nnc static",
79
    )
80
    parser.add_argument(
81
        "--no-nnc-static",
82
        dest="nnc_static",
83
        action="store_false",
84
        help="DON'T benchmark nnc static",
85
    )
86
    parser.set_defaults(nnc_static=False)
87

88
    parser.add_argument(
89
        "--nnc-dynamic",
90
        dest="nnc_dynamic",
91
        action="store_true",
92
        help="nnc with dynamic shapes",
93
    )
94
    parser.add_argument(
95
        "--no-nnc-dynamic",
96
        dest="nnc_dynamic",
97
        action="store_false",
98
        help="DONT't benchmark nnc with dynamic shapes",
99
    )
100
    parser.set_defaults(nnc_dynamic=False)
101

102
    parser.add_argument(
103
        "--baseline", dest="baseline", action="store_true", help="benchmark baseline"
104
    )
105
    parser.add_argument(
106
        "--no-baseline",
107
        dest="baseline",
108
        action="store_false",
109
        help="DON'T benchmark baseline",
110
    )
111
    parser.set_defaults(baseline=False)
112

113
    parser.add_argument(
114
        "--output", dest="output", action="store_true", help="Output graph IR"
115
    )
116
    parser.add_argument(
117
        "--no-output", dest="output", action="store_false", help="DON'T output graph IR"
118
    )
119
    parser.set_defaults(output=False)
120

121
    parser.add_argument(
122
        "--graphs", nargs="+", type=int, help="Run only specified graph indices"
123
    )
124

125
    args = parser.parse_args()
126
    graphs = extract_ir(args.filename)
127

128
    graph_set = args.graphs
129
    graph_set = graph_set if graph_set else None
130

131
    options = []
132
    if args.baseline:
133
        options.append(("Baseline no fusion", run_baseline_no_fusion))
134
    if args.nnc_dynamic:
135
        options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True)))
136
    if args.nnc_static:
137
        options.append(("NNC Static", functools.partial(run_nnc, dynamic=False)))
138
    if args.nvfuser:
139
        options.append(("NVFuser", run_nvfuser))
140

141
    test_runners(graphs, options, graph_set)
142

143
    if args.output:
144
        quoted = []
145
        for i, ir in enumerate(graphs):
146
            if graph_set and i not in graph_set:
147
                continue
148
            quoted.append('"""' + ir + '"""')
149
        print("[" + ", ".join(quoted) + "]")
150

151

152
if __name__ == "__main__":
153
    run()
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.