google-research
41 строка · 1.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Model definitions for dsprites and 3dident experiments.
17"""
18
19from absl import flags20
21import tensorflow.compat.v2 as tf22
23
24FLAGS = flags.FLAGS25
26
27class LinearLayerOverPretrainedSimclrModel(tf.keras.Model):28"""Trainable linear evaluation layer over a pretrained SimCLR model.29"""
30
31def __init__(self, path, optimizer, num_classes):32super().__init__()33self.saved_model = tf.saved_model.load(path)34self.dense_layer = tf.keras.layers.Dense(35units=num_classes, name='affine_transform')36self.optimizer = optimizer37
38def call(self, x):39outputs = self.saved_model(x, trainable=False)40pred_t = self.dense_layer(outputs['final_avg_pool'])41return pred_t42
43