pytorch

Форк
0
189 строк · 5.6 Кб
1
## @package app
2
# Module caffe2.python.mint.app
3
import argparse
4
import flask
5
import glob
6
import numpy as np
7
import nvd3
8
import os
9
import sys
10
# pyre-fixme[21]: Could not find module `tornado.httpserver`.
11
import tornado.httpserver
12
# pyre-fixme[21]: Could not find a module corresponding to import `tornado.wsgi`
13
import tornado.wsgi
14

15
__folder__ = os.path.abspath(os.path.dirname(__file__))
16

17
app = flask.Flask(
18
    __name__,
19
    template_folder=os.path.join(__folder__, "templates"),
20
    static_folder=os.path.join(__folder__, "static")
21
)
22
args = None
23

24

25
def jsonify_nvd3(chart):
26
    chart.buildcontent()
27
    # Note(Yangqing): python-nvd3 does not seem to separate the built HTML part
28
    # and the script part. Luckily, it seems to be the case that the HTML part is
29
    # only a <div>, which can be accessed by chart.container; the script part,
30
    # while the script part occupies the rest of the html content, which we can
31
    # then find by chart.htmlcontent.find['<script>'].
32
    script_start = chart.htmlcontent.find('<script>') + 8
33
    script_end = chart.htmlcontent.find('</script>')
34
    return flask.jsonify(
35
        result=chart.container,
36
        script=chart.htmlcontent[script_start:script_end].strip()
37
    )
38

39

40
def visualize_summary(filename):
41
    try:
42
        data = np.loadtxt(filename)
43
    except Exception as e:
44
        return 'Cannot load file {}: {}'.format(filename, str(e))
45
    chart_name = os.path.splitext(os.path.basename(filename))[0]
46
    chart = nvd3.lineChart(
47
        name=chart_name + '_summary_chart',
48
        height=args.chart_height,
49
        y_axis_format='.03g'
50
    )
51
    if args.sample < 0:
52
        step = max(data.shape[0] / -args.sample, 1)
53
    else:
54
        step = args.sample
55
    xdata = np.arange(0, data.shape[0], step)
56
    # data should have 4 dimensions.
57
    chart.add_serie(x=xdata, y=data[xdata, 0], name='min')
58
    chart.add_serie(x=xdata, y=data[xdata, 1], name='max')
59
    chart.add_serie(x=xdata, y=data[xdata, 2], name='mean')
60
    chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std')
61
    chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std')
62
    return jsonify_nvd3(chart)
63

64

65
def visualize_print_log(filename):
66
    try:
67
        data = np.loadtxt(filename)
68
        if data.ndim == 1:
69
            data = data[:, np.newaxis]
70
    except Exception as e:
71
        return 'Cannot load file {}: {}'.format(filename, str(e))
72
    chart_name = os.path.splitext(os.path.basename(filename))[0]
73
    chart = nvd3.lineChart(
74
        name=chart_name + '_log_chart',
75
        height=args.chart_height,
76
        y_axis_format='.03g'
77
    )
78
    if args.sample < 0:
79
        step = max(data.shape[0] / -args.sample, 1)
80
    else:
81
        step = args.sample
82
    xdata = np.arange(0, data.shape[0], step)
83
    # if there is only one curve, we also show the running min and max
84
    if data.shape[1] == 1:
85
        # We also print the running min and max for the steps.
86
        trunc_size = data.shape[0] / step
87
        running_mat = data[:trunc_size * step].reshape((trunc_size, step))
88
        chart.add_serie(
89
            x=xdata[:trunc_size],
90
            y=running_mat.min(axis=1),
91
            name='running_min'
92
        )
93
        chart.add_serie(
94
            x=xdata[:trunc_size],
95
            y=running_mat.max(axis=1),
96
            name='running_max'
97
        )
98
        chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name)
99
    else:
100
        for i in range(0, min(data.shape[1], args.max_curves)):
101
            # data should have 4 dimensions.
102
            chart.add_serie(
103
                x=xdata,
104
                y=data[xdata, i],
105
                name='{}[{}]'.format(chart_name, i)
106
            )
107

108
    return jsonify_nvd3(chart)
109

110

111
def visualize_file(filename):
112
    fullname = os.path.join(args.root, filename)
113
    if filename.endswith('summary'):
114
        return visualize_summary(fullname)
115
    elif filename.endswith('log'):
116
        return visualize_print_log(fullname)
117
    else:
118
        return flask.jsonify(
119
            result='Unsupport file: {}'.format(filename),
120
            script=''
121
        )
122

123

124
@app.route('/')
125
def index():
126
    files = glob.glob(os.path.join(args.root, "*.*"))
127
    files.sort()
128
    names = [os.path.basename(f) for f in files]
129
    return flask.render_template(
130
        'index.html',
131
        root=args.root,
132
        names=names,
133
        debug_messages=names
134
    )
135

136

137
@app.route('/visualization/<string:name>')
138
def visualization(name):
139
    ret = visualize_file(name)
140
    return ret
141

142

143
def main(argv):
144
    parser = argparse.ArgumentParser("The mint visualizer.")
145
    parser.add_argument(
146
        '-p',
147
        '--port',
148
        type=int,
149
        default=5000,
150
        help="The flask port to use."
151
    )
152
    parser.add_argument(
153
        '-r',
154
        '--root',
155
        type=str,
156
        default='.',
157
        help="The root folder to read files for visualization."
158
    )
159
    parser.add_argument(
160
        '--max_curves',
161
        type=int,
162
        default=5,
163
        help="The max number of curves to show in a dump tensor."
164
    )
165
    parser.add_argument(
166
        '--chart_height',
167
        type=int,
168
        default=300,
169
        help="The chart height for nvd3."
170
    )
171
    parser.add_argument(
172
        '-s',
173
        '--sample',
174
        type=int,
175
        default=-200,
176
        help="Sample every given number of data points. A negative "
177
        "number means the total points we will sample on the "
178
        "whole curve. Default 100 points."
179
    )
180
    global args
181
    args = parser.parse_args(argv)
182
    server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app))
183
    server.listen(args.port)
184
    print("Tornado server starting on port {}.".format(args.port))
185
    tornado.ioloop.IOLoop.instance().start()
186

187

188
if __name__ == '__main__':
189
    main(sys.argv[1:])
190

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.