google-research
82 строки · 3.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""RED-ACE config class."""
17from official.legacy.bert import configs18
19
20class RedAceConfig(configs.BertConfig):21"""Model configuration for RED-ACE."""22
23def __init__(24self,25vocab_size=30522,26hidden_size=768,27num_hidden_layers=12,28num_attention_heads=12,29intermediate_size=3072,30hidden_act="gelu",31hidden_dropout_prob=0.1,32attention_probs_dropout_prob=0.1,33max_position_embeddings=512,34type_vocab_size=2,35initializer_range=0.02,36num_classes=2,37enable_async_checkpoint=True,38):39"""Initializes an instance of RED-ACE configuration.40
41This initializer expects both the BERT specific arguments and the
42Transformer decoder arguments listed below.
43
44Args:
45vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
46hidden_size: Size of the encoder layers and the pooler layer.
47num_hidden_layers: Number of hidden layers in the Transformer encoder.
48num_attention_heads: Number of attention heads for each attention layer in
49the Transformer encoder.
50intermediate_size: The size of the "intermediate" (i.e., feed-forward)
51layer in the Transformer encoder.
52hidden_act: The non-linear activation function (function or string) in the
53encoder and pooler.
54hidden_dropout_prob: The dropout probability for all fully connected
55layers in the embeddings, encoder, and pooler.
56attention_probs_dropout_prob: The dropout ratio for the attention
57probabilities.
58max_position_embeddings: The maximum sequence length that this model might
59ever be used with. Typically set this to something large just in case
60(e.g., 512 or 1024 or 2048).
61type_vocab_size: The vocabulary size of the `token_type_ids` passed into
62`BertModel`.
63initializer_range: The stdev of the truncated_normal_initializer for
64initializing all weight matrices.
65num_classes: Number of tags.
66enable_async_checkpoint: If saving the model should happen asynchronously.
67"""
68super(RedAceConfig, self).__init__(69vocab_size=vocab_size,70hidden_size=hidden_size,71num_hidden_layers=num_hidden_layers,72num_attention_heads=num_attention_heads,73intermediate_size=intermediate_size,74hidden_act=hidden_act,75hidden_dropout_prob=hidden_dropout_prob,76attention_probs_dropout_prob=attention_probs_dropout_prob,77max_position_embeddings=max_position_embeddings,78type_vocab_size=type_vocab_size,79initializer_range=initializer_range,80)81self.num_classes = num_classes82self.enable_async_checkpoint = enable_async_checkpoint83