transformers

Форк
0
/
test_modeling_esmfold.py 
278 строк · 9.7 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch ESM model. """
16

17

18
import unittest
19

20
from transformers import EsmConfig, is_torch_available
21
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
22

23
from ...test_configuration_common import ConfigTester
24
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
25
from ...test_pipeline_mixin import PipelineTesterMixin
26

27

28
if is_torch_available():
29
    import torch
30

31
    from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
32

33

34
class EsmFoldModelTester:
35
    def __init__(
36
        self,
37
        parent,
38
        batch_size=13,
39
        seq_length=7,
40
        is_training=False,
41
        use_input_mask=True,
42
        use_token_type_ids=False,
43
        use_labels=False,
44
        vocab_size=19,
45
        hidden_size=32,
46
        num_hidden_layers=2,
47
        num_attention_heads=4,
48
        intermediate_size=37,
49
        hidden_act="gelu",
50
        hidden_dropout_prob=0.1,
51
        attention_probs_dropout_prob=0.1,
52
        max_position_embeddings=512,
53
        type_vocab_size=16,
54
        type_sequence_label_size=2,
55
        initializer_range=0.02,
56
        num_labels=3,
57
        num_choices=4,
58
        scope=None,
59
    ):
60
        self.parent = parent
61
        self.batch_size = batch_size
62
        self.seq_length = seq_length
63
        self.is_training = is_training
64
        self.use_input_mask = use_input_mask
65
        self.use_token_type_ids = use_token_type_ids
66
        self.use_labels = use_labels
67
        self.vocab_size = vocab_size
68
        self.hidden_size = hidden_size
69
        self.num_hidden_layers = num_hidden_layers
70
        self.num_attention_heads = num_attention_heads
71
        self.intermediate_size = intermediate_size
72
        self.hidden_act = hidden_act
73
        self.hidden_dropout_prob = hidden_dropout_prob
74
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
75
        self.max_position_embeddings = max_position_embeddings
76
        self.type_vocab_size = type_vocab_size
77
        self.type_sequence_label_size = type_sequence_label_size
78
        self.initializer_range = initializer_range
79
        self.num_labels = num_labels
80
        self.num_choices = num_choices
81
        self.scope = scope
82

83
    def prepare_config_and_inputs(self):
84
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
85

86
        input_mask = None
87
        if self.use_input_mask:
88
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
89

90
        sequence_labels = None
91
        token_labels = None
92
        choice_labels = None
93
        if self.use_labels:
94
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
95
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
96
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
97

98
        config = self.get_config()
99

100
        return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
101

102
    def get_config(self):
103
        esmfold_config = {
104
            "trunk": {
105
                "num_blocks": 2,
106
                "sequence_state_dim": 64,
107
                "pairwise_state_dim": 16,
108
                "sequence_head_width": 4,
109
                "pairwise_head_width": 4,
110
                "position_bins": 4,
111
                "chunk_size": 16,
112
                "structure_module": {
113
                    "ipa_dim": 16,
114
                    "num_angles": 7,
115
                    "num_blocks": 2,
116
                    "num_heads_ipa": 4,
117
                    "pairwise_dim": 16,
118
                    "resnet_dim": 16,
119
                    "sequence_dim": 48,
120
                },
121
            },
122
            "fp16_esm": False,
123
            "lddt_head_hid_dim": 16,
124
        }
125
        config = EsmConfig(
126
            vocab_size=33,
127
            hidden_size=self.hidden_size,
128
            pad_token_id=1,
129
            num_hidden_layers=self.num_hidden_layers,
130
            num_attention_heads=self.num_attention_heads,
131
            intermediate_size=self.intermediate_size,
132
            hidden_act=self.hidden_act,
133
            hidden_dropout_prob=self.hidden_dropout_prob,
134
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
135
            max_position_embeddings=self.max_position_embeddings,
136
            type_vocab_size=self.type_vocab_size,
137
            initializer_range=self.initializer_range,
138
            is_folding_model=True,
139
            esmfold_config=esmfold_config,
140
        )
141
        return config
142

143
    def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
144
        model = EsmForProteinFolding(config=config).float()
145
        model.to(torch_device)
146
        model.eval()
147
        result = model(input_ids, attention_mask=input_mask)
148
        result = model(input_ids)
149
        result = model(input_ids)
150

151
        self.parent.assertEqual(result.positions.shape, (2, self.batch_size, self.seq_length, 14, 3))
152
        self.parent.assertEqual(result.angles.shape, (2, self.batch_size, self.seq_length, 7, 2))
153

154
    def prepare_config_and_inputs_for_common(self):
155
        config_and_inputs = self.prepare_config_and_inputs()
156
        (
157
            config,
158
            input_ids,
159
            input_mask,
160
            sequence_labels,
161
            token_labels,
162
            choice_labels,
163
        ) = config_and_inputs
164
        inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
165
        return config, inputs_dict
166

167

168
@require_torch
169
class EsmFoldModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
170
    test_mismatched_shapes = False
171

172
    all_model_classes = (EsmForProteinFolding,) if is_torch_available() else ()
173
    all_generative_model_classes = ()
174
    pipeline_model_mapping = {} if is_torch_available() else {}
175
    test_sequence_classification_problem_types = False
176

177
    def setUp(self):
178
        self.model_tester = EsmFoldModelTester(self)
179
        self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37)
180

181
    def test_config(self):
182
        self.config_tester.run_common_tests()
183

184
    def test_model(self):
185
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
186
        self.model_tester.create_and_check_model(*config_and_inputs)
187

188
    @unittest.skip("Does not support attention outputs")
189
    def test_attention_outputs(self):
190
        pass
191

192
    @unittest.skip
193
    def test_correct_missing_keys(self):
194
        pass
195

196
    @unittest.skip("Esm does not support embedding resizing")
197
    def test_resize_embeddings_untied(self):
198
        pass
199

200
    @unittest.skip("Esm does not support embedding resizing")
201
    def test_resize_tokens_embeddings(self):
202
        pass
203

204
    @unittest.skip("ESMFold does not support passing input embeds!")
205
    def test_inputs_embeds(self):
206
        pass
207

208
    @unittest.skip("ESMFold does not support head pruning.")
209
    def test_head_pruning(self):
210
        pass
211

212
    @unittest.skip("ESMFold does not support head pruning.")
213
    def test_head_pruning_integration(self):
214
        pass
215

216
    @unittest.skip("ESMFold does not support head pruning.")
217
    def test_head_pruning_save_load_from_config_init(self):
218
        pass
219

220
    @unittest.skip("ESMFold does not support head pruning.")
221
    def test_head_pruning_save_load_from_pretrained(self):
222
        pass
223

224
    @unittest.skip("ESMFold does not support head pruning.")
225
    def test_headmasking(self):
226
        pass
227

228
    @unittest.skip("ESMFold does not output hidden states in the normal way.")
229
    def test_hidden_states_output(self):
230
        pass
231

232
    @unittest.skip("ESMfold does not output hidden states in the normal way.")
233
    def test_retain_grad_hidden_states_attentions(self):
234
        pass
235

236
    @unittest.skip("ESMFold only has one output format.")
237
    def test_model_outputs_equivalence(self):
238
        pass
239

240
    @unittest.skip("This test doesn't work for ESMFold and doesn't test core functionality")
241
    def test_save_load_fast_init_from_base(self):
242
        pass
243

244
    @unittest.skip("ESMFold does not support input chunking.")
245
    def test_feed_forward_chunking(self):
246
        pass
247

248
    @unittest.skip("ESMFold doesn't respect you and it certainly doesn't respect your initialization arguments.")
249
    def test_initialization(self):
250
        pass
251

252
    @unittest.skip("ESMFold doesn't support torchscript compilation.")
253
    def test_torchscript_output_attentions(self):
254
        pass
255

256
    @unittest.skip("ESMFold doesn't support torchscript compilation.")
257
    def test_torchscript_output_hidden_state(self):
258
        pass
259

260
    @unittest.skip("ESMFold doesn't support torchscript compilation.")
261
    def test_torchscript_simple(self):
262
        pass
263

264
    @unittest.skip("ESMFold doesn't support data parallel.")
265
    def test_multi_gpu_data_parallel_forward(self):
266
        pass
267

268

269
@require_torch
270
class EsmModelIntegrationTest(TestCasePlus):
271
    @slow
272
    def test_inference_protein_folding(self):
273
        model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").float()
274
        model.eval()
275
        input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
276
        position_outputs = model(input_ids)["positions"]
277
        expected_slice = torch.tensor([2.5828, 0.7993, -10.9334], dtype=torch.float32)
278
        self.assertTrue(torch.allclose(position_outputs[0, 0, 0, 0], expected_slice, atol=1e-4))
279

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.