x-transformers
X-Transformers
Description
x-transformers
A concise but fully-featured transformer, complete with a set of promising experimental features from various papers.
Install
$ pip install x-transformers
Usage
Full encoder / decoder
import torch
from x_transformers import XTransformer
model = XTransformer(
dim = 512,
enc_num_tokens = 256,
enc_depth = 6,
enc_heads = 8,
enc_max_seq_len = 1024,
dec_num_tokens = 256,
dec_depth = 6,
dec_heads = 8,
dec_max_seq_len = 1024,
tie_token_emb = True # tie embeddings of encoder and decoder
)
src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))
loss = model(src, tgt, mask = src_mask) # (1, 1024, 512)
loss.backward()
Decoder-only (GPT-like)
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
).cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
model(x) # (1, 1024, 20000)
GPT3 would be approximately the following (but you wouldn't be able to run it anyways)
gpt3 = TransformerWrapper(
num_tokens = 50000,
max_seq_len = 2048,
attn_layers = Decoder(
dim = 12288,
depth = 96,
heads = 96,
attn_dim_head = 128
)
).cuda()
Encoder-only (BERT-like)
import torch
from x_transformers import TransformerWrapper, Encoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 512,
depth = 12,
heads = 8
)
).cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()
model(x, mask = mask) # (1, 1024, 20000)
State of the art image classification (<a href="https://arxiv.org/abs/2205.01580">SimpleViT</a>)
import torch
from x_transformers import ViTransformerWrapper, Encoder
model = ViTransformerWrapper(
image_size = 256,
patch_size = 32,
num_classes = 1000,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8,
)
)
img = torch.randn(1, 3, 256, 256)
model(img) # (1, 1000)
Image -> caption
import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder
encoder = ViTransformerWrapper(
image_size = 256,
patch_size = 32,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
decoder = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
cross_attend = True
)
)
img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))
encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)
<a href="https://arxiv.org/abs/2209.06794">PaLI</a>, state of the art language-vision model
import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder
# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)
vit = ViTransformerWrapper(
image_size = 256,
patch_size = 32,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
pali = XTransformer(
dim = 512,
enc_num_tokens = 256,
enc_depth = 6,
enc_heads = 8,
enc_max_seq_len = 1024,
dec_num_tokens = 256,
dec_depth = 6,
dec_heads = 8,
dec_max_seq_len = 1024
)
# training data
img = torch.randn(1, 3, 256, 256) # images
prompt = torch.randint(0, 256, (1, 1024)) # prompt
prompt_mask = torch.ones(1, 1024).bool() # prompt text mask
output_text = torch.randint(0, 256, (1, 1024)) # target output text
# train
img_embeds = vit(
img,
return_embeddings = True
)
loss = pali(
prompt,
output_text,
mask = prompt_mask,
src_prepend_embeds = img_embeds # will preprend image embeddings to encoder text embeddings before attention
)
loss.backward()
# do the above for many steps on a 17B parameter model
# attention is all you need
Dropouts
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
emb_dropout = 0.1, # dropout after embedding
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
layer_dropout = 0.1, # stochastic depth - dropout entire layer
attn_dropout = 0.1, # dropout post-attention
ff_dropout = 0.1 # feedforward dropout
)
)
x = torch.randint(0, 20000, (1, 1024))
model(x)
Features
Flash Attention
<img src="./images/flash-attention.png" width="500px"></img>
What originally started off as <a href="https://arxiv.org/abs/2112.05682">a short paper</a> from Markus Rabe culminated as a practical fused attention CUDA kernel, named <a href="https://arxiv.org/abs/2205.14135">Flash Attention</a> by <a href="https://tridao.me/">Tri Dao</a>.
The technique processes the attention matrix in tiles, only keeping track of the running softmax and exponentiated weighted sums. By recomputing on the backwards pass in a tiled fashion, one is able to keep the memory linear with respect to sequence length. This allows a lot of recent models to be able to reach for longer context lengths without worrying about the memory bottleneck.
Other engineering decisions made by Tri Dao led to its enormous success, namely minimizing HBM accesses so that both the forwards and backwards outperform naive attention. In other words, flash attention is not only more memory efficient, but faster as well, making it a necessity for training transformers.
MetaAI has recently added the ability to use <a href="https://github.com/hazyresearch/flash-attention">Tri Dao's CUDA kernel</a> through the <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">scaled_dot_product_attention</a> function in Pytorch 2.0. (They also have a mem_efficient attention, which is identical to flash attention design, just that the tiles are traversed differently)
<a href="https://ai.facebook.com/blog/large-language-model-llama-meta-ai/">Llama</a> was trained using Flash Attention. The only reason to avoid it is if you require operating on the attention matrix (dynamic positional bias, talking heads, residual attention).
You can use it in this repository by setting attn_flash to True and enjoy the immediate memory savings and increase in speed.
ex.
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
attn_flash = True # just set this to True if you have pytorch 2.0 installed
)
)
Augmenting Self-attention with Persistent Memory
<img src="./images/all-attention.png" width="500px"></img>
https://arxiv.org/abs/1907.01470
Proposes adding learned memory key / values prior to attention. They were able to remove feedforwards altogether and attain similar performance to the original transformers. I have found that keeping the feedforwards and adding the memory key / values leads to even better performance.
from x_transformers import Decoder, Encoder
enc = Encoder(
dim = 512,
depth = 6,
heads = 8,
attn_num_mem_kv = 16 # 16 memory key / values
)
Memory Transformers
<img src="./images/memory-transformer.png" width="500px"></img>
https://arxiv.org/abs/2006.11527
Proposes adding learned tokens, akin to CLS tokens, named memory tokens, that is passed through the attention layers alongside the input tokens. This setting is compatible with both encoder and decoder training.
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
num_memory_tokens = 20, # 20 memory tokens
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).
Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper.
Update 3: further corroborated by <a href="https://arxiv.org/abs/2501.00663">a paper</a> trying to extend memory in attention networks, termed persistent memory
Transformers Without Tears
<img src="./images/scalenorm.png"></img>
https://arxiv.org/abs/1910.05895
They experiment with alternatives to Layer normalization and found one that is both effective and simpler. Researchers have shared with me this leads to faster convergence.
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
use_scalenorm = True # set to True to use for all layers
)
)
You can also use the l2 normalized embeddings proposed as part of fixnorm. I have found it leads to improved convergence, when paired with small initialization (proposed by <a href="https://github.com/BlinkDL">BlinkDL</a>). The small in