google-research

Форк
0
436 строк · 15.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Defines standard network layers that train using l0 regularization."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import tensorflow.compat.v1 as tf
22

23
from state_of_sparsity.layers.l0_regularization import common
24
from state_of_sparsity.layers.utils import layer_utils
25

26

27
def _verify_weight_parameters(weight_parameters):
28
  """Verifies that the format of the input `weight_parameters`.
29

30
  Checks that the input parameters is a 2-tuple of tensors of equal shape.
31

32
  Args:
33
    weight_parameters: The parameters to check.
34

35
  Raises:
36
    RuntimeError: If the input is not a 2-tuple of tensors with equal shape.
37

38
  Returns:
39
    The input `weight_parameters`.
40
  """
41
  if len(weight_parameters) != 2:
42
    raise RuntimeError("Incorrect number of weight parameters. Expected "
43
                       "2 tensors, got {}".format(len(weight_parameters)))
44
  if weight_parameters[0].shape != weight_parameters[1].shape:
45
    raise RuntimeError("Expected theta and log alpha parameter tensor "
46
                       "to be same shape. Got shapes {} and {}"
47
                       .format(weight_parameters[0].get_shape().as_list(),
48
                               weight_parameters[1].get_shape().as_list()))
49
  return weight_parameters
50

51

52
def matmul_train(
53
    x,
54
    weight_parameters,
55
    transpose_a=False,
56
    transpose_b=False,
57
    beta=common.BETA,
58
    gamma=common.GAMMA,
59
    zeta=common.ZETA,
60
    eps=common.EPSILON):
61
  """Training computation for a l0-regularized matmul.
62

63
  Args:
64
    x: 2D Tensor representing the input batch.
65
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
66
      unscaled weight values and the second is the log of the alpha values
67
      for the hard concrete distribution.
68
    transpose_a: If True, a is transposed before multiplication.
69
    transpose_b: If True, b is transposed before multiplication.
70
    beta: The beta parameter, which controls the "temperature" of
71
      the distribution. Defaults to 2/3 from the above paper.
72
    gamma: The gamma parameter, which controls the lower bound of the
73
      stretched distribution. Defaults to -0.1 from the above paper.
74
    zeta: The zeta parameters, which controls the upper bound of the
75
      stretched distribution. Defaults to 1.1 from the above paper.
76
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.
77

78
  Returns:
79
    Output Tensor of the matmul operation.
80

81
  Raises:
82
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
83
  """
84
  x.get_shape().assert_has_rank(2)
85
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
86

87
  # Sample the z values from the hard-concrete distribution
88
  weight_noise = common.hard_concrete_sample(
89
      log_alpha,
90
      beta,
91
      gamma,
92
      zeta,
93
      eps)
94
  weights = theta * weight_noise
95
  return tf.matmul(x, weights, transpose_a=transpose_a, transpose_b=transpose_b)
96

97

98
def matmul_eval(
99
    x,
100
    weight_parameters,
101
    transpose_a=False,
102
    transpose_b=False,
103
    gamma=common.GAMMA,
104
    zeta=common.ZETA):
105
  """Evaluation computation for a l0-regularized matmul.
106

107
  Args:
108
    x: 2D Tensor representing the input batch.
109
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
110
      unscaled weight values and the second is the log of the alpha values
111
      for the hard concrete distribution.
112
    transpose_a: If True, a is transposed before multiplication.
113
    transpose_b: If True, b is transposed before multiplication.
114
    gamma: The gamma parameter, which controls the lower bound of the
115
      stretched distribution. Defaults to -0.1 from the above paper.
116
    zeta: The zeta parameters, which controls the upper bound of the
117
      stretched distribution. Defaults to 1.1 from the above paper.
118

119
  Returns:
120
    Output Tensor of the matmul operation.
121

122
  Raises:
123
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
124
  """
125
  x.get_shape().assert_has_rank(2)
126
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
127

128
  # Use the mean of the learned hard-concrete distribution as the
129
  # deterministic weight noise at evaluation time
130
  weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
131
  weights = theta * weight_noise
132
  return tf.matmul(x, weights, transpose_a=transpose_a, transpose_b=transpose_b)
133

134

135
def broadcast_matmul_train(
136
    x,
137
    weight_parameters,
138
    beta=common.BETA,
139
    gamma=common.GAMMA,
140
    zeta=common.ZETA,
141
    eps=common.EPSILON):
142
  """Training computation for l0 matrix multiplication with N input matrices.
143

144
  Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix
145
  `x[i, :, :]` in the input tensor is multiplied independently with the
146
  parameters, resulting in a 3D output tensor with shape
147
  `x.shape[:2] + weight_parameters[0].shape[1]`.
148

149
  Args:
150
    x: 3D Tensor representing the input batch.
151
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
152
      unscaled weight values and the second is the log of the alpha values
153
      for the hard concrete distribution.
154
    beta: The beta parameter, which controls the "temperature" of
155
      the distribution. Defaults to 2/3 from the above paper.
156
    gamma: The gamma parameter, which controls the lower bound of the
157
      stretched distribution. Defaults to -0.1 from the above paper.
158
    zeta: The zeta parameters, which controls the upper bound of the
159
      stretched distribution. Defaults to 1.1 from the above paper.
160
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.
161

162
  Returns:
163
    Output Tensor of the batched matmul operation.
164

165
  Raises:
166
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
167
  """
168
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
169
  theta.get_shape().assert_has_rank(2)
170

171
  # The input data must have be rank 2 or greater
172
  assert x.get_shape().ndims >= 2
173
  input_rank = x.get_shape().ndims
174

175
  # Sample the z values from the hard-concrete distribution
176
  weight_noise = common.hard_concrete_sample(
177
      log_alpha,
178
      beta,
179
      gamma,
180
      zeta,
181
      eps)
182
  weights = theta * weight_noise
183

184
  # Compute the batch of matmuls
185
  return tf.tensordot(x, weights, [[input_rank-1], [0]])
186

187

188
def broadcast_matmul_eval(
189
    x,
190
    weight_parameters,
191
    gamma=common.GAMMA,
192
    zeta=common.ZETA):
193
  """Evaluation computation for l0 matrix multiplication with N input matrices.
194

195
  Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix
196
  `x[i, :, :]` in the input tensor is multiplied independently with the
197
  parameters, resulting in a 3D output tensor with shape
198
  `x.shape[:2] + weight_parameters[0].shape[1]`.
199

200
  Args:
201
    x: 3D Tensor representing the input batch.
202
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
203
      unscaled weight values and the second is the log of the alpha values
204
      for the hard concrete distribution.
205
    gamma: The gamma parameter, which controls the lower bound of the
206
      stretched distribution. Defaults to -0.1 from the above paper.
207
    zeta: The zeta parameters, which controls the upper bound of the
208
      stretched distribution. Defaults to 1.1 from the above paper.
209

210
  Returns:
211
    Output Tensor of the batched matmul operation.
212

213
  Raises:
214
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
215
  """
216
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
217
  theta.get_shape().assert_has_rank(2)
218

219
  # The input data must have be rank 2 or greater
220
  assert x.get_shape().ndims >= 2
221
  input_rank = x.get_shape().ndims
222

223
  # Use the mean of the learned hard-concrete distribution as the
224
  # deterministic weight noise at evaluation time
225
  weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
226
  weights = theta * weight_noise
227

228
  # Compute the batch of matmuls
229
  return tf.tensordot(x, weights, [[input_rank-1], [0]])
230

231

232
def conv2d_train(
233
    x,
234
    weight_parameters,
235
    strides,
236
    padding,
237
    data_format="NHWC",
238
    beta=common.BETA,
239
    gamma=common.GAMMA,
240
    zeta=common.ZETA,
241
    eps=common.EPSILON):
242
  """Training computation for a l0-regularized conv2d.
243

244
  Args:
245
    x: NHWC tf.Tensor representing the input batch of features.
246
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
247
      unscaled weight values and the second is the log of the alpha values
248
      for the hard concrete distribution.
249
    strides: The stride of the sliding window for each dimension of 'x'.
250
      Identical to standard strides argument for tf.conv2d.
251
    padding: String. One of "SAME", or "VALID". Identical to standard
252
      padding argument for tf.conv2d.
253
    data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor.
254
    beta: The beta parameter, which controls the "temperature" of
255
      the distribution. Defaults to 2/3 from the above paper.
256
    gamma: The gamma parameter, which controls the lower bound of the
257
      stretched distribution. Defaults to -0.1 from the above paper.
258
    zeta: The zeta parameters, which controls the upper bound of the
259
      stretched distribution. Defaults to 1.1 from the above paper.
260
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.
261

262
  Returns:
263
    Output Tensor of the conv2d operation.
264

265
  Raises:
266
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
267
  """
268
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
269

270
  # Sample the z values from the hard-concreate distribution
271
  weight_noise = common.hard_concrete_sample(
272
      log_alpha,
273
      beta,
274
      gamma,
275
      zeta,
276
      eps)
277
  weights = theta * weight_noise
278
  return tf.nn.conv2d(x, weights, strides, padding, data_format=data_format)
279

280

281
def conv2d_eval(
282
    x,
283
    weight_parameters,
284
    strides,
285
    padding,
286
    data_format="NHWC",
287
    gamma=common.GAMMA,
288
    zeta=common.ZETA):
289
  """Evaluation computation for a l0-regularized conv2d.
290

291
  Args:
292
    x: NHWC tf.Tensor representing the input batch of features.
293
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
294
      unscaled weight values and the second is the log of the alpha values
295
      for the hard concrete distribution.
296
    strides: The stride of the sliding window for each dimension of 'x'.
297
      Identical to standard strides argument for tf.conv2d.
298
    padding: String. One of "SAME", or "VALID". Identical to standard
299
      padding argument for tf.conv2d.
300
    data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor.
301
    gamma: The gamma parameter, which controls the lower bound of the
302
      stretched distribution. Defaults to -0.1 from the above paper.
303
    zeta: The zeta parameters, which controls the upper bound of the
304
      stretched distribution. Defaults to 1.1 from the above paper.
305

306
  Returns:
307
    Output Tensor of the conv2d operation.
308

309
  Raises:
310
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
311
  """
312
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
313

314
  # Use the mean of the learned hard-concrete distribution as the
315
  # deterministic weight noise at evaluation time
316
  weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
317
  weights = theta * weight_noise
318
  return tf.nn.conv2d(x, weights, strides, padding, data_format=data_format)
319

320

321
def embedding_lookup_train(
322
    weight_parameters,
323
    ids,
324
    name=None,
325
    beta=common.BETA,
326
    gamma=common.GAMMA,
327
    zeta=common.ZETA,
328
    eps=common.EPSILON):
329
  """Training computation for a l0-regularized embedding lookup.
330

331
  Args:
332
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
333
      unscaled weight values and the second is the log of the alpha values
334
      for the hard concrete distribution.
335
    ids: A Tensor with type int32 or int64 containing the ids to be looked up
336
      in params.
337
    name: String. Name of the operator.
338
    beta: The beta parameter, which controls the "temperature" of
339
      the distribution. Defaults to 2/3 from the above paper.
340
    gamma: The gamma parameter, which controls the lower bound of the
341
      stretched distribution. Defaults to -0.1 from the above paper.
342
    zeta: The zeta parameters, which controls the upper bound of the
343
      stretched distribution. Defaults to 1.1 from the above paper.
344
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.
345

346
  Returns:
347
    Output Tensor of the embedding lookup.
348

349
  Raises:
350
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
351
  """
352
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
353

354
  # Before we do anything, lookup the theta values and log_alpha
355
  # values so that we can do our sampling and weight scaling in
356
  # the lower dimensional output batch
357
  embedding_theta = layer_utils.gather(theta, ids)
358
  embedding_log_alpha = layer_utils.gather(log_alpha, ids)
359

360
  # Sample the z values for the output batch from the hard-concrete
361
  embedding_noise = common.hard_concrete_sample(
362
      embedding_log_alpha,
363
      beta,
364
      gamma,
365
      zeta,
366
      eps)
367
  return tf.identity(embedding_theta * embedding_noise, name=name)
368

369

370
def embedding_lookup_eval(
371
    weight_parameters,
372
    ids,
373
    name=None,
374
    gamma=common.GAMMA,
375
    zeta=common.ZETA):
376
  """Evaluation computation for a l0-regularized embedding lookup.
377

378
  Args:
379
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
380
      unscaled weight values and the second is the log of the alpha values
381
      for the hard concrete distribution.
382
    ids: A Tensor with type int32 or int64 containing the ids to be looked up
383
      in params.
384
    name: String. Name of the operator.
385
    gamma: The gamma parameter, which controls the lower bound of the
386
      stretched distribution. Defaults to -0.1 from the above paper.
387
    zeta: The zeta parameters, which controls the upper bound of the
388
      stretched distribution. Defaults to 1.1 from the above paper.
389

390
  Returns:
391
    Output Tensor of the embedding lookup.
392

393
  Raises:
394
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
395
  """
396
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
397

398
  # Before we do anything, lookup the theta values and log_alpha
399
  # values so that we can do our sampling and weight scaling in
400
  # the lower dimensional output batch
401
  embedding_theta = layer_utils.gather(theta, ids)
402
  embedding_log_alpha = layer_utils.gather(log_alpha, ids)
403

404
  # Calculate the mean of the learned hard-concrete distribution
405
  # and scale the output embedding vectors
406
  embedding_noise = common.hard_concrete_mean(
407
      embedding_log_alpha,
408
      gamma,
409
      zeta)
410
  return tf.identity(embedding_theta * embedding_noise, name=name)
411

412

413
def l0_norm(
414
    log_alpha,
415
    beta=common.BETA,
416
    gamma=common.GAMMA,
417
    zeta=common.ZETA):
418
  """Calculate the l0-regularization contribution to the loss.
419

420
  Args:
421
    log_alpha: Tensor of the log alpha parameters for the hard concrete
422
      distribution.
423
    beta: The beta parameter, which controls the "temperature" of
424
      the distribution. Defaults to 2/3 from the above paper.
425
    gamma: The gamma parameter, which controls the lower bound of the
426
      stretched distribution. Defaults to -0.1 from the above paper.
427
    zeta: The zeta parameters, which controls the upper bound of the
428
      stretched distribution. Defaults to 1.1 from the above paper.
429

430
  Returns:
431
    Scalar tensor containing the unweighted l0-regularization term contribution
432
    to the loss.
433
  """
434
  # Value of the CDF of the hard-concrete distribution evaluated at 0
435
  reg_per_weight = tf.sigmoid(log_alpha - beta * tf.log(-gamma / zeta))
436
  return tf.reduce_sum(reg_per_weight)
437

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.