google-research
324 строки · 8.3 Кб
1using Latexify, DataFrames
2import CSV, Plots
3
4include("../src/methods_enum.jl")
5include("scripts_parameters.jl")
6
7function standard_plot_setup()
8Plots.plot(
9xlabel = "iterations",
10ylabel = "normalized duality gap",
11yaxis = :log,
12linestyle = :dash,
13legend = :topright,
14)
15end
16
17function get_residual(data::DataFrame)
18#return data.kkt_error_average_iterate
19return data.average_normalized_gap
20#return data.primal_delta_norms .+ data.dual_delta_norms
21end
22
23
24function generic_restart_plot(;
25data::DataFrame,
26color::Symbol,
27markershape::Symbol,
28label::String,
29linestyle::Symbol,
30y_min::Float64,
31)
32residual = get_residual(data)
33# Plot the residuals
34Plots.plot!(
35data.iteration,
36residual,
37label = "",
38ylims = (y_min, maximum(residual)),
39color = color,
40linestyle = linestyle,
41)
42# Plot the restart points
43restart_indicies = findall(data.restart_occurred)
44Plots.plot!(
45data.iteration[restart_indicies],
46residual[restart_indicies],
47color = color,
48linealpha = 0.0,
49markershape = markershape,
50label = "",
51markerstrokewidth = 0,
52markerstrokecolor = :auto,
53markersize = 5.0,
54)
55# Plot the last active set change
56no_change_indicies = findall(data.number_of_active_set_changes .> 0)
57if length(no_change_indicies) > 0
58last_active_set_change_index = no_change_indicies[end]
59Plots.plot!(
60[data.iteration[last_active_set_change_index]],
61[residual[last_active_set_change_index]],
62color = color,
63markershape = :star5,
64linealpha = 0.0,
65label = "",
66markerstrokecolor = :auto,
67markersize = 7.0,
68)
69end
70# Artificial plot for the legend.
71restart_plt = Plots.plot!(
72data.iteration[restart_indicies][1:1],
73residual[restart_indicies][1:1],
74color = color,
75markershape = markershape,
76label = label,
77markerstrokecolor = :auto,
78linestyle = linestyle,
79)
80return restart_plt
81end
82
83function plot_no_restart_and_adaptive(results_directory::String, y_min::Float64)
84no_restarts_df =
85CSV.read(joinpath(results_directory, "no_restarts.csv"), DataFrame)
86residual = no_restarts_df.current_normalized_gap
87Plots.plot!(
88no_restarts_df.iteration,
89residual,
90color = :black,
91label = "No restarts",
92)
93
94restarts_plt = generic_restart_plot(
95data = CSV.read(
96joinpath(results_directory, "adaptive_restarts.csv"),
97DataFrame,
98),
99color = :blue,
100markershape = :circle,
101label = "Adaptive restarts",
102linestyle = :solid,
103y_min = y_min,
104)
105return restarts_plt
106end
107
108function plot_no_restart_adaptive_and_fixed_frequency_results(
109results_directory::String,
110restart_lengths::Vector{Int64},
111y_min::Float64,
112)
113# Figure out which restart lengths to plot
114df_restart_performance = DataFrame(
115restart_length = Int64[],
116first_approx_opt_index = Float64[],
117final_function_value = Float64[],
118)
119for i in 1:length(restart_lengths)
120restart_length = restart_lengths[i]
121fixed_frequency_df = CSV.read(
122joinpath(results_directory, "restart_length$(restart_length).csv"),
123DataFrame,
124)
125residuals = get_residual(fixed_frequency_df)
126approx_opt_indicies = findall(residuals .< y_min)
127if length(approx_opt_indicies) > 0
128first_approx_opt_index = approx_opt_indicies[1]
129else
130first_approx_opt_index = Inf
131end
132final_function_value = residuals[end]
133append!(
134df_restart_performance,
135DataFrame(
136restart_length = restart_length,
137first_approx_opt_index = first_approx_opt_index,
138final_function_value = final_function_value,
139),
140)
141end
142sort!(
143df_restart_performance,
144[:first_approx_opt_index, :final_function_value],
145rev = (false, false),
146)
147# Pick three best restart lengths
148subset_of_restart_lengths = df_restart_performance.restart_length[1:3]
149sort!(subset_of_restart_lengths)
150
151standard_plot_setup()
152
153colors = [:red, :green, :purple, :orange, :pink]
154markers = [
155:dtriangle,
156:rect,
157:diamond,
158:hexagon,
159:cross,
160:xcross,
161:utriangle,
162:rtriangle,
163:ltriangle,
164:pentagon,
165:heptagon,
166:octagon,
167:vline,
168:hline,
169]
170for i in 1:length(subset_of_restart_lengths)
171restart_length = subset_of_restart_lengths[i]
172fixed_frequency_df = CSV.read(
173joinpath(results_directory, "restart_length$(restart_length).csv"),
174DataFrame,
175)
176generic_restart_plot(;
177data = fixed_frequency_df,
178color = colors[i],
179markershape = markers[i],
180label = "Restart length = $restart_length",
181linestyle = :dot,
182y_min = y_min,
183)
184end
185
186restarts_plt = plot_no_restart_and_adaptive(results_directory, y_min)
187return restarts_plt
188end
189
190function plot_dynamic_adaptive_and_no_restarts(
191results_directory::String,
192restart_lengths::Vector{Int64},
193y_min::Float64,
194)
195standard_plot_setup()
196plot_no_restart_and_adaptive(results_directory, y_min)
197
198dynamic_adaptive_restarts_df = CSV.read(
199joinpath(results_directory, "dynamic_adaptive_restarts.csv"),
200DataFrame,
201)
202residual = get_residual(dynamic_adaptive_restarts_df)
203flexible_restarts_plt = Plots.plot!(
204dynamic_adaptive_restarts_df.iteration,
205residual,
206label = "Flexible restarts",
207ylims = (y_min, maximum(residual)),
208color = :blue,
209linestyle = :dot,
210)
211return flexible_restarts_plt
212end
213
214function first_iteration_to_hit_tolerance(
215df::DataFrame,
216target_tolerance::Float64,
217)
218indicies_below_tolerance = findall(
219min.(df.kkt_error_average_iterate, df.kkt_error_current_iterate) .<=
220target_tolerance,
221)
222if length(indicies_below_tolerance) > 0
223return df.iteration[indicies_below_tolerance[1]]
224else
225return Inf
226end
227end
228
229
230function create_dictionary_of_iterations_to_hit_tolerance(
231problem_name::String,
232results_directory::String,
233restart_lengths::Vector{Int64},
234target_tolerance::Float64,
235)
236
237dictionary_hits = Dict()
238dictionary_hits["problem_name"] = problem_name
239dictionary_hits["adaptive_restarts"] = first_iteration_to_hit_tolerance(
240CSV.read(joinpath(results_directory, "adaptive_restarts.csv"), DataFrame),
241target_tolerance,
242)
243dictionary_hits["flexible_restarts"] = first_iteration_to_hit_tolerance(
244CSV.read(
245joinpath(results_directory, "dynamic_adaptive_restarts.csv"),
246DataFrame,
247),
248target_tolerance,
249)
250best_fixed_frequency_iterations = Inf
251for restart_length in restart_lengths
252fixed_frequency_df = CSV.read(
253joinpath(results_directory, "restart_length$(restart_length).csv"),
254DataFrame,
255)
256fixed_frequency_iterations =
257first_iteration_to_hit_tolerance(fixed_frequency_df, target_tolerance)
258if fixed_frequency_iterations < best_fixed_frequency_iterations
259best_fixed_frequency_iterations = fixed_frequency_iterations
260end
261end
262dictionary_hits["best_fixed_frequency"] = best_fixed_frequency_iterations
263
264return dictionary_hits
265end
266
267results_directory = ARGS[1]
268@assert length(ARGS) == 1
269
270
271function main()
272y_min = 1e-7
273
274for problem_name in ALL_PROBLEM_NAMES
275subdirectory = joinpath(results_directory, problem_name)
276restarts_plt = plot_no_restart_adaptive_and_fixed_frequency_results(
277subdirectory,
278RESTART_LENGTHS_DICT[problem_name],
279y_min,
280)
281Plots.savefig(
282restarts_plt,
283joinpath(subdirectory, "$(problem_name)-adaptive-residuals.pdf"),
284)
285flexible_plt = plot_dynamic_adaptive_and_no_restarts(
286subdirectory,
287RESTART_LENGTHS_DICT[problem_name],
288y_min,
289)
290Plots.savefig(
291flexible_plt,
292joinpath(subdirectory, "$(problem_name)-flexible-residuals.pdf"),
293)
294end
295
296####################
297# table of results #
298####################
299TABLE_OF_RESULTS_KKT_ERROR = 1e-6
300df_hit_tolerance = DataFrames.DataFrame(
301problem_name = String[],
302best_fixed_frequency = Float64[],
303adaptive_restarts = Float64[],
304flexible_restarts = Float64[],
305)
306for problem_name in ALL_PROBLEM_NAMES
307subdirectory = joinpath(results_directory, problem_name)
308push!(
309df_hit_tolerance,
310create_dictionary_of_iterations_to_hit_tolerance(
311problem_name,
312subdirectory,
313RESTART_LENGTHS_DICT[problem_name],
314TABLE_OF_RESULTS_KKT_ERROR,
315),
316)
317end
318
319latex_table_file = open(joinpath(results_directory, "latex_table.txt"), "w")
320write(latex_table_file, latexify(df_hit_tolerance, env = :table))
321close(latex_table_file)
322end
323
324main()
325