google-research

Форк
0
324 строки · 8.3 Кб
1
using Latexify, DataFrames
2
import CSV, Plots
3

4
include("../src/methods_enum.jl")
5
include("scripts_parameters.jl")
6

7
function standard_plot_setup()
8
  Plots.plot(
9
    xlabel = "iterations",
10
    ylabel = "normalized duality gap",
11
    yaxis = :log,
12
    linestyle = :dash,
13
    legend = :topright,
14
  )
15
end
16

17
function get_residual(data::DataFrame)
18
  #return data.kkt_error_average_iterate
19
  return data.average_normalized_gap
20
  #return data.primal_delta_norms .+ data.dual_delta_norms
21
end
22

23

24
function generic_restart_plot(;
25
  data::DataFrame,
26
  color::Symbol,
27
  markershape::Symbol,
28
  label::String,
29
  linestyle::Symbol,
30
  y_min::Float64,
31
)
32
  residual = get_residual(data)
33
  # Plot the residuals
34
  Plots.plot!(
35
    data.iteration,
36
    residual,
37
    label = "",
38
    ylims = (y_min, maximum(residual)),
39
    color = color,
40
    linestyle = linestyle,
41
  )
42
  # Plot the restart points
43
  restart_indicies = findall(data.restart_occurred)
44
  Plots.plot!(
45
    data.iteration[restart_indicies],
46
    residual[restart_indicies],
47
    color = color,
48
    linealpha = 0.0,
49
    markershape = markershape,
50
    label = "",
51
    markerstrokewidth = 0,
52
    markerstrokecolor = :auto,
53
    markersize = 5.0,
54
  )
55
  # Plot the last active set change
56
  no_change_indicies = findall(data.number_of_active_set_changes .> 0)
57
  if length(no_change_indicies) > 0
58
    last_active_set_change_index = no_change_indicies[end]
59
    Plots.plot!(
60
      [data.iteration[last_active_set_change_index]],
61
      [residual[last_active_set_change_index]],
62
      color = color,
63
      markershape = :star5,
64
      linealpha = 0.0,
65
      label = "",
66
      markerstrokecolor = :auto,
67
      markersize = 7.0,
68
    )
69
  end
70
  # Artificial plot for the legend.
71
  restart_plt = Plots.plot!(
72
    data.iteration[restart_indicies][1:1],
73
    residual[restart_indicies][1:1],
74
    color = color,
75
    markershape = markershape,
76
    label = label,
77
    markerstrokecolor = :auto,
78
    linestyle = linestyle,
79
  )
80
  return restart_plt
81
end
82

83
function plot_no_restart_and_adaptive(results_directory::String, y_min::Float64)
84
  no_restarts_df =
85
    CSV.read(joinpath(results_directory, "no_restarts.csv"), DataFrame)
86
  residual = no_restarts_df.current_normalized_gap
87
  Plots.plot!(
88
    no_restarts_df.iteration,
89
    residual,
90
    color = :black,
91
    label = "No restarts",
92
  )
93

94
  restarts_plt = generic_restart_plot(
95
    data = CSV.read(
96
      joinpath(results_directory, "adaptive_restarts.csv"),
97
      DataFrame,
98
    ),
99
    color = :blue,
100
    markershape = :circle,
101
    label = "Adaptive restarts",
102
    linestyle = :solid,
103
    y_min = y_min,
104
  )
105
  return restarts_plt
106
end
107

108
function plot_no_restart_adaptive_and_fixed_frequency_results(
109
  results_directory::String,
110
  restart_lengths::Vector{Int64},
111
  y_min::Float64,
112
)
113
  # Figure out which restart lengths to plot
114
  df_restart_performance = DataFrame(
115
    restart_length = Int64[],
116
    first_approx_opt_index = Float64[],
117
    final_function_value = Float64[],
118
  )
119
  for i in 1:length(restart_lengths)
120
    restart_length = restart_lengths[i]
121
    fixed_frequency_df = CSV.read(
122
      joinpath(results_directory, "restart_length$(restart_length).csv"),
123
      DataFrame,
124
    )
125
    residuals = get_residual(fixed_frequency_df)
126
    approx_opt_indicies = findall(residuals .< y_min)
127
    if length(approx_opt_indicies) > 0
128
      first_approx_opt_index = approx_opt_indicies[1]
129
    else
130
      first_approx_opt_index = Inf
131
    end
132
    final_function_value = residuals[end]
133
    append!(
134
      df_restart_performance,
135
      DataFrame(
136
        restart_length = restart_length,
137
        first_approx_opt_index = first_approx_opt_index,
138
        final_function_value = final_function_value,
139
      ),
140
    )
141
  end
142
  sort!(
143
    df_restart_performance,
144
    [:first_approx_opt_index, :final_function_value],
145
    rev = (false, false),
146
  )
147
  # Pick three best restart lengths
148
  subset_of_restart_lengths = df_restart_performance.restart_length[1:3]
149
  sort!(subset_of_restart_lengths)
150

151
  standard_plot_setup()
152

153
  colors = [:red, :green, :purple, :orange, :pink]
154
  markers = [
155
    :dtriangle,
156
    :rect,
157
    :diamond,
158
    :hexagon,
159
    :cross,
160
    :xcross,
161
    :utriangle,
162
    :rtriangle,
163
    :ltriangle,
164
    :pentagon,
165
    :heptagon,
166
    :octagon,
167
    :vline,
168
    :hline,
169
  ]
170
  for i in 1:length(subset_of_restart_lengths)
171
    restart_length = subset_of_restart_lengths[i]
172
    fixed_frequency_df = CSV.read(
173
      joinpath(results_directory, "restart_length$(restart_length).csv"),
174
      DataFrame,
175
    )
176
    generic_restart_plot(;
177
      data = fixed_frequency_df,
178
      color = colors[i],
179
      markershape = markers[i],
180
      label = "Restart length = $restart_length",
181
      linestyle = :dot,
182
      y_min = y_min,
183
    )
184
  end
185

186
  restarts_plt = plot_no_restart_and_adaptive(results_directory, y_min)
187
  return restarts_plt
188
end
189

190
function plot_dynamic_adaptive_and_no_restarts(
191
  results_directory::String,
192
  restart_lengths::Vector{Int64},
193
  y_min::Float64,
194
)
195
  standard_plot_setup()
196
  plot_no_restart_and_adaptive(results_directory, y_min)
197

198
  dynamic_adaptive_restarts_df = CSV.read(
199
    joinpath(results_directory, "dynamic_adaptive_restarts.csv"),
200
    DataFrame,
201
  )
202
  residual = get_residual(dynamic_adaptive_restarts_df)
203
  flexible_restarts_plt = Plots.plot!(
204
    dynamic_adaptive_restarts_df.iteration,
205
    residual,
206
    label = "Flexible restarts",
207
    ylims = (y_min, maximum(residual)),
208
    color = :blue,
209
    linestyle = :dot,
210
  )
211
  return flexible_restarts_plt
212
end
213

214
function first_iteration_to_hit_tolerance(
215
  df::DataFrame,
216
  target_tolerance::Float64,
217
)
218
  indicies_below_tolerance = findall(
219
    min.(df.kkt_error_average_iterate, df.kkt_error_current_iterate) .<=
220
    target_tolerance,
221
  )
222
  if length(indicies_below_tolerance) > 0
223
    return df.iteration[indicies_below_tolerance[1]]
224
  else
225
    return Inf
226
  end
227
end
228

229

230
function create_dictionary_of_iterations_to_hit_tolerance(
231
  problem_name::String,
232
  results_directory::String,
233
  restart_lengths::Vector{Int64},
234
  target_tolerance::Float64,
235
)
236

237
  dictionary_hits = Dict()
238
  dictionary_hits["problem_name"] = problem_name
239
  dictionary_hits["adaptive_restarts"] = first_iteration_to_hit_tolerance(
240
    CSV.read(joinpath(results_directory, "adaptive_restarts.csv"), DataFrame),
241
    target_tolerance,
242
  )
243
  dictionary_hits["flexible_restarts"] = first_iteration_to_hit_tolerance(
244
    CSV.read(
245
      joinpath(results_directory, "dynamic_adaptive_restarts.csv"),
246
      DataFrame,
247
    ),
248
    target_tolerance,
249
  )
250
  best_fixed_frequency_iterations = Inf
251
  for restart_length in restart_lengths
252
    fixed_frequency_df = CSV.read(
253
      joinpath(results_directory, "restart_length$(restart_length).csv"),
254
      DataFrame,
255
    )
256
    fixed_frequency_iterations =
257
      first_iteration_to_hit_tolerance(fixed_frequency_df, target_tolerance)
258
    if fixed_frequency_iterations < best_fixed_frequency_iterations
259
      best_fixed_frequency_iterations = fixed_frequency_iterations
260
    end
261
  end
262
  dictionary_hits["best_fixed_frequency"] = best_fixed_frequency_iterations
263

264
  return dictionary_hits
265
end
266

267
results_directory = ARGS[1]
268
@assert length(ARGS) == 1
269

270

271
function main()
272
  y_min = 1e-7
273

274
  for problem_name in ALL_PROBLEM_NAMES
275
    subdirectory = joinpath(results_directory, problem_name)
276
    restarts_plt = plot_no_restart_adaptive_and_fixed_frequency_results(
277
      subdirectory,
278
      RESTART_LENGTHS_DICT[problem_name],
279
      y_min,
280
    )
281
    Plots.savefig(
282
      restarts_plt,
283
      joinpath(subdirectory, "$(problem_name)-adaptive-residuals.pdf"),
284
    )
285
    flexible_plt = plot_dynamic_adaptive_and_no_restarts(
286
      subdirectory,
287
      RESTART_LENGTHS_DICT[problem_name],
288
      y_min,
289
    )
290
    Plots.savefig(
291
      flexible_plt,
292
      joinpath(subdirectory, "$(problem_name)-flexible-residuals.pdf"),
293
    )
294
  end
295

296
  ####################
297
  # table of results #
298
  ####################
299
  TABLE_OF_RESULTS_KKT_ERROR = 1e-6
300
  df_hit_tolerance = DataFrames.DataFrame(
301
    problem_name = String[],
302
    best_fixed_frequency = Float64[],
303
    adaptive_restarts = Float64[],
304
    flexible_restarts = Float64[],
305
  )
306
  for problem_name in ALL_PROBLEM_NAMES
307
    subdirectory = joinpath(results_directory, problem_name)
308
    push!(
309
      df_hit_tolerance,
310
      create_dictionary_of_iterations_to_hit_tolerance(
311
        problem_name,
312
        subdirectory,
313
        RESTART_LENGTHS_DICT[problem_name],
314
        TABLE_OF_RESULTS_KKT_ERROR,
315
      ),
316
    )
317
  end
318

319
  latex_table_file = open(joinpath(results_directory, "latex_table.txt"), "w")
320
  write(latex_table_file, latexify(df_hit_tolerance, env = :table))
321
  close(latex_table_file)
322
end
323

324
main()
325

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.