google-research
45 строк · 1.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Example code to create tfrecord for training."""
17
18import numpy as np
19import tensorflow as tf
20
21with tf.io.TFRecordWriter('example.tfrecord') as writer:
22for _ in range(5): # Iterate over 5 examples
23frames = 100 # Number of frames in the example video
24fps = 25 # FPS in the example video
25
26is_signing = np.random.randint(
27low=0, high=1, size=(frames), dtype='byte').tobytes()
28data = tf.io.serialize_tensor(
29tf.random.normal(shape=(frames, 1, 137, 2), dtype=tf.float32)).numpy()
30confidence = tf.io.serialize_tensor(
31tf.random.normal(shape=(frames, 1, 137), dtype=tf.float32)).numpy()
32
33features = {
34'fps':
35tf.train.Feature(int64_list=tf.train.Int64List(value=[fps])),
36'pose_data':
37tf.train.Feature(bytes_list=tf.train.BytesList(value=[data])),
38'pose_confidence':
39tf.train.Feature(bytes_list=tf.train.BytesList(value=[confidence])),
40'is_signing':
41tf.train.Feature(bytes_list=tf.train.BytesList(value=[is_signing]))
42}
43
44example = tf.train.Example(features=tf.train.Features(feature=features))
45writer.write(example.SerializeToString())
46