google-research

Форк
0
/
run_vila_decode.py 
199 строк · 6.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Example script for running VILA model for captioning."""
17
from typing import Sequence
18

19
from absl import app
20
from absl import flags
21
from absl import logging
22
import jax
23
import jax.numpy as jnp
24
from lingvo import compat as tf
25
from lingvo.core import tokenizers as lingvo_tokenizers
26
from paxml import checkpoints
27
from paxml import learners
28
from paxml import tasks_lib
29
from paxml import train_states
30
from praxis import base_layer
31
from praxis import optimizers
32
from praxis import pax_fiddle
33
from praxis import py_utils
34
from praxis import pytypes
35
from praxis import schedules
36

37
from vila import coca_vila
38
from vila import coca_vila_configs
39

40

41
NestedMap = py_utils.NestedMap
42

43
_CKPT_DIR = flags.DEFINE_string('ckpt_dir', '', 'Path to checkpoint.')
44
_IS_PRETRAIN = flags.DEFINE_boolean(
45
    'is_pretrain',
46
    False,
47
    'Whether the checkpoint is pretrain or rank-based finetuned.',
48
)
49
_SPM_MODEL_PATH = flags.DEFINE_string(
50
    'spm_model_path', '', 'Path to sentence piece tokenizer model.'
51
)
52
_IMAGE_PATH = flags.DEFINE_string('image_path', '', 'Path to input image.')
53

54
_PRE_CROP_SIZE = 272
55
_IMAGE_SIZE = 224
56
_MAX_TEXT_LEN = 64
57
_TEXT_VOCAB_SIZE = 64000
58

59

60
def load_vila_model(
61
    ckpt_dir, is_pretrain = False
62
):
63
  """Loads the VILA model from checkpoint directory.
64

65
  Args:
66
    ckpt_dir: The path to checkpoint directory
67
    is_pretrain: True for `CoCaVilaPretrain`, False for
68
      `CoCaVilaRankBasedFinetune`
69

70
  Returns:
71
    VILA model, VILA model states
72
  """
73
  coca_config = coca_vila_configs.CocaVilaConfig()
74
  if is_pretrain:
75
    coca_config.model_type = coca_vila.CoCaVilaPretrain
76
  else:
77
    coca_config.model_type = coca_vila.CoCaVilaRankBasedFinetune
78
  coca_config.decoding_max_len = _MAX_TEXT_LEN
79
  coca_config.text_vocab_size = _TEXT_VOCAB_SIZE
80
  model_p = coca_vila_configs.build_coca_vila_model(coca_config)
81
  if not is_pretrain:
82
    model_p.model_dims = coca_config.model_dims
83
  model_p.generation_decode = True
84
  model = model_p.Instantiate()
85

86
  dummy_batch_size = 4  # For initialization only
87
  text_shape = (dummy_batch_size, 1, _MAX_TEXT_LEN)
88
  image_shape = (dummy_batch_size, _IMAGE_SIZE, _IMAGE_SIZE, 3)
89
  input_specs = NestedMap(
90
      ids=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.int32),
91
      image=jax.ShapeDtypeStruct(shape=image_shape, dtype=jnp.float32),
92
      paddings=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),
93
      # For initialization only
94
      labels=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),
95
      regression_labels=jax.ShapeDtypeStruct(
96
          shape=(dummy_batch_size, 10), dtype=jnp.float32
97
      ),
98
  )
99
  prng_key = jax.random.PRNGKey(123)
100
  prng_key, _ = jax.random.split(prng_key)
101
  vars_weight_params = model.abstract_init_with_metadata(input_specs)
102

103
  # `learner` is only used for initialization.
104
  learner_p = pax_fiddle.Config(learners.Learner)
105
  learner_p.name = 'learner'
106
  learner_p.optimizer = pax_fiddle.Config(
107
      optimizers.ShardedAdafactor,
108
      decay_method='adam',
109
      lr_schedule=pax_fiddle.Config(schedules.Constant),
110
  )
111
  learner = learner_p.Instantiate()
112

113
  train_state_global_shapes = tasks_lib.create_state_unpadded_shapes(
114
      vars_weight_params, discard_opt_states=False, learners=[learner]
115
  )
116
  model_states = checkpoints.restore_checkpoint(
117
      train_state_global_shapes, ckpt_dir
118
  )
119
  return model, model_states
120

121

122
def preprocess_image(
123
    image_path, pre_crop_size, image_size
124
):
125
  """Image preprocessing."""
126
  with tf.compat.v1.gfile.FastGFile(image_path, 'rb') as f:
127
    image_bytes = f.read()
128
  image = tf.io.decode_image(image_bytes, 3, expand_animations=False)
129
  image = tf.image.resize_bilinear(
130
      tf.expand_dims(image, 0), [pre_crop_size, pre_crop_size]
131
  )
132
  image = tf.image.resize_with_crop_or_pad(image, image_size, image_size)
133
  image = tf.cast(image, tf.float32)
134
  image = image / 255.0
135
  image = tf.clip_by_value(image, 0.0, 1.0)
136
  return image.numpy()
137

138

139
class InputObject:
140
  """The function `ids_to_strings` is called in model.process_decode_out."""
141

142
  def __init__(self, tokenizer):
143
    self.tokenizer = tokenizer
144

145
  def ids_to_strings(
146
      self, ids, lengths
147
  ):
148
    decoded = self.tokenizer.IdsToStrings(ids, lengths).numpy()
149
    return [d.decode() for d in decoded]
150

151

152
def main(_):
153
  # Suppresses verbose INFO/DEBUG log.
154
  logging.set_verbosity(logging.ERROR)
155
  model, model_states = load_vila_model(
156
      _CKPT_DIR.value, _IS_PRETRAIN.value
157
  )
158
  tokenizer_p = lingvo_tokenizers.SentencePieceTokenizer.Params().Set(
159
      spm_model=_SPM_MODEL_PATH.value,
160
      vocab_size=_TEXT_VOCAB_SIZE,
161
  )
162
  tokenizer = tokenizer_p.Instantiate()
163
  input_obj = InputObject(tokenizer)
164

165
  # Dummy text input, not actually used.
166
  ids, _, paddings = tokenizer.StringsToIds(
167
      ['dummy text'], max_length=_MAX_TEXT_LEN
168
  )
169

170
  image = preprocess_image(_IMAGE_PATH.value, _PRE_CROP_SIZE, _IMAGE_SIZE)
171
  input_batch = NestedMap(
172
      ids=ids[tf.newaxis].numpy(),
173
      image=image,
174
      paddings=paddings[tf.newaxis].numpy(),
175
  )
176
  context_p = base_layer.JaxContext.HParams(do_eval=True)
177
  with base_layer.JaxContext(context_p):
178
    prng_key = jax.random.PRNGKey(0)
179
    prng_key, compute_key = jax.random.split(prng_key)
180
    (_, decode_out, _), _ = model.apply(
181
        {'params': model_states.mdl_vars['params']},
182
        input_batch,
183
        method=model.decode,
184
        rngs={base_layer.RANDOM: compute_key},
185
        mutable=[base_layer.DECODE_CACHE],
186
    )
187
    (process_metric, _, _), _ = model.apply(
188
        {'params': model_states.mdl_vars['params']},
189
        input_obj,
190
        decode_out,
191
        method=model.process_decode_out,
192
        mutable=[base_layer.DECODE_CACHE],
193
    )
194
    decoded_str_list = process_metric['decoded_str']
195
  print('===== VILA generated comment: ', decoded_str_list)
196

197

198
if __name__ == '__main__':
199
  app.run(main)
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.