google-research

Форк
0
/
seq2act_decode.py 
218 строк · 7.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""seq2act decoder."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import os
22
import numpy as np
23
import tensorflow.compat.v1 as tf
24
from tensorflow.compat.v1 import estimator as tf_estimator
25
from seq2act.models import input as input_utils
26
from seq2act.models import seq2act_estimator
27
from seq2act.models import seq2act_model
28
from seq2act.utils import decode_utils
29

30
flags = tf.flags
31
FLAGS = flags.FLAGS
32

33
flags.DEFINE_integer("beam_size", 1, "beam size")
34
flags.DEFINE_string("problem", "android_howto", "problem")
35
flags.DEFINE_string("data_files", "", "data_files")
36
flags.DEFINE_string("checkpoint_path", "", "checkpoint_path")
37
flags.DEFINE_string("output_dir", "", "output_dir")
38
flags.DEFINE_integer("decode_batch_size", 1, "decode_batch_size")
39

40

41
def get_input(hparams, data_files):
42
  """Get the input."""
43
  if FLAGS.problem == "pixel_help":
44
    data_source = input_utils.DataSource.PIXEL_HELP
45
  elif FLAGS.problem == "android_howto":
46
    data_source = input_utils.DataSource.ANDROID_HOWTO
47
  elif FLAGS.problem == "rico_sca":
48
    data_source = input_utils.DataSource.RICO_SCA
49
  else:
50
    raise ValueError("Unrecognized test: %s" % FLAGS.problem)
51
  tf.logging.info("Testing data_source=%s data_files=%s" % (
52
      FLAGS.problem, data_files))
53
  dataset = input_utils.input_fn(
54
      data_files,
55
      FLAGS.decode_batch_size,
56
      repeat=1,
57
      data_source=data_source,
58
      max_range=hparams.max_span,
59
      max_dom_pos=hparams.max_dom_pos,
60
      max_pixel_pos=(
61
          hparams.max_pixel_pos),
62
      load_extra=True,
63
      load_dom_dist=(hparams.screen_encoder == "gcn"))
64
  iterator = tf.data.make_one_shot_iterator(dataset)
65
  features = iterator.get_next()
66
  return features
67

68

69
def generate_action_mask(features):
70
  """Computes the decode mask from "task" and "verb_refs"."""
71
  eos_positions = tf.to_int32(tf.expand_dims(
72
      tf.where(tf.equal(features["task"], 1))[:, 1], 1))
73
  decode_mask = tf.cumsum(tf.to_int32(
74
      tf.logical_and(
75
          tf.equal(features["verb_refs"][:, :, 0], eos_positions),
76
          tf.equal(features["verb_refs"][:, :, 1], eos_positions + 1))),
77
                          axis=-1)
78
  decode_mask = tf.sequence_mask(
79
      tf.reduce_sum(tf.to_int32(tf.less(decode_mask, 1)), -1),
80
      maxlen=tf.shape(decode_mask)[1])
81
  return decode_mask
82

83

84
def _decode_common(hparams):
85
  """Common graph for decoding."""
86
  features = get_input(hparams, FLAGS.data_files)
87
  decode_features = {}
88
  for key in features:
89
    if key.endswith("_refs"):
90
      continue
91
    decode_features[key] = features[key]
92
  _, _, _, references = seq2act_model.compute_logits(
93
      features, hparams, mode=tf_estimator.ModeKeys.EVAL)
94
  decode_utils.decode_n_step(seq2act_model.compute_logits,
95
                             decode_features, references["areas"],
96
                             hparams, n=20,
97
                             beam_size=FLAGS.beam_size)
98
  decode_mask = generate_action_mask(decode_features)
99
  return decode_features, decode_mask, features
100

101

102
def to_string(name, seq):
103
  steps = []
104
  for step in seq:
105
    steps.append(",".join(map(str, step)))
106
  return name + " - ".join(steps)
107

108

109
def ref_acc_to_string_list(task_seqs, ref_seqs, masks):
110
  """Convert a seqs of refs to strings."""
111
  cra = 0.
112
  pra = 0.
113
  string_list = []
114
  for task, seq, mask in zip(task_seqs, ref_seqs, masks):
115
    # Assuming batch_size = 1
116
    string_list.append(task)
117
    string_list.append(to_string("gt_seq", seq["gt_seq"][0]))
118
    string_list.append(to_string("pred_seq", seq["pred_seq"][0][mask[0]]))
119
    string_list.append(
120
        "complete_seq_acc: " + str(
121
            seq["complete_seq_acc"]) + " partial_seq_acc: " + str(
122
                seq["partial_seq_acc"]))
123
    cra += seq["complete_seq_acc"]
124
    pra += seq["partial_seq_acc"]
125
  mcra = cra / len(ref_seqs)
126
  mpra = pra / len(ref_seqs)
127
  string_list.append("mean_complete_seq_acc: " + str(mcra) +(
128
      "mean_partial_seq_acc: " + str(mpra)))
129
  return string_list
130

131

132
def save(task_seqs, seqs, masks, tag):
133
  string_list = ref_acc_to_string_list(task_seqs, seqs, masks)
134
  if not tf.gfile.Exists(FLAGS.output_dir):
135
    tf.gfile.MakeDirs(FLAGS.output_dir)
136
  with tf.gfile.GFile(os.path.join(FLAGS.output_dir, "decodes." + tag),
137
                      mode="w") as f:
138
    for item in string_list:
139
      print(item)
140
      f.write(str(item))
141
      f.write("\n")
142

143

144
def decode_fn(hparams):
145
  """The main function."""
146
  decode_dict, decode_mask, label_dict = _decode_common(hparams)
147
  if FLAGS.problem != "android_howto":
148
    decode_dict["input_refs"] = decode_utils.unify_input_ref(
149
        decode_dict["verbs"], decode_dict["input_refs"])
150
  print_ops = []
151
  for key in ["raw_task", "verbs", "objects",
152
              "verb_refs", "obj_refs", "input_refs"]:
153
    print_ops.append(tf.print(key, tf.shape(decode_dict[key]), decode_dict[key],
154
                              label_dict[key], "decode_mask", decode_mask,
155
                              summarize=100))
156
  acc_metrics = decode_utils.compute_seq_metrics(
157
      label_dict, decode_dict, mask=None)
158
  saver = tf.train.Saver()
159
  with tf.Session() as session:
160
    session.run(tf.global_variables_initializer())
161
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
162
    tf.logging.info("Restoring from the latest checkpoint: %s" %
163
                    (latest_checkpoint))
164
    saver.restore(session, latest_checkpoint)
165
    task_seqs = []
166
    ref_seqs = []
167
    act_seqs = []
168
    mask_seqs = []
169
    try:
170
      i = 0
171
      while True:
172
        tf.logging.info("Example %d" % i)
173
        task, acc, mask, label, decode = session.run([
174
            decode_dict["raw_task"], acc_metrics, decode_mask,
175
            label_dict, decode_dict
176
        ])
177
        ref_seq = {}
178
        ref_seq["gt_seq"] = np.concatenate([
179
            label["verb_refs"], label["obj_refs"], label["input_refs"]],
180
                                           axis=-1)
181
        ref_seq["pred_seq"] = np.concatenate([
182
            decode["verb_refs"], decode["obj_refs"], decode["input_refs"]],
183
                                             axis=-1)
184
        ref_seq["complete_seq_acc"] = acc["complete_refs_acc"]
185
        ref_seq["partial_seq_acc"] = acc["partial_refs_acc"]
186
        act_seq = {}
187
        act_seq["gt_seq"] = np.concatenate([
188
            np.expand_dims(label["verbs"], 2),
189
            np.expand_dims(label["objects"], 2),
190
            label["input_refs"]], axis=-1)
191
        act_seq["pred_seq"] = np.concatenate([
192
            np.expand_dims(decode["verbs"], 2),
193
            np.expand_dims(decode["objects"], 2),
194
            decode["input_refs"]], axis=-1)
195
        act_seq["complete_seq_acc"] = acc["complete_acts_acc"]
196
        act_seq["partial_seq_acc"] = acc["partial_acts_acc"]
197
        print("task", task)
198
        print("ref_seq", ref_seq)
199
        print("act_seq", act_seq)
200
        print("mask", mask)
201
        task_seqs.append(task)
202
        ref_seqs.append(ref_seq)
203
        act_seqs.append(act_seq)
204
        mask_seqs.append(mask)
205
        i += 1
206
    except tf.errors.OutOfRangeError:
207
      pass
208
    save(task_seqs, ref_seqs, mask_seqs, "joint_refs")
209
    save(task_seqs, act_seqs, mask_seqs, "joint_act")
210

211

212
def main(_):
213
  hparams = seq2act_estimator.load_hparams(FLAGS.checkpoint_path)
214
  hparams.set_hparam("batch_size", FLAGS.decode_batch_size)
215
  decode_fn(hparams)
216

217
if __name__ == "__main__":
218
  tf.app.run()
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.