google-research
Efficient content-based sparse attention with Routing Transformers
Code-base accompanying the paper (to appear in TACL). See also the accompanying slides for a quick overview.
Table of Contents
- Updates
- Pre-trained PG-19 Checkpoint
- Explanation of hyperparameters
- Samples
- Acknowledgments
- How to Cite
Updates
- Routing Transformer + REALM is now SOTA on long form Question Answering (QA) on the ELI5 data-set on the Knowledge Intensive Language Tasks (KILT) benchmark from Facebook AI, with significant improvements in generation quality over BART, RAG, T5/Mesh TF , e.g. +4.11, +5.78, +9.14 Rouge-L improvement over T5/Mesh TF, BART + DPR and RAG respectively. Check out the source code and pre-trained model weights at Kalpesh's Github repository.
Pre-trained PG-19 Checkpoint
Model | Hparams | Context Length | Data-set | Vocab | Download |
---|---|---|---|---|---|
Local-base | pg19_local8k | 8192 | PG-19 | vocab98K | checkpoint.zip |
RT-base | pg19_local_cluster8k | 8192 | PG-19 | vocab98K | checkpoint.zip |
RT-base | pg19_local_cluster8k | 8192 | ELI-5 | vocab98K | checkpoint.zip |
Explanation of hyperparameters
Local Attention
local_num_heads
: Number of local attention headsquery_shape
: This represents the shape of the query block.- For 1-d local attention with block size
b
, this would be(b,)
- For 1-d local attention with block size
memory_query_shape
: This represents the query shape of memory antecedent and is useful for encoder-decoder attention- This is usually set the same as
query_shape
by default - This is useful when inputs and targets are of different lengths
- E.g., if inputs are of length
4096
and targets of length8192
- Plausible setting:
query_shape = (256,)
,memory_flange = (256,)
andmemory_query_shape = (128,)
- This is because with block size
256
, the targets will have32
blocks - To match this in enc-dec attention, the inputs must have
32
blocks - This is why we set
memory_query_shape = (4096/32,) = (128,)
- This is usually set the same as
memory_flange
: This represents the overlap of the memory block- Example setting:
query_shape = (b,)
andmemory_flange = (m * b, )
- Masked: Each query block attends to
m
previous blocks - Unmasked: Each query block attends to
m
previous &m
subsequent blocks - Setting this to
(0,)
means all the blocks are independent of each other - Setting to
(0,)
is used for full attention, or for axial attention - This must be a multiple of
query_shape
in every dimension
- Example setting:
- Example setting can be found in
sparse_transformer.py
underpg19_local8k
Routing Attention
sparsity_cluster_num_heads
: Number of routing attention headssparsity_cluster_size
: Number of clusterssparsity_cluster_attention_window
: Average size of each clustersparsity_skip_first
: Number of initial layers to skip routing attentionsparsity_skip_first = 0
would have routing attention in every layersparsity_skip_first
equalling total layers would have no routing attention
- Example setting can be found in
sparse_transformer.py
underpg19_local_cluster8k
Samples
PG-19 (sequence length 8k)
Unconditional Samples
Conditional Samples
Document Machine Translation (sequence length 4k)
Acknowledgments
The authors would like to thank Phillip Wang and Aran Komatsuzaki for a Pytorch implementation of Routing Transformer. The authors would also like to thank Yonghui Wu, Weikang Zhou and Dehao Chen for helpful feedback in improving the implementation of this work. The authors would also like to thank anonymous reviewers and the Action Editor Xavier Carreras of TACL for their constructive comments which helped improve the exposition of this work.
How to Cite
If you extend or use this work, please cite the paper where it was introduced:
@article{roy2020efficient,
title={Efficient content-based sparse attention with routing transformers},
author={Roy, Aurko and Saffar, Mohammad and Vaswani, Ashish and Grangier, David},
journal={arXiv preprint arXiv:2003.05997},
year={2020}
}