transformers
/
hubconf.py
162 строки · 8.5 Кб
1# Copyright 2020 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16import sys
17
18
19SRC_DIR = os.path.join(os.path.dirname(__file__), "src")
20sys.path.append(SRC_DIR)
21
22
23from transformers import (
24AutoConfig,
25AutoModel,
26AutoModelForCausalLM,
27AutoModelForMaskedLM,
28AutoModelForQuestionAnswering,
29AutoModelForSequenceClassification,
30AutoTokenizer,
31add_start_docstrings,
32)
33
34
35dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata", "huggingface_hub"]
36
37
38@add_start_docstrings(AutoConfig.__doc__)
39def config(*args, **kwargs):
40r"""
41# Using torch.hub !
42import torch
43
44config = torch.hub.load('huggingface/transformers', 'config', 'google-bert/bert-base-uncased') # Download configuration from huggingface.co and cache.
45config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
46config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/my_configuration.json')
47config = torch.hub.load('huggingface/transformers', 'config', 'google-bert/bert-base-uncased', output_attentions=True, foo=False)
48assert config.output_attentions == True
49config, unused_kwargs = torch.hub.load('huggingface/transformers', 'config', 'google-bert/bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True)
50assert config.output_attentions == True
51assert unused_kwargs == {'foo': False}
52
53"""
54
55return AutoConfig.from_pretrained(*args, **kwargs)
56
57
58@add_start_docstrings(AutoTokenizer.__doc__)
59def tokenizer(*args, **kwargs):
60r"""
61# Using torch.hub !
62import torch
63
64tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', 'google-bert/bert-base-uncased') # Download vocabulary from huggingface.co and cache.
65tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', './test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
66
67"""
68
69return AutoTokenizer.from_pretrained(*args, **kwargs)
70
71
72@add_start_docstrings(AutoModel.__doc__)
73def model(*args, **kwargs):
74r"""
75# Using torch.hub !
76import torch
77
78model = torch.hub.load('huggingface/transformers', 'model', 'google-bert/bert-base-uncased') # Download model and configuration from huggingface.co and cache.
79model = torch.hub.load('huggingface/transformers', 'model', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
80model = torch.hub.load('huggingface/transformers', 'model', 'google-bert/bert-base-uncased', output_attentions=True) # Update configuration during loading
81assert model.config.output_attentions == True
82# Loading from a TF checkpoint file instead of a PyTorch model (slower)
83config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
84model = torch.hub.load('huggingface/transformers', 'model', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
85
86"""
87
88return AutoModel.from_pretrained(*args, **kwargs)
89
90
91@add_start_docstrings(AutoModelForCausalLM.__doc__)
92def modelForCausalLM(*args, **kwargs):
93r"""
94# Using torch.hub !
95import torch
96
97model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'openai-community/gpt2') # Download model and configuration from huggingface.co and cache.
98model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
99model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'openai-community/gpt2', output_attentions=True) # Update configuration during loading
100assert model.config.output_attentions == True
101# Loading from a TF checkpoint file instead of a PyTorch model (slower)
102config = AutoConfig.from_pretrained('./tf_model/gpt_tf_model_config.json')
103model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './tf_model/gpt_tf_checkpoint.ckpt.index', from_tf=True, config=config)
104
105"""
106return AutoModelForCausalLM.from_pretrained(*args, **kwargs)
107
108
109@add_start_docstrings(AutoModelForMaskedLM.__doc__)
110def modelForMaskedLM(*args, **kwargs):
111r"""
112# Using torch.hub !
113import torch
114
115model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'google-bert/bert-base-uncased') # Download model and configuration from huggingface.co and cache.
116model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
117model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'google-bert/bert-base-uncased', output_attentions=True) # Update configuration during loading
118assert model.config.output_attentions == True
119# Loading from a TF checkpoint file instead of a PyTorch model (slower)
120config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
121model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
122
123"""
124
125return AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
126
127
128@add_start_docstrings(AutoModelForSequenceClassification.__doc__)
129def modelForSequenceClassification(*args, **kwargs):
130r"""
131# Using torch.hub !
132import torch
133
134model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'google-bert/bert-base-uncased') # Download model and configuration from huggingface.co and cache.
135model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
136model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'google-bert/bert-base-uncased', output_attentions=True) # Update configuration during loading
137assert model.config.output_attentions == True
138# Loading from a TF checkpoint file instead of a PyTorch model (slower)
139config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
140model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
141
142"""
143
144return AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
145
146
147@add_start_docstrings(AutoModelForQuestionAnswering.__doc__)
148def modelForQuestionAnswering(*args, **kwargs):
149r"""
150# Using torch.hub !
151import torch
152
153model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'google-bert/bert-base-uncased') # Download model and configuration from huggingface.co and cache.
154model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
155model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'google-bert/bert-base-uncased', output_attentions=True) # Update configuration during loading
156assert model.config.output_attentions == True
157# Loading from a TF checkpoint file instead of a PyTorch model (slower)
158config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
159model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
160
161"""
162return AutoModelForQuestionAnswering.from_pretrained(*args, **kwargs)
163