google-research

Форк
0
/
distributions.py 
221 строка · 8.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Functions for sampling from different distributions.
17

18
Sampling functions for YOTO. Also includes functions to transform the samples,
19
for instance via softmax.
20
"""
21

22
import ast
23
import enum
24
import gin
25
import numpy as np
26
import tensorflow.compat.v1 as tf
27
from tensorflow_probability import distributions as tfd
28

29

30
@gin.constants_from_enum
31
class DistributionType(enum.Enum):
32
  UNIFORM = 0
33
  LOG_UNIFORM = 1
34

35

36
@gin.constants_from_enum
37
class TransformType(enum.Enum):
38
  IDENTITY = 0
39
  LOG = 1
40

41

42
@gin.configurable("DistributionSpec")
43
class DistributionSpec(object):
44
  """Spec of a distribution for YOTO training or evaluation."""
45
  # NOTE(adosovitskiy) Tried to do it with namedtuple, but failed to make
46
  # it work with gin
47

48
  def __init__(self, distribution_type, params, transform):
49
    self.distribution_type = distribution_type
50
    self.params = params
51
    self.transform = transform
52

53

54
# TODO(adosovitskiy): have one signature with distributionspec and one without
55
def get_samples_as_dicts(distribution_spec, num_samples=1,
56
                         names=None, seed=None):
57
  """Sample weight dictionaries for multi-loss problems.
58

59
  Supports many different distribution specifications, including random
60
  distributions given via DistributionSpec or fixed sets of weights given by
61
  dictionaries or lists of dictionaries. The function first parses the different
62
  options and then actually computes the weights to be returned.
63

64
  Args:
65
    distribution_spec: One of the following:
66
      * An instance of DistributionSpec
67
      * DistributionSpec class
68
      * A dictionary mapping loss names to their values
69
      * A list of such dictionaries
70
      * A string representing one of the above
71
    num_samples: how many samples to return (only if given DistributionSpec)
72
    names: names of losses (only if given DistributionSpec)
73
    seed: random seed to use for sampling (only if given DistributionSpec)
74

75
  Returns:
76
    samples_dicts: list of dictionaries with the samples weights
77
  """
78

79
  # If given a string, first eval it
80
  if isinstance(distribution_spec, str):
81
    distribution_spec = ast.literal_eval(distribution_spec)
82

83
  # Now convert to a list of dictionaries or an instance of DistributionSpec
84
  if isinstance(distribution_spec, dict):
85
    given_keys = distribution_spec.keys()
86
    if not (names is None or set(names) == set(given_keys)):
87
      raise ValueError(
88
          "Provided names {} do not match with the keys of the provided "
89
          "dictionary {}".format(names, given_keys))
90
    distribution_spec = [distribution_spec]
91
  elif isinstance(distribution_spec, list):
92
    if not (distribution_spec and
93
            isinstance(distribution_spec[0], dict)):
94
      raise ValueError(
95
          "If distribution_spec is a list, it should be non-empty and "
96
          "consist of dictionaries.")
97
    given_keys = distribution_spec[0].keys()
98
    if not (names is None or set(names) == set(given_keys)):
99
      raise ValueError(
100
          "Provided names {} do not match with the keys of the provided "
101
          "dictionary {}".format(names, given_keys))
102
  elif isinstance(distribution_spec, type):
103
    distribution_spec = distribution_spec()
104
  else:
105
    raise TypeError(
106
        "The distribution_spec should be a dictionary ot a list of dictionaries"
107
        " or an instance of DistributionSpec or class DistributionSpec")
108

109
  assert (isinstance(distribution_spec, DistributionSpec) or
110
          isinstance(distribution_spec, list)), \
111
          "By now distribution_spec should be a DistributionSpec or a list"
112

113
  # Finally, actually make the samples
114
  if isinstance(distribution_spec, DistributionSpec):
115
    # Sample and convert to a list of dictionaries
116
    samples = get_sample((num_samples, len(names)), distribution_spec,
117
                         seed=seed, return_numpy=True)
118
    samples_dicts = []
119
    for k in range(num_samples):
120
      samples_dicts.append(
121
          {name: samples[k, n] for n, name in enumerate(names)})
122
  elif isinstance(distribution_spec, list):
123
    samples_dicts = distribution_spec
124

125
  return samples_dicts
126

127

128
def get_sample_untransformed(shape, distribution_type, distribution_params,
129
                             seed):
130
  """Get a distribution based on specification and parameters.
131

132
  Parameters can be a list, in which case each of the list members is used to
133
  generate one row (or column?) of the resulting sample matrix. Otherwise, the
134
  same parameters are used for the whole matrix.
135

136
  Args:
137
    shape: Tuple/List representing the shape of the output
138
    distribution_type: DistributionType object
139
    distribution_params: Dict of distributon parameters
140
    seed: random seed to be used
141

142
  Returns:
143
    sample: TF Tensor with a sample from the distribution
144
  """
145
  if isinstance(distribution_params, list):
146
    if len(shape) != 2 or len(distribution_params) != shape[1]:
147
      raise ValueError("If distribution_params is a list, the desired 'shape' "
148
                       "should be 2-dimensional and number of elements in the "
149
                       "list should match 'shape[1]'")
150
    all_samples = []
151
    for curr_params in distribution_params:
152
      curr_samples = get_one_sample_untransformed([shape[0], 1],
153
                                                  distribution_type,
154
                                                  curr_params, seed)
155
      all_samples.append(curr_samples)
156
    return tf.concat(all_samples, axis=1)
157
  else:
158
    return get_one_sample_untransformed(shape, distribution_type,
159
                                        distribution_params, seed)
160

161

162
def get_one_sample_untransformed(shape, distribution_type, distribution_params,
163
                                 seed):
164
  """Get one untransoformed sample."""
165
  if distribution_type == DistributionType.UNIFORM:
166
    low, high = distribution_params["low"], distribution_params["high"]
167
    distribution = tfd.Uniform(low=tf.constant(low, shape=shape[1:]),
168
                               high=tf.constant(high, shape=shape[1:],))
169
    sample = distribution.sample(shape[0], seed=seed)
170
  elif distribution_type == DistributionType.LOG_UNIFORM:
171
    low, high = distribution_params["low"], distribution_params["high"]
172
    distribution = tfd.Uniform(
173
        low=tf.constant(np.log(low), shape=shape[1:], dtype=tf.float32),
174
        high=tf.constant(np.log(high), shape=shape[1:], dtype=tf.float32))
175
    sample = tf.exp(distribution.sample(shape[0], seed=seed))
176
  else:
177
    raise ValueError("Unknown distribution type {}".format(distribution_type))
178
  return sample
179

180

181
def get_sample(shape, distribution_spec, seed=None, return_numpy=False):
182
  """Sample a tensor of random numbers.
183

184
  Args:
185
    shape: shape of the resulting tensor
186
    distribution_spec: DistributionSpec
187
    seed: random seed to use for sampling
188
    return_numpy: if True, returns a fixed numpy array, otherwise - a TF op
189
      that allows sampling repeatedly
190

191
  Returns:
192
    samples: numpy array or TF op representing the random numbers
193
  """
194
  distribution_type = distribution_spec.distribution_type  # pytype: disable=attribute-error
195
  distribution_params = distribution_spec.params  # pytype: disable=attribute-error
196
  transform_type = distribution_spec.transform  # pytype: disable=attribute-error
197

198
  sample_tf = get_sample_untransformed(shape, distribution_type,
199
                                       distribution_params, seed)
200

201
  if transform_type is not None:
202
    transform = get_transform(transform_type)
203
    sample_tf = transform(sample_tf)
204

205
  if return_numpy:
206
    with tf.Session() as sess:
207
      sample_np = sess.run([sample_tf])[0]
208
    return sample_np
209
  else:
210
    return sample_tf
211

212

213
def get_transform(transform_type):
214
  """Get transforms for converting raw samples to weights and back."""
215
  if transform_type == TransformType.IDENTITY:
216
    transform = lambda x: x
217
  elif transform_type == TransformType.LOG:
218
    transform = tf.log
219
  else:
220
    raise ValueError("Unknown transform type {}".format(transform_type))
221
  return transform
222

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.