google-research
51 строка · 1.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Functions for loading dataset, fragment text into short pieces, mutating text with random words."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import numpy as np
21
22
23def fragment_into_short_sentence(text_list, label_list, fix_len, rate):
24"""Randomly fragment sentences into short pieces.
25
26Args:
27text_list: a list of text
28label_list: a list of class labels corresponding to the text
29fix_len: the length of fragmented text
30rate: the sampling rate. the number of fragments = int(rate *
31original_text_length).
32
33Returns:
34A list of fragmented texts and the corresponding class labels.
35"""
36n = len(text_list)
37text_frag_list = []
38label_frag_list = []
39
40for i in range(n):
41text_len = len(text_list[i])
42if text_len < fix_len:
43continue
44n_sample = int(rate * (text_len - fix_len + 1))
45if n_sample == 0:
46continue
47pos = np.random.choice((text_len - fix_len + 1), [n_sample])
48for j in pos:
49text_frag_list.append(text_list[i][j:(j + fix_len)])
50label_frag_list.append(label_list[i])
51return np.array(text_frag_list), np.array(label_frag_list)
52