google-research
221 строка · 8.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Functions for sampling from different distributions.
17
18Sampling functions for YOTO. Also includes functions to transform the samples,
19for instance via softmax.
20"""
21
22import ast23import enum24import gin25import numpy as np26import tensorflow.compat.v1 as tf27from tensorflow_probability import distributions as tfd28
29
30@gin.constants_from_enum31class DistributionType(enum.Enum):32UNIFORM = 033LOG_UNIFORM = 134
35
36@gin.constants_from_enum37class TransformType(enum.Enum):38IDENTITY = 039LOG = 140
41
42@gin.configurable("DistributionSpec")43class DistributionSpec(object):44"""Spec of a distribution for YOTO training or evaluation."""45# NOTE(adosovitskiy) Tried to do it with namedtuple, but failed to make46# it work with gin47
48def __init__(self, distribution_type, params, transform):49self.distribution_type = distribution_type50self.params = params51self.transform = transform52
53
54# TODO(adosovitskiy): have one signature with distributionspec and one without
55def get_samples_as_dicts(distribution_spec, num_samples=1,56names=None, seed=None):57"""Sample weight dictionaries for multi-loss problems.58
59Supports many different distribution specifications, including random
60distributions given via DistributionSpec or fixed sets of weights given by
61dictionaries or lists of dictionaries. The function first parses the different
62options and then actually computes the weights to be returned.
63
64Args:
65distribution_spec: One of the following:
66* An instance of DistributionSpec
67* DistributionSpec class
68* A dictionary mapping loss names to their values
69* A list of such dictionaries
70* A string representing one of the above
71num_samples: how many samples to return (only if given DistributionSpec)
72names: names of losses (only if given DistributionSpec)
73seed: random seed to use for sampling (only if given DistributionSpec)
74
75Returns:
76samples_dicts: list of dictionaries with the samples weights
77"""
78
79# If given a string, first eval it80if isinstance(distribution_spec, str):81distribution_spec = ast.literal_eval(distribution_spec)82
83# Now convert to a list of dictionaries or an instance of DistributionSpec84if isinstance(distribution_spec, dict):85given_keys = distribution_spec.keys()86if not (names is None or set(names) == set(given_keys)):87raise ValueError(88"Provided names {} do not match with the keys of the provided "89"dictionary {}".format(names, given_keys))90distribution_spec = [distribution_spec]91elif isinstance(distribution_spec, list):92if not (distribution_spec and93isinstance(distribution_spec[0], dict)):94raise ValueError(95"If distribution_spec is a list, it should be non-empty and "96"consist of dictionaries.")97given_keys = distribution_spec[0].keys()98if not (names is None or set(names) == set(given_keys)):99raise ValueError(100"Provided names {} do not match with the keys of the provided "101"dictionary {}".format(names, given_keys))102elif isinstance(distribution_spec, type):103distribution_spec = distribution_spec()104else:105raise TypeError(106"The distribution_spec should be a dictionary ot a list of dictionaries"107" or an instance of DistributionSpec or class DistributionSpec")108
109assert (isinstance(distribution_spec, DistributionSpec) or110isinstance(distribution_spec, list)), \111"By now distribution_spec should be a DistributionSpec or a list"112
113# Finally, actually make the samples114if isinstance(distribution_spec, DistributionSpec):115# Sample and convert to a list of dictionaries116samples = get_sample((num_samples, len(names)), distribution_spec,117seed=seed, return_numpy=True)118samples_dicts = []119for k in range(num_samples):120samples_dicts.append(121{name: samples[k, n] for n, name in enumerate(names)})122elif isinstance(distribution_spec, list):123samples_dicts = distribution_spec124
125return samples_dicts126
127
128def get_sample_untransformed(shape, distribution_type, distribution_params,129seed):130"""Get a distribution based on specification and parameters.131
132Parameters can be a list, in which case each of the list members is used to
133generate one row (or column?) of the resulting sample matrix. Otherwise, the
134same parameters are used for the whole matrix.
135
136Args:
137shape: Tuple/List representing the shape of the output
138distribution_type: DistributionType object
139distribution_params: Dict of distributon parameters
140seed: random seed to be used
141
142Returns:
143sample: TF Tensor with a sample from the distribution
144"""
145if isinstance(distribution_params, list):146if len(shape) != 2 or len(distribution_params) != shape[1]:147raise ValueError("If distribution_params is a list, the desired 'shape' "148"should be 2-dimensional and number of elements in the "149"list should match 'shape[1]'")150all_samples = []151for curr_params in distribution_params:152curr_samples = get_one_sample_untransformed([shape[0], 1],153distribution_type,154curr_params, seed)155all_samples.append(curr_samples)156return tf.concat(all_samples, axis=1)157else:158return get_one_sample_untransformed(shape, distribution_type,159distribution_params, seed)160
161
162def get_one_sample_untransformed(shape, distribution_type, distribution_params,163seed):164"""Get one untransoformed sample."""165if distribution_type == DistributionType.UNIFORM:166low, high = distribution_params["low"], distribution_params["high"]167distribution = tfd.Uniform(low=tf.constant(low, shape=shape[1:]),168high=tf.constant(high, shape=shape[1:],))169sample = distribution.sample(shape[0], seed=seed)170elif distribution_type == DistributionType.LOG_UNIFORM:171low, high = distribution_params["low"], distribution_params["high"]172distribution = tfd.Uniform(173low=tf.constant(np.log(low), shape=shape[1:], dtype=tf.float32),174high=tf.constant(np.log(high), shape=shape[1:], dtype=tf.float32))175sample = tf.exp(distribution.sample(shape[0], seed=seed))176else:177raise ValueError("Unknown distribution type {}".format(distribution_type))178return sample179
180
181def get_sample(shape, distribution_spec, seed=None, return_numpy=False):182"""Sample a tensor of random numbers.183
184Args:
185shape: shape of the resulting tensor
186distribution_spec: DistributionSpec
187seed: random seed to use for sampling
188return_numpy: if True, returns a fixed numpy array, otherwise - a TF op
189that allows sampling repeatedly
190
191Returns:
192samples: numpy array or TF op representing the random numbers
193"""
194distribution_type = distribution_spec.distribution_type # pytype: disable=attribute-error195distribution_params = distribution_spec.params # pytype: disable=attribute-error196transform_type = distribution_spec.transform # pytype: disable=attribute-error197
198sample_tf = get_sample_untransformed(shape, distribution_type,199distribution_params, seed)200
201if transform_type is not None:202transform = get_transform(transform_type)203sample_tf = transform(sample_tf)204
205if return_numpy:206with tf.Session() as sess:207sample_np = sess.run([sample_tf])[0]208return sample_np209else:210return sample_tf211
212
213def get_transform(transform_type):214"""Get transforms for converting raw samples to weights and back."""215if transform_type == TransformType.IDENTITY:216transform = lambda x: x217elif transform_type == TransformType.LOG:218transform = tf.log219else:220raise ValueError("Unknown transform type {}".format(transform_type))221return transform222