caffe
1classdef test_solver < matlab.unittest.TestCase
2
3properties
4num_output
5solver
6end
7
8methods
9function self = test_solver()
10self.num_output = 13;
11model_file = caffe.test.test_net.simple_net_file(self.num_output);
12solver_file = tempname();
13
14fid = fopen(solver_file, 'w');
15fprintf(fid, [ ...
16'net: "' model_file '"\n' ...
17'test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9\n' ...
18'weight_decay: 0.0005 lr_policy: "inv" gamma: 0.0001 power: 0.75\n' ...
19'display: 100 max_iter: 100 snapshot_after_train: false\n' ]);
20fclose(fid);
21
22self.solver = caffe.Solver(solver_file);
23% also make sure get_solver runs
24caffe.get_solver(solver_file);
25caffe.set_mode_cpu();
26% fill in valid labels
27self.solver.net.blobs('label').set_data(randi( ...
28self.num_output - 1, self.solver.net.blobs('label').shape));
29self.solver.test_nets(1).blobs('label').set_data(randi( ...
30self.num_output - 1, self.solver.test_nets(1).blobs('label').shape));
31
32delete(solver_file);
33delete(model_file);
34end
35end
36methods (Test)
37function test_solve(self)
38self.verifyEqual(self.solver.iter(), 0)
39self.solver.step(30);
40self.verifyEqual(self.solver.iter(), 30)
41self.solver.solve()
42self.verifyEqual(self.solver.iter(), 100)
43end
44end
45end
46