google-research
190 строк · 6.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Example script for running VILA model."""
17
18from absl import app19from absl import flags20from absl import logging21import jax22import jax.numpy as jnp23from lingvo import compat as tf24from lingvo.core import tokenizers as lingvo_tokenizers25from paxml import checkpoints26from paxml import learners27from paxml import tasks_lib28from paxml import train_states29from praxis import base_layer30from praxis import optimizers31from praxis import pax_fiddle32from praxis import py_utils33from praxis import schedules34
35from vila import coca_vila36from vila import coca_vila_configs37
38
39NestedMap = py_utils.NestedMap40
41_CKPT_DIR = flags.DEFINE_string('ckpt_dir', '', 'Path to checkpoint.')42_SPM_MODEL_PATH = flags.DEFINE_string(43'spm_model_path', '', 'Path to sentence piece tokenizer model.'44)
45_IMAGE_PATH = flags.DEFINE_string('image_path', '', 'Path to input image.')46
47_PRE_CROP_SIZE = 27248_IMAGE_SIZE = 22449_MAX_TEXT_LEN = 6450_TEXT_VOCAB_SIZE = 6400051
52_ZSL_QUALITY_PROMPTS = [53['good image', 'bad image'],54['good lighting', 'bad lighting'],55['good content', 'bad content'],56['good background', 'bad background'],57['good foreground', 'bad foreground'],58['good composition', 'bad composition'],59]
60
61
62def load_vila_model(63ckpt_dir,64):65"""Loads the VILA model from checkpoint directory.66
67Args:
68ckpt_dir: The path to checkpoint directory
69
70Returns:
71VILA model, VILA model states
72"""
73coca_config = coca_vila_configs.CocaVilaConfig()74coca_config.model_type = coca_vila.CoCaVilaRankBasedFinetune75coca_config.decoding_max_len = _MAX_TEXT_LEN76coca_config.text_vocab_size = _TEXT_VOCAB_SIZE77model_p = coca_vila_configs.build_coca_vila_model(coca_config)78model_p.model_dims = coca_config.model_dims79model = model_p.Instantiate()80
81dummy_batch_size = 4 # For initialization only82text_shape = (dummy_batch_size, 1, _MAX_TEXT_LEN)83image_shape = (dummy_batch_size, _IMAGE_SIZE, _IMAGE_SIZE, 3)84input_specs = NestedMap(85ids=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.int32),86image=jax.ShapeDtypeStruct(shape=image_shape, dtype=jnp.float32),87paddings=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),88# For initialization only89labels=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),90regression_labels=jax.ShapeDtypeStruct(91shape=(dummy_batch_size, 10), dtype=jnp.float3292),93)94prng_key = jax.random.PRNGKey(123)95prng_key, _ = jax.random.split(prng_key)96vars_weight_params = model.abstract_init_with_metadata(input_specs)97
98# `learner` is only used for initialization.99learner_p = pax_fiddle.Config(learners.Learner)100learner_p.name = 'learner'101learner_p.optimizer = pax_fiddle.Config(102optimizers.ShardedAdafactor,103decay_method='adam',104lr_schedule=pax_fiddle.Config(schedules.Constant),105)106learner = learner_p.Instantiate()107
108train_state_global_shapes = tasks_lib.create_state_unpadded_shapes(109vars_weight_params, discard_opt_states=False, learners=[learner]110)111model_states = checkpoints.restore_checkpoint(112train_state_global_shapes, ckpt_dir113)114return model, model_states115
116
117def preprocess_image(118image_path, pre_crop_size, image_size119):120"""Image preprocessing."""121with tf.compat.v1.gfile.FastGFile(image_path, 'rb') as f:122image_bytes = f.read()123image = tf.io.decode_image(image_bytes, 3, expand_animations=False)124image = tf.image.resize_bilinear(125tf.expand_dims(image, 0), [pre_crop_size, pre_crop_size]126)127image = tf.image.resize_with_crop_or_pad(image, image_size, image_size)128image = tf.cast(image, tf.float32)129image = image / 255.0130image = tf.clip_by_value(image, 0.0, 1.0)131return image.numpy()132
133
134def main(_):135# Suppresses verbose INFO/DEBUG log.136logging.set_verbosity(logging.ERROR)137model, model_states = load_vila_model(_CKPT_DIR.value)138image = preprocess_image(_IMAGE_PATH.value, _PRE_CROP_SIZE, _IMAGE_SIZE)139input_batch = NestedMap(140image=image,141ids=jnp.zeros((1, 1, _MAX_TEXT_LEN), dtype=jnp.int32),142paddings=jnp.zeros((1, 1, _MAX_TEXT_LEN), dtype=jnp.int32),143)144
145context_p = base_layer.JaxContext.HParams(do_eval=True)146with base_layer.JaxContext(context_p):147predictions = model.apply(148{'params': model_states.mdl_vars['params']},149input_batch,150method=model.compute_predictions,151)152quality_scores = predictions['quality_scores']153print('===== VILA predicted quality score [0, 1]: ', quality_scores)154
155if _SPM_MODEL_PATH.value:156tokenizer_p = lingvo_tokenizers.SentencePieceTokenizer.Params().Set(157spm_model=_SPM_MODEL_PATH.value,158vocab_size=_TEXT_VOCAB_SIZE,159)160tokenizer = tokenizer_p.Instantiate()161
162all_prompts = [p for pair in _ZSL_QUALITY_PROMPTS for p in pair] # pylint: disable=g-complex-comprehension163ids, _, paddings = tokenizer.StringsToIds(all_prompts, max_length=4)164context_p = base_layer.JaxContext.HParams(do_eval=True)165with base_layer.JaxContext.new_context(hparams=context_p):166input_batch = NestedMap(167ids=ids.numpy(),168paddings=paddings.numpy(),169image=jnp.zeros((1, 224, 224, 3)),170)171text_encoded = model.apply(172{'params': model_states.mdl_vars['params']},173input_batch,174method=model.compute_text_embedding,175)176text_embed = text_encoded.contrastive_txt_embed_norm177image_embed = predictions.contrastive_img_embed_norm178
179zsl_scores = jnp.matmul(image_embed, text_embed.T)180
181zsl_scores = zsl_scores.reshape([-1, len(_ZSL_QUALITY_PROMPTS), 2])182
183zsl_scores = jax.nn.softmax(zsl_scores)184zsl_scores = zsl_scores.mean(1)185zsl_scores = zsl_scores[:, 0]186print('===== VILA ZSL predicted quality score [0, 1]: ', zsl_scores)187
188
189if __name__ == '__main__':190app.run(main)191