google-research

Форк
0
438 строк · 13.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Functions shared among files under word2act/data_generation."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23
import os
24

25
import attr
26
from enum import Enum
27
import numpy as np
28
import tensorflow.compat.v1 as tf  # tf
29

30
from seq2act.data_generation import config
31
from seq2act.data_generation import view_hierarchy
32

33

34
gfile = tf.gfile
35

36

37
@attr.s
38
class MaxValues(object):
39
  """Represents max values for a task and UI."""
40

41
  # For instrction
42
  max_word_num = attr.ib(default=None)
43
  max_word_length = attr.ib(default=None)
44

45
  # For UI objects
46
  max_ui_object_num = attr.ib(default=None)
47
  max_ui_object_word_num = attr.ib(default=None)
48
  max_ui_object_word_length = attr.ib(default=None)
49

50
  def update(self, other):
51
    """Update max value from another MaxValues instance.
52

53
    This will be used when want to merge several MaxValues instances:
54

55
      max_values_list = ...
56
      result = MaxValues()
57
      for v in max_values_list:
58
        result.update(v)
59

60
    Then `result` contains merged max values in each field.
61

62
    Args:
63
      other: another MaxValues instance, contains updated data.
64
    """
65
    self.max_word_num = max(self.max_word_num, other.max_word_num)
66
    self.max_word_length = max(self.max_word_length, other.max_word_length)
67
    self.max_ui_object_num = max(self.max_ui_object_num,
68
                                 other.max_ui_object_num)
69
    self.max_ui_object_word_num = max(self.max_ui_object_word_num,
70
                                      other.max_ui_object_word_num)
71
    self.max_ui_object_word_length = max(self.max_ui_object_word_length,
72
                                         other.max_ui_object_word_length)
73

74

75
class ActionRules(Enum):
76
  """The rule_id to generate synthetic action."""
77
  SINGLE_OBJECT_RULE = 0
78
  GRID_CONTEXT_RULE = 1
79
  NEIGHBOR_CONTEXT_RULE = 2
80
  SWIPE_TO_OBJECT_RULE = 3
81
  SWIPE_TO_DIRECTION_RULE = 4
82
  REAL = 5  # The action is not generated, but a real user action.
83
  CROWD_COMPUTE = 6
84
  DIRECTION_VERB_RULE = 7  # For win, "click button under some tab/combobox
85
  CONSUMED_MULTI_STEP = 8  # For win, if the target verb is not direction_verb
86
  UNCONSUMED_MULTI_STEP = 9
87
  NO_VERB_RULE = 10
88

89

90
class ActionTypes(Enum):
91
  """The action types and ids of Android actions."""
92
  CLICK = 2
93
  INPUT = 3
94
  SWIPE = 4
95
  CHECK = 5
96
  UNCHECK = 6
97
  LONG_CLICK = 7
98
  OTHERS = 8
99
  GO_HOME = 9
100
  GO_BACK = 10
101

102

103
VERB_ID_MAP = {
104
    'check': ActionTypes.CHECK,
105
    'find': ActionTypes.SWIPE,
106
    'navigate': ActionTypes.SWIPE,
107
    'uncheck': ActionTypes.UNCHECK,
108
    'head to': ActionTypes.SWIPE,
109
    'enable': ActionTypes.CHECK,
110
    'turn on': ActionTypes.CHECK,
111
    'locate': ActionTypes.SWIPE,
112
    'disable': ActionTypes.UNCHECK,
113
    'tap and hold': ActionTypes.LONG_CLICK,
114
    'long press': ActionTypes.LONG_CLICK,
115
    'look': ActionTypes.SWIPE,
116
    'press and hold': ActionTypes.LONG_CLICK,
117
    'turn it on': ActionTypes.CHECK,
118
    'turn off': ActionTypes.UNCHECK,
119
    'switch on': ActionTypes.CHECK,
120
    'visit': ActionTypes.SWIPE,
121
    'hold': ActionTypes.LONG_CLICK,
122
    'switch off': ActionTypes.UNCHECK,
123
    'head': ActionTypes.SWIPE,
124
    'head over': ActionTypes.SWIPE,
125
    'long-press': ActionTypes.LONG_CLICK,
126
    'un-click': ActionTypes.UNCHECK,
127
    'tap': ActionTypes.CLICK,
128
    'check off': ActionTypes.UNCHECK,
129
    # 'power on': 21
130
}
131

132

133
class WinActionTypes(Enum):
134
  """The action types and ids of windows actions."""
135
  LEFT_CLICK = 2
136
  RIGHT_CLICK = 3
137
  DOUBLE_CLICK = 4
138
  INPUT = 5
139

140

141
@attr.s
142
class Action(object):
143
  """The class for a word2act action."""
144
  instruction_str = attr.ib(default=None)
145
  verb_str = attr.ib(default=None)
146
  obj_desc_str = attr.ib(default=None)
147
  input_content_str = attr.ib(default=None)
148
  action_type = attr.ib(default=None)
149
  action_rule = attr.ib(default=None)
150
  target_obj_idx = attr.ib(default=None)
151
  obj_str_pos = attr.ib(default=None)
152
  input_str_pos = attr.ib(default=None)
153
  verb_str_pos = attr.ib(default=None)
154
  # start/end position of one whole step
155
  step_str_pos = attr.ib(default=[0, 0])
156
  # Defalt action is 1-step consumed action
157
  is_consumed = attr.ib(default=True)
158

159
  def __eq__(self, other):
160
    if not isinstance(other, Action):
161
      return NotImplemented
162
    return self.instruction_str == other.instruction_str
163

164
  def is_valid(self):
165
    """Does valid check for action instance.
166

167
    Returns true when any component is None or obj_desc_str is all spaces.
168

169
    Returns:
170
      a boolean
171
    """
172
    invalid_obj_pos = (np.array(self.obj_str_pos) == 0).all()
173
    if (not self.instruction_str or invalid_obj_pos or
174
        not self.obj_desc_str.strip()):
175
      return False
176

177
    return True
178

179
  def has_valid_input(self):
180
    """Does valid check for input positions.
181

182
    Returns true when input_str_pos is not all default value.
183

184
    Returns:
185
      a boolean
186
    """
187
    return (self.input_str_pos != np.array([
188
        config.LABEL_DEFAULT_VALUE_INT, config.LABEL_DEFAULT_VALUE_INT
189
    ])).any()
190

191
  def regularize_strs(self):
192
    """Trims action instance's obj_desc_str, input_content_str, verb_str."""
193
    self.obj_desc_str = self.obj_desc_str.strip()
194
    self.input_content_str = self.input_content_str.strip()
195
    self.verb_str = self.verb_str.strip()
196

197
  def convert_to_lower_case(self):
198
    self.instruction_str = self.instruction_str.lower()
199
    self.obj_desc_str = self.obj_desc_str.lower()
200
    self.input_content_str = self.input_content_str.lower()
201
    self.verb_str = self.verb_str.lower()
202

203

204
@attr.s
205
class ActionEvent(object):
206
  """This class defines ActionEvent class.
207

208
  ActionEvent is high level event summarized from low level android event logs.
209
  This example shows the android event logs and the extracted ActionEvent
210
  object:
211

212
  Android Event Logs:
213
  [      42.407808] EV_ABS       ABS_MT_TRACKING_ID   00000000
214
  [      42.407808] EV_ABS       ABS_MT_TOUCH_MAJOR   00000004
215
  [      42.407808] EV_ABS       ABS_MT_PRESSURE      00000081
216
  [      42.407808] EV_ABS       ABS_MT_POSITION_X    00004289
217
  [      42.407808] EV_ABS       ABS_MT_POSITION_Y    00007758
218
  [      42.407808] EV_SYN       SYN_REPORT           00000000
219
  [      42.453256] EV_ABS       ABS_MT_PRESSURE      00000000
220
  [      42.453256] EV_ABS       ABS_MT_TRACKING_ID   ffffffff
221
  [      42.453256] EV_SYN       SYN_REPORT           00000000
222

223
  This log can be generated from this command during runing android emulator:
224
  adb shell getevent -lt /dev/input/event1
225

226
  If screen pixel size is [480,800], this is the extracted ActionEvent Object:
227
    ActionEvent(
228
      event_time = 42.407808
229
      action_type = ActionTypes.CLICK
230
      action_object_id = -1
231
      coordinates_x = [17033,]
232
      coordinates_y = [30552,]
233
      coordinates_x_pixel = [249,]
234
      coordinates_y_pixel = [747,]
235
      action_params = []
236
    )
237
  """
238

239
  event_time = attr.ib()
240
  action_type = attr.ib()
241
  coordinates_x = attr.ib()
242
  coordinates_y = attr.ib()
243
  action_params = attr.ib()
244
  # These fields will be generated by public method update_info_from_screen()
245
  coordinates_x_pixel = None
246
  coordinates_y_pixel = None
247
  object_id = config.LABEL_DEFAULT_INVALID_INT
248
  leaf_nodes = None  # If dedup, the nodes here will be less than XML
249
  debug_target_object_word_sequence = None
250

251
  def update_info_from_screen(self, screen_info, dedup=False):
252
    """Updates action event attributes from screen_info.
253

254
    Updates coordinates_x(y)_pixel and object_id from the screen_info proto.
255

256
    Args:
257
      screen_info: ScreenInfo protobuf
258
      dedup: whether dedup the UI objs with same text or content desc.
259
    Raises:
260
      ValueError when fail to find object id.
261
    """
262
    self.update_norm_coordinates((config.SCREEN_WIDTH, config.SCREEN_HEIGHT))
263
    vh = view_hierarchy.ViewHierarchy()
264
    vh.load_xml(screen_info.view_hierarchy.xml.encode('utf-8'))
265
    if dedup:
266
      vh.dedup((self.coordinates_x_pixel[0], self.coordinates_y_pixel[0]))
267
    self.leaf_nodes = vh.get_leaf_nodes()
268
    ui_object_list = vh.get_ui_objects()
269
    self._update_object_id(ui_object_list)
270

271
  def _update_object_id(self, ui_object_list):
272
    """Updates ui object index from view_hierarchy.
273

274
    If point(X,Y) surrounded by multiple UI objects, select the one with
275
    smallest area.
276

277
    Args:
278
      ui_object_list: .
279
    Raises:
280
      ValueError when fail to find object id.
281
    """
282
    smallest_area = -1
283
    for index, ui_obj in enumerate(ui_object_list):
284
      box = ui_obj.bounding_box
285
      if (box.x1 <= self.coordinates_x_pixel[0] <= box.x2 and
286
          box.y1 <= self.coordinates_y_pixel[0] <= box.y2):
287
        area = (box.x2 - box.x1) * (box.y2 - box.y1)
288
        if smallest_area == -1 or area < smallest_area:
289
          self.object_id = index
290
          self.debug_target_object_word_sequence = ui_obj.word_sequence
291
          smallest_area = area
292

293
    if smallest_area == -1:
294
      raise ValueError(('Object id not found: x,y=%d,%d coordinates fail to '
295
                        'match every UI bounding box') %
296
                       (self.coordinates_x_pixel[0],
297
                        self.coordinates_y_pixel[0]))
298

299
  def update_norm_coordinates(self, screen_size):
300
    """Update coordinates_x(y)_norm according to screen_size.
301

302
    self.coordinate_x is scaled between [0, ANDROID_LOG_MAX_ABS_X]
303
    self.coordinate_y is scaled between [0, ANDROID_LOG_MAX_ABS_Y]
304
    This function recovers coordinate of android event logs back to coordinate
305
    in real screen's pixel level.
306

307
    coordinates_x_pixel = coordinates_x/ANDROID_LOG_MAX_ABS_X*horizontal_pixel
308
    coordinates_y_pixel = coordinates_y/ANDROID_LOG_MAX_ABS_Y*vertical_pixel
309

310
    For example,
311
    ANDROID_LOG_MAX_ABS_X = ANDROID_LOG_MAX_ABS_Y = 32676
312
    coordinate_x = [17033, ]
313
    object_cords_y = [30552, ]
314
    screen_size = (480, 800)
315
    Then the updated pixel coordinates are as follow:
316
      coordinates_x_pixel = [250, ]
317
      coordinates_y_pixel = [747, ]
318

319
    Args:
320
      screen_size: a tuple of screen pixel size.
321
    """
322
    (horizontal_pixel, vertical_pixel) = screen_size
323
    self.coordinates_x_pixel = [
324
        int(cord * horizontal_pixel / config.ANDROID_LOG_MAX_ABS_X)
325
        for cord in self.coordinates_x
326
    ]
327
    self.coordinates_y_pixel = [
328
        int(cord * vertical_pixel / config.ANDROID_LOG_MAX_ABS_Y)
329
        for cord in self.coordinates_y
330
    ]
331

332

333
# For Debug: Get distribution info for each cases
334
word_num_distribution_dict = collections.defaultdict(int)
335
word_length_distribution_dict = collections.defaultdict(int)
336

337

338
def get_word_statistics(file_path):
339
  """Calculates maximum word number/length from ui objects in one xml/json file.
340

341
  Args:
342
    file_path: The full path of a xml/json file.
343

344
  Returns:
345
    A tuple (max_word_num, max_word_length)
346
      ui_object_num: UI object num.
347
      max_word_num: The maximum number of words contained in all ui objects.
348
      max_word_length: The maximum length of words contained in all ui objects.
349
  """
350
  max_word_num = 0
351
  max_word_length = 0
352

353
  leaf_nodes = get_view_hierarchy_list(file_path)
354
  for view_hierarchy_object in leaf_nodes:
355
    word_sequence = view_hierarchy_object.uiobject.word_sequence
356
    max_word_num = max(max_word_num, len(word_sequence))
357
    word_num_distribution_dict[len(word_sequence)] += 1
358

359
    for word in word_sequence:
360
      max_word_length = max(max_word_length, len(word))
361
      word_length_distribution_dict[len(word)] += 1
362
  return len(leaf_nodes), max_word_num, max_word_length
363

364

365
def get_ui_max_values(file_paths):
366
  """Calculates max values from ui objects in multi xml/json files.
367

368
  Args:
369
    file_paths: The full paths of multi xml/json files.
370
  Returns:
371
    max_values: instrance of MaxValues.
372
  """
373
  max_values = MaxValues()
374
  for file_path in file_paths:
375
    (ui_object_num,
376
     max_ui_object_word_num,
377
     max_ui_object_word_length) = get_word_statistics(file_path)
378

379
    max_values.max_ui_object_num = max(
380
        max_values.max_ui_object_num, ui_object_num)
381
    max_values.max_ui_object_word_num = max(
382
        max_values.max_ui_object_word_num, max_ui_object_word_num)
383
    max_values.max_ui_object_word_length = max(
384
        max_values.max_ui_object_word_length, max_ui_object_word_length)
385
  return max_values
386

387

388
def get_ui_object_list(file_path):
389
  """Gets ui object list from view hierarchy leaf nodes.
390

391
  Args:
392
    file_path: file path of xml or json
393
  Returns:
394
    A list of ui objects according to view hierarchy leaf nodes.
395
  """
396

397
  vh = _get_view_hierachy(file_path)
398
  return vh.get_ui_objects()
399

400

401
def get_view_hierarchy_list(file_path):
402
  """Gets view hierarchy leaf node list.
403

404
  Args:
405
    file_path: file path of xml or json
406
  Returns:
407
    A list of view hierarchy leaf nodes.
408
  """
409
  vh = _get_view_hierachy(file_path)
410
  return vh.get_leaf_nodes()
411

412

413
def _get_view_hierachy(file_path):
414
  """Gets leaf nodes view hierarchy lists.
415

416
  Args:
417
    file_path: The full path of an input xml/json file.
418
  Returns:
419
    A ViewHierarchy object.
420
  Raises:
421
    ValueError: unsupported file format.
422
  """
423
  with gfile.GFile(file_path, 'r') as f:
424
    data = f.read()
425

426
  _, file_extension = os.path.splitext(file_path)
427
  if file_extension == '.xml':
428
    vh = view_hierarchy.ViewHierarchy(
429
        screen_width=config.SCREEN_WIDTH, screen_height=config.SCREEN_HEIGHT)
430
    vh.load_xml(data)
431
  elif file_extension == '.json':
432
    vh = view_hierarchy.ViewHierarchy(
433
        screen_width=config.RICO_SCREEN_WIDTH,
434
        screen_height=config.RICO_SCREEN_HEIGHT)
435
    vh.load_json(data)
436
  else:
437
    raise ValueError('unsupported file format %s' % file_extension)
438
  return vh
439

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.