google-research

Форк
0
/
run_vila_predict.py 
190 строк · 6.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Example script for running VILA model."""
17

18
from absl import app
19
from absl import flags
20
from absl import logging
21
import jax
22
import jax.numpy as jnp
23
from lingvo import compat as tf
24
from lingvo.core import tokenizers as lingvo_tokenizers
25
from paxml import checkpoints
26
from paxml import learners
27
from paxml import tasks_lib
28
from paxml import train_states
29
from praxis import base_layer
30
from praxis import optimizers
31
from praxis import pax_fiddle
32
from praxis import py_utils
33
from praxis import schedules
34

35
from vila import coca_vila
36
from vila import coca_vila_configs
37

38

39
NestedMap = py_utils.NestedMap
40

41
_CKPT_DIR = flags.DEFINE_string('ckpt_dir', '', 'Path to checkpoint.')
42
_SPM_MODEL_PATH = flags.DEFINE_string(
43
    'spm_model_path', '', 'Path to sentence piece tokenizer model.'
44
)
45
_IMAGE_PATH = flags.DEFINE_string('image_path', '', 'Path to input image.')
46

47
_PRE_CROP_SIZE = 272
48
_IMAGE_SIZE = 224
49
_MAX_TEXT_LEN = 64
50
_TEXT_VOCAB_SIZE = 64000
51

52
_ZSL_QUALITY_PROMPTS = [
53
    ['good image', 'bad image'],
54
    ['good lighting', 'bad lighting'],
55
    ['good content', 'bad content'],
56
    ['good background', 'bad background'],
57
    ['good foreground', 'bad foreground'],
58
    ['good composition', 'bad composition'],
59
]
60

61

62
def load_vila_model(
63
    ckpt_dir,
64
):
65
  """Loads the VILA model from checkpoint directory.
66

67
  Args:
68
    ckpt_dir: The path to checkpoint directory
69

70
  Returns:
71
    VILA model, VILA model states
72
  """
73
  coca_config = coca_vila_configs.CocaVilaConfig()
74
  coca_config.model_type = coca_vila.CoCaVilaRankBasedFinetune
75
  coca_config.decoding_max_len = _MAX_TEXT_LEN
76
  coca_config.text_vocab_size = _TEXT_VOCAB_SIZE
77
  model_p = coca_vila_configs.build_coca_vila_model(coca_config)
78
  model_p.model_dims = coca_config.model_dims
79
  model = model_p.Instantiate()
80

81
  dummy_batch_size = 4  # For initialization only
82
  text_shape = (dummy_batch_size, 1, _MAX_TEXT_LEN)
83
  image_shape = (dummy_batch_size, _IMAGE_SIZE, _IMAGE_SIZE, 3)
84
  input_specs = NestedMap(
85
      ids=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.int32),
86
      image=jax.ShapeDtypeStruct(shape=image_shape, dtype=jnp.float32),
87
      paddings=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),
88
      # For initialization only
89
      labels=jax.ShapeDtypeStruct(shape=text_shape, dtype=jnp.float32),
90
      regression_labels=jax.ShapeDtypeStruct(
91
          shape=(dummy_batch_size, 10), dtype=jnp.float32
92
      ),
93
  )
94
  prng_key = jax.random.PRNGKey(123)
95
  prng_key, _ = jax.random.split(prng_key)
96
  vars_weight_params = model.abstract_init_with_metadata(input_specs)
97

98
  # `learner` is only used for initialization.
99
  learner_p = pax_fiddle.Config(learners.Learner)
100
  learner_p.name = 'learner'
101
  learner_p.optimizer = pax_fiddle.Config(
102
      optimizers.ShardedAdafactor,
103
      decay_method='adam',
104
      lr_schedule=pax_fiddle.Config(schedules.Constant),
105
  )
106
  learner = learner_p.Instantiate()
107

108
  train_state_global_shapes = tasks_lib.create_state_unpadded_shapes(
109
      vars_weight_params, discard_opt_states=False, learners=[learner]
110
  )
111
  model_states = checkpoints.restore_checkpoint(
112
      train_state_global_shapes, ckpt_dir
113
  )
114
  return model, model_states
115

116

117
def preprocess_image(
118
    image_path, pre_crop_size, image_size
119
):
120
  """Image preprocessing."""
121
  with tf.compat.v1.gfile.FastGFile(image_path, 'rb') as f:
122
    image_bytes = f.read()
123
  image = tf.io.decode_image(image_bytes, 3, expand_animations=False)
124
  image = tf.image.resize_bilinear(
125
      tf.expand_dims(image, 0), [pre_crop_size, pre_crop_size]
126
  )
127
  image = tf.image.resize_with_crop_or_pad(image, image_size, image_size)
128
  image = tf.cast(image, tf.float32)
129
  image = image / 255.0
130
  image = tf.clip_by_value(image, 0.0, 1.0)
131
  return image.numpy()
132

133

134
def main(_):
135
  # Suppresses verbose INFO/DEBUG log.
136
  logging.set_verbosity(logging.ERROR)
137
  model, model_states = load_vila_model(_CKPT_DIR.value)
138
  image = preprocess_image(_IMAGE_PATH.value, _PRE_CROP_SIZE, _IMAGE_SIZE)
139
  input_batch = NestedMap(
140
      image=image,
141
      ids=jnp.zeros((1, 1, _MAX_TEXT_LEN), dtype=jnp.int32),
142
      paddings=jnp.zeros((1, 1, _MAX_TEXT_LEN), dtype=jnp.int32),
143
  )
144

145
  context_p = base_layer.JaxContext.HParams(do_eval=True)
146
  with base_layer.JaxContext(context_p):
147
    predictions = model.apply(
148
        {'params': model_states.mdl_vars['params']},
149
        input_batch,
150
        method=model.compute_predictions,
151
    )
152
    quality_scores = predictions['quality_scores']
153
  print('===== VILA predicted quality score [0, 1]: ', quality_scores)
154

155
  if _SPM_MODEL_PATH.value:
156
    tokenizer_p = lingvo_tokenizers.SentencePieceTokenizer.Params().Set(
157
        spm_model=_SPM_MODEL_PATH.value,
158
        vocab_size=_TEXT_VOCAB_SIZE,
159
    )
160
    tokenizer = tokenizer_p.Instantiate()
161

162
    all_prompts = [p for pair in _ZSL_QUALITY_PROMPTS for p in pair]  # pylint: disable=g-complex-comprehension
163
    ids, _, paddings = tokenizer.StringsToIds(all_prompts, max_length=4)
164
    context_p = base_layer.JaxContext.HParams(do_eval=True)
165
    with base_layer.JaxContext.new_context(hparams=context_p):
166
      input_batch = NestedMap(
167
          ids=ids.numpy(),
168
          paddings=paddings.numpy(),
169
          image=jnp.zeros((1, 224, 224, 3)),
170
      )
171
      text_encoded = model.apply(
172
          {'params': model_states.mdl_vars['params']},
173
          input_batch,
174
          method=model.compute_text_embedding,
175
      )
176
      text_embed = text_encoded.contrastive_txt_embed_norm
177
      image_embed = predictions.contrastive_img_embed_norm
178

179
      zsl_scores = jnp.matmul(image_embed, text_embed.T)
180

181
      zsl_scores = zsl_scores.reshape([-1, len(_ZSL_QUALITY_PROMPTS), 2])
182

183
      zsl_scores = jax.nn.softmax(zsl_scores)
184
      zsl_scores = zsl_scores.mean(1)
185
      zsl_scores = zsl_scores[:, 0]
186
      print('===== VILA ZSL predicted quality score [0, 1]: ', zsl_scores)
187

188

189
if __name__ == '__main__':
190
  app.run(main)
191

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.