google-research
199 строк · 6.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Example script for running VILA model for captioning."""
17from typing import Sequence
18
19from absl import app
20from absl import flags
21from absl import logging
22import jax
23import jax.numpy as jnp
24from lingvo import compat as tf
25from lingvo.core import tokenizers as lingvo_tokenizers
26from paxml import checkpoints
27from paxml import learners
28from paxml import tasks_lib
29from paxml import train_states
30from praxis import base_layer
31from praxis import optimizers
32from praxis import pax_fiddle
33from praxis import py_utils
34from praxis import pytypes
35from praxis import schedules
36
37from vila import coca_vila
38from vila import coca_vila_configs
39
40
41NestedMap = py_utils.NestedMap
42
43_CKPT_DIR = flags.DEFINE_string('ckpt_dir', '', 'Path to checkpoint.')
44_IS_PRETRAIN = flags.DEFINE_boolean(
45'is_pretrain',
46False,
47'Whether the checkpoint is pretrain or rank-based finetuned.',
48)
49_SPM_MODEL_PATH = flags.DEFINE_string(
50'spm_model_path', '', 'Path to sentence piece tokenizer model.'
51)
52_IMAGE_PATH = flags.DEFINE_string('image_path', '', 'Path to input image.')
53
54_PRE_CROP_SIZE = 272
55_IMAGE_SIZE = 224
56_MAX_TEXT_LEN = 64
57_TEXT_VOCAB_SIZE = 64000
58
59
60def load_vila_model(
61ckpt_dir, is_pretrain = False
62):
63"""Loads the VILA model from checkpoint directory.
64
65Args:
66ckpt_dir: The path to checkpoint directory
67is_pretrain: True for `CoCaVilaPretrain`, False for
68`CoCaVilaRankBasedFinetune`
69
70Returns:
71VILA model, VILA model states
72"""
73coca_config = coca_vila_configs.CocaVilaConfig()
74if is_pretrain:
75coca_config.model_type = coca_vila.CoCaVilaPretrain
76else:
77coca_config.model_type = coca_vila.CoCaVilaRankBasedFinetune
78coca_config.decoding_max_len = _MAX_TEXT_LEN
79coca_config.text_vocab_size = _TEXT_VOCAB_SIZE
80model_p = coca_vila_configs.build_coca_vila_model(coca_config)
81if not is_pretrain:
82model_p.model_dims = coca_config.model_dims
83model_p.generation_decode = True
84model = model_p.Instantiate()
85
86dummy_batch_size = 4 # For initialization only
87text_shape = (dummy_batch_size, 1, _MAX_TEXT_LEN)
88image_shape = (dummy_batch_size, _IMAGE_SIZE, _IMAGE_SIZE, 3)
89input_specs = NestedMap(
90ids=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.int32),
91image=jax.ShapeDtypeStruct(shape=image_shape, dtype=jnp.float32),
92paddings=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),
93# For initialization only
94labels=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),
95regression_labels=jax.ShapeDtypeStruct(
96shape=(dummy_batch_size, 10), dtype=jnp.float32
97),
98)
99prng_key = jax.random.PRNGKey(123)
100prng_key, _ = jax.random.split(prng_key)
101vars_weight_params = model.abstract_init_with_metadata(input_specs)
102
103# `learner` is only used for initialization.
104learner_p = pax_fiddle.Config(learners.Learner)
105learner_p.name = 'learner'
106learner_p.optimizer = pax_fiddle.Config(
107optimizers.ShardedAdafactor,
108decay_method='adam',
109lr_schedule=pax_fiddle.Config(schedules.Constant),
110)
111learner = learner_p.Instantiate()
112
113train_state_global_shapes = tasks_lib.create_state_unpadded_shapes(
114vars_weight_params, discard_opt_states=False, learners=[learner]
115)
116model_states = checkpoints.restore_checkpoint(
117train_state_global_shapes, ckpt_dir
118)
119return model, model_states
120
121
122def preprocess_image(
123image_path, pre_crop_size, image_size
124):
125"""Image preprocessing."""
126with tf.compat.v1.gfile.FastGFile(image_path, 'rb') as f:
127image_bytes = f.read()
128image = tf.io.decode_image(image_bytes, 3, expand_animations=False)
129image = tf.image.resize_bilinear(
130tf.expand_dims(image, 0), [pre_crop_size, pre_crop_size]
131)
132image = tf.image.resize_with_crop_or_pad(image, image_size, image_size)
133image = tf.cast(image, tf.float32)
134image = image / 255.0
135image = tf.clip_by_value(image, 0.0, 1.0)
136return image.numpy()
137
138
139class InputObject:
140"""The function `ids_to_strings` is called in model.process_decode_out."""
141
142def __init__(self, tokenizer):
143self.tokenizer = tokenizer
144
145def ids_to_strings(
146self, ids, lengths
147):
148decoded = self.tokenizer.IdsToStrings(ids, lengths).numpy()
149return [d.decode() for d in decoded]
150
151
152def main(_):
153# Suppresses verbose INFO/DEBUG log.
154logging.set_verbosity(logging.ERROR)
155model, model_states = load_vila_model(
156_CKPT_DIR.value, _IS_PRETRAIN.value
157)
158tokenizer_p = lingvo_tokenizers.SentencePieceTokenizer.Params().Set(
159spm_model=_SPM_MODEL_PATH.value,
160vocab_size=_TEXT_VOCAB_SIZE,
161)
162tokenizer = tokenizer_p.Instantiate()
163input_obj = InputObject(tokenizer)
164
165# Dummy text input, not actually used.
166ids, _, paddings = tokenizer.StringsToIds(
167['dummy text'], max_length=_MAX_TEXT_LEN
168)
169
170image = preprocess_image(_IMAGE_PATH.value, _PRE_CROP_SIZE, _IMAGE_SIZE)
171input_batch = NestedMap(
172ids=ids[tf.newaxis].numpy(),
173image=image,
174paddings=paddings[tf.newaxis].numpy(),
175)
176context_p = base_layer.JaxContext.HParams(do_eval=True)
177with base_layer.JaxContext(context_p):
178prng_key = jax.random.PRNGKey(0)
179prng_key, compute_key = jax.random.split(prng_key)
180(_, decode_out, _), _ = model.apply(
181{'params': model_states.mdl_vars['params']},
182input_batch,
183method=model.decode,
184rngs={base_layer.RANDOM: compute_key},
185mutable=[base_layer.DECODE_CACHE],
186)
187(process_metric, _, _), _ = model.apply(
188{'params': model_states.mdl_vars['params']},
189input_obj,
190decode_out,
191method=model.process_decode_out,
192mutable=[base_layer.DECODE_CACHE],
193)
194decoded_str_list = process_metric['decoded_str']
195print('===== VILA generated comment: ', decoded_str_list)
196
197
198if __name__ == '__main__':
199app.run(main)
200