transformers

Форк
0
/
hubconf.py 
162 строки · 8.5 Кб
1
# Copyright 2020 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import sys
17

18

19
SRC_DIR = os.path.join(os.path.dirname(__file__), "src")
20
sys.path.append(SRC_DIR)
21

22

23
from transformers import (
24
    AutoConfig,
25
    AutoModel,
26
    AutoModelForCausalLM,
27
    AutoModelForMaskedLM,
28
    AutoModelForQuestionAnswering,
29
    AutoModelForSequenceClassification,
30
    AutoTokenizer,
31
    add_start_docstrings,
32
)
33

34

35
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata", "huggingface_hub"]
36

37

38
@add_start_docstrings(AutoConfig.__doc__)
39
def config(*args, **kwargs):
40
    r"""
41
                # Using torch.hub !
42
                import torch
43

44
                config = torch.hub.load('huggingface/transformers', 'config', 'google-bert/bert-base-uncased')  # Download configuration from huggingface.co and cache.
45
                config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
46
                config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/my_configuration.json')
47
                config = torch.hub.load('huggingface/transformers', 'config', 'google-bert/bert-base-uncased', output_attentions=True, foo=False)
48
                assert config.output_attentions == True
49
                config, unused_kwargs = torch.hub.load('huggingface/transformers', 'config', 'google-bert/bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True)
50
                assert config.output_attentions == True
51
                assert unused_kwargs == {'foo': False}
52

53
            """
54

55
    return AutoConfig.from_pretrained(*args, **kwargs)
56

57

58
@add_start_docstrings(AutoTokenizer.__doc__)
59
def tokenizer(*args, **kwargs):
60
    r"""
61
        # Using torch.hub !
62
        import torch
63

64
        tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', 'google-bert/bert-base-uncased')    # Download vocabulary from huggingface.co and cache.
65
        tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', './test/bert_saved_model/')  # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
66

67
    """
68

69
    return AutoTokenizer.from_pretrained(*args, **kwargs)
70

71

72
@add_start_docstrings(AutoModel.__doc__)
73
def model(*args, **kwargs):
74
    r"""
75
            # Using torch.hub !
76
            import torch
77

78
            model = torch.hub.load('huggingface/transformers', 'model', 'google-bert/bert-base-uncased')    # Download model and configuration from huggingface.co and cache.
79
            model = torch.hub.load('huggingface/transformers', 'model', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
80
            model = torch.hub.load('huggingface/transformers', 'model', 'google-bert/bert-base-uncased', output_attentions=True)  # Update configuration during loading
81
            assert model.config.output_attentions == True
82
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
83
            config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
84
            model = torch.hub.load('huggingface/transformers', 'model', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
85

86
        """
87

88
    return AutoModel.from_pretrained(*args, **kwargs)
89

90

91
@add_start_docstrings(AutoModelForCausalLM.__doc__)
92
def modelForCausalLM(*args, **kwargs):
93
    r"""
94
        # Using torch.hub !
95
        import torch
96

97
        model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'openai-community/gpt2')    # Download model and configuration from huggingface.co and cache.
98
        model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
99
        model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'openai-community/gpt2', output_attentions=True)  # Update configuration during loading
100
        assert model.config.output_attentions == True
101
        # Loading from a TF checkpoint file instead of a PyTorch model (slower)
102
        config = AutoConfig.from_pretrained('./tf_model/gpt_tf_model_config.json')
103
        model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './tf_model/gpt_tf_checkpoint.ckpt.index', from_tf=True, config=config)
104

105
    """
106
    return AutoModelForCausalLM.from_pretrained(*args, **kwargs)
107

108

109
@add_start_docstrings(AutoModelForMaskedLM.__doc__)
110
def modelForMaskedLM(*args, **kwargs):
111
    r"""
112
            # Using torch.hub !
113
            import torch
114

115
            model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'google-bert/bert-base-uncased')    # Download model and configuration from huggingface.co and cache.
116
            model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
117
            model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'google-bert/bert-base-uncased', output_attentions=True)  # Update configuration during loading
118
            assert model.config.output_attentions == True
119
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
120
            config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
121
            model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
122

123
        """
124

125
    return AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
126

127

128
@add_start_docstrings(AutoModelForSequenceClassification.__doc__)
129
def modelForSequenceClassification(*args, **kwargs):
130
    r"""
131
            # Using torch.hub !
132
            import torch
133

134
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'google-bert/bert-base-uncased')    # Download model and configuration from huggingface.co and cache.
135
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
136
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'google-bert/bert-base-uncased', output_attentions=True)  # Update configuration during loading
137
            assert model.config.output_attentions == True
138
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
139
            config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
140
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
141

142
        """
143

144
    return AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
145

146

147
@add_start_docstrings(AutoModelForQuestionAnswering.__doc__)
148
def modelForQuestionAnswering(*args, **kwargs):
149
    r"""
150
        # Using torch.hub !
151
        import torch
152

153
        model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'google-bert/bert-base-uncased')    # Download model and configuration from huggingface.co and cache.
154
        model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
155
        model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'google-bert/bert-base-uncased', output_attentions=True)  # Update configuration during loading
156
        assert model.config.output_attentions == True
157
        # Loading from a TF checkpoint file instead of a PyTorch model (slower)
158
        config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
159
        model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
160

161
    """
162
    return AutoModelForQuestionAnswering.from_pretrained(*args, **kwargs)
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.