google-research
218 строк · 7.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""seq2act decoder."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import numpy as np
23import tensorflow.compat.v1 as tf
24from tensorflow.compat.v1 import estimator as tf_estimator
25from seq2act.models import input as input_utils
26from seq2act.models import seq2act_estimator
27from seq2act.models import seq2act_model
28from seq2act.utils import decode_utils
29
30flags = tf.flags
31FLAGS = flags.FLAGS
32
33flags.DEFINE_integer("beam_size", 1, "beam size")
34flags.DEFINE_string("problem", "android_howto", "problem")
35flags.DEFINE_string("data_files", "", "data_files")
36flags.DEFINE_string("checkpoint_path", "", "checkpoint_path")
37flags.DEFINE_string("output_dir", "", "output_dir")
38flags.DEFINE_integer("decode_batch_size", 1, "decode_batch_size")
39
40
41def get_input(hparams, data_files):
42"""Get the input."""
43if FLAGS.problem == "pixel_help":
44data_source = input_utils.DataSource.PIXEL_HELP
45elif FLAGS.problem == "android_howto":
46data_source = input_utils.DataSource.ANDROID_HOWTO
47elif FLAGS.problem == "rico_sca":
48data_source = input_utils.DataSource.RICO_SCA
49else:
50raise ValueError("Unrecognized test: %s" % FLAGS.problem)
51tf.logging.info("Testing data_source=%s data_files=%s" % (
52FLAGS.problem, data_files))
53dataset = input_utils.input_fn(
54data_files,
55FLAGS.decode_batch_size,
56repeat=1,
57data_source=data_source,
58max_range=hparams.max_span,
59max_dom_pos=hparams.max_dom_pos,
60max_pixel_pos=(
61hparams.max_pixel_pos),
62load_extra=True,
63load_dom_dist=(hparams.screen_encoder == "gcn"))
64iterator = tf.data.make_one_shot_iterator(dataset)
65features = iterator.get_next()
66return features
67
68
69def generate_action_mask(features):
70"""Computes the decode mask from "task" and "verb_refs"."""
71eos_positions = tf.to_int32(tf.expand_dims(
72tf.where(tf.equal(features["task"], 1))[:, 1], 1))
73decode_mask = tf.cumsum(tf.to_int32(
74tf.logical_and(
75tf.equal(features["verb_refs"][:, :, 0], eos_positions),
76tf.equal(features["verb_refs"][:, :, 1], eos_positions + 1))),
77axis=-1)
78decode_mask = tf.sequence_mask(
79tf.reduce_sum(tf.to_int32(tf.less(decode_mask, 1)), -1),
80maxlen=tf.shape(decode_mask)[1])
81return decode_mask
82
83
84def _decode_common(hparams):
85"""Common graph for decoding."""
86features = get_input(hparams, FLAGS.data_files)
87decode_features = {}
88for key in features:
89if key.endswith("_refs"):
90continue
91decode_features[key] = features[key]
92_, _, _, references = seq2act_model.compute_logits(
93features, hparams, mode=tf_estimator.ModeKeys.EVAL)
94decode_utils.decode_n_step(seq2act_model.compute_logits,
95decode_features, references["areas"],
96hparams, n=20,
97beam_size=FLAGS.beam_size)
98decode_mask = generate_action_mask(decode_features)
99return decode_features, decode_mask, features
100
101
102def to_string(name, seq):
103steps = []
104for step in seq:
105steps.append(",".join(map(str, step)))
106return name + " - ".join(steps)
107
108
109def ref_acc_to_string_list(task_seqs, ref_seqs, masks):
110"""Convert a seqs of refs to strings."""
111cra = 0.
112pra = 0.
113string_list = []
114for task, seq, mask in zip(task_seqs, ref_seqs, masks):
115# Assuming batch_size = 1
116string_list.append(task)
117string_list.append(to_string("gt_seq", seq["gt_seq"][0]))
118string_list.append(to_string("pred_seq", seq["pred_seq"][0][mask[0]]))
119string_list.append(
120"complete_seq_acc: " + str(
121seq["complete_seq_acc"]) + " partial_seq_acc: " + str(
122seq["partial_seq_acc"]))
123cra += seq["complete_seq_acc"]
124pra += seq["partial_seq_acc"]
125mcra = cra / len(ref_seqs)
126mpra = pra / len(ref_seqs)
127string_list.append("mean_complete_seq_acc: " + str(mcra) +(
128"mean_partial_seq_acc: " + str(mpra)))
129return string_list
130
131
132def save(task_seqs, seqs, masks, tag):
133string_list = ref_acc_to_string_list(task_seqs, seqs, masks)
134if not tf.gfile.Exists(FLAGS.output_dir):
135tf.gfile.MakeDirs(FLAGS.output_dir)
136with tf.gfile.GFile(os.path.join(FLAGS.output_dir, "decodes." + tag),
137mode="w") as f:
138for item in string_list:
139print(item)
140f.write(str(item))
141f.write("\n")
142
143
144def decode_fn(hparams):
145"""The main function."""
146decode_dict, decode_mask, label_dict = _decode_common(hparams)
147if FLAGS.problem != "android_howto":
148decode_dict["input_refs"] = decode_utils.unify_input_ref(
149decode_dict["verbs"], decode_dict["input_refs"])
150print_ops = []
151for key in ["raw_task", "verbs", "objects",
152"verb_refs", "obj_refs", "input_refs"]:
153print_ops.append(tf.print(key, tf.shape(decode_dict[key]), decode_dict[key],
154label_dict[key], "decode_mask", decode_mask,
155summarize=100))
156acc_metrics = decode_utils.compute_seq_metrics(
157label_dict, decode_dict, mask=None)
158saver = tf.train.Saver()
159with tf.Session() as session:
160session.run(tf.global_variables_initializer())
161latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
162tf.logging.info("Restoring from the latest checkpoint: %s" %
163(latest_checkpoint))
164saver.restore(session, latest_checkpoint)
165task_seqs = []
166ref_seqs = []
167act_seqs = []
168mask_seqs = []
169try:
170i = 0
171while True:
172tf.logging.info("Example %d" % i)
173task, acc, mask, label, decode = session.run([
174decode_dict["raw_task"], acc_metrics, decode_mask,
175label_dict, decode_dict
176])
177ref_seq = {}
178ref_seq["gt_seq"] = np.concatenate([
179label["verb_refs"], label["obj_refs"], label["input_refs"]],
180axis=-1)
181ref_seq["pred_seq"] = np.concatenate([
182decode["verb_refs"], decode["obj_refs"], decode["input_refs"]],
183axis=-1)
184ref_seq["complete_seq_acc"] = acc["complete_refs_acc"]
185ref_seq["partial_seq_acc"] = acc["partial_refs_acc"]
186act_seq = {}
187act_seq["gt_seq"] = np.concatenate([
188np.expand_dims(label["verbs"], 2),
189np.expand_dims(label["objects"], 2),
190label["input_refs"]], axis=-1)
191act_seq["pred_seq"] = np.concatenate([
192np.expand_dims(decode["verbs"], 2),
193np.expand_dims(decode["objects"], 2),
194decode["input_refs"]], axis=-1)
195act_seq["complete_seq_acc"] = acc["complete_acts_acc"]
196act_seq["partial_seq_acc"] = acc["partial_acts_acc"]
197print("task", task)
198print("ref_seq", ref_seq)
199print("act_seq", act_seq)
200print("mask", mask)
201task_seqs.append(task)
202ref_seqs.append(ref_seq)
203act_seqs.append(act_seq)
204mask_seqs.append(mask)
205i += 1
206except tf.errors.OutOfRangeError:
207pass
208save(task_seqs, ref_seqs, mask_seqs, "joint_refs")
209save(task_seqs, act_seqs, mask_seqs, "joint_act")
210
211
212def main(_):
213hparams = seq2act_estimator.load_hparams(FLAGS.checkpoint_path)
214hparams.set_hparam("batch_size", FLAGS.decode_batch_size)
215decode_fn(hparams)
216
217if __name__ == "__main__":
218tf.app.run()
219