google-research
210 строк · 6.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Definitions of kernels for Gaussian Process models for UQ experiments."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import tensorflow.compat.v2 as tf23import tensorflow_probability as tfp24
25
26class RBFKernelFn(tf.keras.layers.Layer):27"""ExponentiatedQuadratic kernel provider."""28
29def __init__(self,30num_classes,31per_class_kernel,32feature_size,33initial_amplitude,34initial_length_scale,35initial_linear_bias,36initial_linear_slope,37add_linear=False,38name='vgp_kernel',39**kwargs):40super(RBFKernelFn, self).__init__(**kwargs)41self._per_class_kernel = per_class_kernel42self._initial_linear_bias = initial_linear_bias43self._initial_linear_slope = initial_linear_slope44self._add_linear = add_linear45
46with tf.compat.v1.variable_scope(name):47if self._per_class_kernel and num_classes > 1:48amplitude_shape = (num_classes,)49length_scale_shape = (num_classes, feature_size)50else:51amplitude_shape = ()52length_scale_shape = (feature_size,)53
54self._amplitude = self.add_variable(55initializer=tf.constant_initializer(initial_amplitude),56shape=amplitude_shape,57name='amplitude')58
59self._length_scale = self.add_variable(60initializer=tf.constant_initializer(initial_length_scale),61shape=length_scale_shape,62name='length_scale')63
64if self._add_linear:65self._linear_bias = self.add_variable(66initializer=tf.constant_initializer(self._initial_linear_bias),67shape=amplitude_shape,68name='linear_bias')69self._linear_slope = self.add_variable(70initializer=tf.constant_initializer(self._initial_linear_slope),71shape=amplitude_shape,72name='linear_slope')73
74def call(self, x):75# Never called -- this is just a layer so it can hold variables76# in a way Keras understands.77return x78
79@property80def kernel(self):81k = tfp.math.psd_kernels.FeatureScaled(82tfp.math.psd_kernels.ExponentiatedQuadratic(83amplitude=tf.nn.softplus(self._amplitude)),84scale_diag=tf.math.sqrt(tf.nn.softplus(self._length_scale)))85if self._add_linear:86k += tfp.math.psd_kernels.Linear(87bias_amplitude=self._linear_bias,88slope_amplitude=self._linear_slope)89return k90
91
92class MaternKernelFn(tf.keras.layers.Layer):93"""Matern kernel provider."""94
95def __init__(self,96num_classes,97degree,98per_class_kernel,99feature_size,100initial_amplitude,101initial_length_scale,102initial_linear_bias,103initial_linear_slope,104add_linear=False,105name='vgp_kernel',106**kwargs):107super(MaternKernelFn, self).__init__(**kwargs)108self._per_class_kernel = per_class_kernel109self._initial_linear_bias = initial_linear_bias110self._initial_linear_slope = initial_linear_slope111self._add_linear = add_linear112
113if degree not in [1, 3, 5]:114raise ValueError(115'Matern degree must be one of [1, 3, 5]: {}'.format(degree))116
117self._degree = degree118
119with tf.compat.v1.variable_scope(name):120if self._per_class_kernel and num_classes > 1:121amplitude_shape = (num_classes,)122length_scale_shape = (num_classes, feature_size)123else:124amplitude_shape = ()125length_scale_shape = (feature_size,)126
127self._amplitude = self.add_variable(128initializer=tf.constant_initializer(initial_amplitude),129shape=amplitude_shape,130name='amplitude')131
132self._length_scale = self.add_variable(133initializer=tf.constant_initializer(initial_length_scale),134shape=length_scale_shape,135name='length_scale')136
137if self._add_linear:138self._linear_bias = self.add_variable(139initializer=tf.constant_initializer(self._initial_linear_bias),140shape=amplitude_shape,141name='linear_bias')142self._linear_slope = self.add_variable(143initializer=tf.constant_initializer(self._initial_linear_slope),144shape=amplitude_shape,145name='linear_slope')146
147def call(self, x):148# Never called -- this is just a layer so it can hold variables149# in a way Keras understands.150return x151
152@property153def kernel(self):154if self._degree == 1:155kernel_class = tfp.math.psd_kernels.MaternOneHalf156if self._degree == 3:157kernel_class = tfp.math.psd_kernels.MaternThreeHalves158if self._degree == 5:159kernel_class = tfp.math.psd_kernels.MaternFiveHalves160
161k = tfp.math.psd_kernels.FeatureScaled(162kernel_class(amplitude=tf.nn.softplus(self._amplitude)),163scale_diag=tf.math.sqrt(tf.nn.softplus(self._length_scale)))164if self._add_linear:165k += tfp.math.psd_kernels.Linear(166bias_amplitude=self._linear_bias,167slope_amplitude=self._linear_slope)168return k169
170
171class LinearKernelFn(tf.keras.layers.Layer):172"""Matern kernel provider."""173
174def __init__(self,175num_classes,176per_class_kernel,177initial_linear_bias,178initial_linear_slope,179name='vgp_kernel',180**kwargs):181super(LinearKernelFn, self).__init__(**kwargs)182self._per_class_kernel = per_class_kernel183self._initial_linear_bias = initial_linear_bias184self._initial_linear_slope = initial_linear_slope185
186with tf.compat.v1.variable_scope(name):187if self._per_class_kernel and num_classes > 1:188shape = (num_classes,)189else:190shape = ()191
192self._linear_bias = self.add_variable(193initializer=tf.constant_initializer(self._initial_linear_bias),194shape=shape,195name='linear_bias')196self._linear_slope = self.add_variable(197initializer=tf.constant_initializer(self._initial_linear_slope),198shape=shape,199name='linear_slope')200
201def call(self, x):202# Never called -- this is just a layer so it can hold variables203# in a way Keras understands.204return x205
206@property207def kernel(self):208return tfp.math.psd_kernels.Linear(209bias_amplitude=self._linear_bias,210slope_amplitude=self._linear_slope)211