4
class SimpleTransformer:
7
SimpleTransformer is a simple class for preprocessing and deprocessing
11
def __init__(self, mean=[128, 128, 128]):
12
self.mean = np.array(mean, dtype=np.float32)
15
def set_mean(self, mean):
17
Set the mean to subtract for centering the data.
21
def set_scale(self, scale):
27
def preprocess(self, im):
29
preprocess() emulate the pre-processing occurring in the vgg16 caffe
34
im = im[:, :, ::-1] # change to BGR
37
im = im.transpose((2, 0, 1))
41
def deprocess(self, im):
43
inverse of preprocess()
45
im = im.transpose(1, 2, 0)
48
im = im[:, :, ::-1] # change to RGB
56
Caffesolver is a class for creating a solver.prototxt file. It sets default
57
values and can export a solver parameter file.
58
Note that all parameters are stored as strings. Strings variables are
59
stored as strings in strings.
62
def __init__(self, testnet_prototxt_path="testnet.prototxt",
63
trainnet_prototxt_path="trainnet.prototxt", debug=False):
68
self.sp['base_lr'] = '0.001'
69
self.sp['momentum'] = '0.9'
72
self.sp['test_iter'] = '100'
73
self.sp['test_interval'] = '250'
76
self.sp['display'] = '25'
77
self.sp['snapshot'] = '2500'
78
self.sp['snapshot_prefix'] = '"snapshot"' # string within a string!
80
# learning rate policy
81
self.sp['lr_policy'] = '"fixed"'
83
# important, but rare:
84
self.sp['gamma'] = '0.1'
85
self.sp['weight_decay'] = '0.0005'
86
self.sp['train_net'] = '"' + trainnet_prototxt_path + '"'
87
self.sp['test_net'] = '"' + testnet_prototxt_path + '"'
89
# pretty much never change these.
90
self.sp['max_iter'] = '100000'
91
self.sp['test_initialization'] = 'false'
92
self.sp['average_loss'] = '25' # this has to do with the display.
93
self.sp['iter_size'] = '1' # this is for accumulating gradients
96
self.sp['max_iter'] = '12'
97
self.sp['test_iter'] = '1'
98
self.sp['test_interval'] = '4'
99
self.sp['display'] = '1'
101
def add_from_file(self, filepath):
103
Reads a caffe solver prototxt file and updates the Caffesolver
106
with open(filepath, 'r') as f:
110
splitLine = line.split(':')
111
self.sp[splitLine[0].strip()] = splitLine[1].strip()
113
def write(self, filepath):
115
Export solver parameters to INPUT "filepath". Sorted alphabetically.
117
f = open(filepath, 'w')
118
for key, value in sorted(self.sp.items()):
119
if not(type(value) is str):
120
raise TypeError('All solver parameters must be strings')
121
f.write('%s: %s\n' % (key, value))