Give AlbumentationsX a star on GitHub — it powers this leaderboard

Star on GitHub

x-transformers

X-Transformers

Rank: #3719Downloads: 1,322,276 (30 days)

Description

x-transformers

PyPI version

A concise but fully-featured transformer, complete with a set of promising experimental features from various papers.

Install

$ pip install x-transformers

Usage

Full encoder / decoder

import torch
from x_transformers import XTransformer

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    tie_token_emb = True      # tie embeddings of encoder and decoder
)

src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))

loss = model(src, tgt, mask = src_mask) # (1, 1024, 512)
loss.backward()

Decoder-only (GPT-like)

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()

model(x) # (1, 1024, 20000)

GPT3 would be approximately the following (but you wouldn't be able to run it anyways)


gpt3 = TransformerWrapper(
    num_tokens = 50000,
    max_seq_len = 2048,
    attn_layers = Decoder(
        dim = 12288,
        depth = 96,
        heads = 96,
        attn_dim_head = 128
    )
).cuda()

Encoder-only (BERT-like)

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 1024, 20000)

State of the art image classification (<a href="https://arxiv.org/abs/2205.01580">SimpleViT</a>)

import torch
from x_transformers import ViTransformerWrapper, Encoder

model = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
    )
)

img = torch.randn(1, 3, 256, 256)
model(img) # (1, 1000)

Image -> caption

import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder

encoder = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

decoder = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        cross_attend = True
    )
)

img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))

encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)

<a href="https://arxiv.org/abs/2209.06794">PaLI</a>, state of the art language-vision model

import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder

# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)

vit = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

pali = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024
)

# training data

img = torch.randn(1, 3, 256, 256)               # images
prompt = torch.randint(0, 256, (1, 1024))       # prompt
prompt_mask = torch.ones(1, 1024).bool()        # prompt text mask
output_text = torch.randint(0, 256, (1, 1024))  # target output text

# train

img_embeds = vit(
    img,
    return_embeddings = True
)

loss = pali(
    prompt,
    output_text,
    mask = prompt_mask,
    src_prepend_embeds = img_embeds             # will preprend image embeddings to encoder text embeddings before attention
)

loss.backward()

# do the above for many steps on a 17B parameter model
# attention is all you need

Dropouts

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    emb_dropout = 0.1,         # dropout after embedding
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        layer_dropout = 0.1,   # stochastic depth - dropout entire layer
        attn_dropout = 0.1,    # dropout post-attention
        ff_dropout = 0.1       # feedforward dropout
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Features

Flash Attention

<img src="./images/flash-attention.png" width="500px"></img>

What originally started off as <a href="https://arxiv.org/abs/2112.05682">a short paper</a> from Markus Rabe culminated as a practical fused attention CUDA kernel, named <a href="https://arxiv.org/abs/2205.14135">Flash Attention</a> by <a href="https://tridao.me/">Tri Dao</a>.

The technique processes the attention matrix in tiles, only keeping track of the running softmax and exponentiated weighted sums. By recomputing on the backwards pass in a tiled fashion, one is able to keep the memory linear with respect to sequence length. This allows a lot of recent models to be able to reach for longer context lengths without worrying about the memory bottleneck.

Other engineering decisions made by Tri Dao led to its enormous success, namely minimizing HBM accesses so that both the forwards and backwards outperform naive attention. In other words, flash attention is not only more memory efficient, but faster as well, making it a necessity for training transformers.

MetaAI has recently added the ability to use <a href="https://github.com/hazyresearch/flash-attention">Tri Dao's CUDA kernel</a> through the <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">scaled_dot_product_attention</a> function in Pytorch 2.0. (They also have a mem_efficient attention, which is identical to flash attention design, just that the tiles are traversed differently)

<a href="https://ai.facebook.com/blog/large-language-model-llama-meta-ai/">Llama</a> was trained using Flash Attention. The only reason to avoid it is if you require operating on the attention matrix (dynamic positional bias, talking heads, residual attention).

You can use it in this repository by setting attn_flash to True and enjoy the immediate memory savings and increase in speed.

ex.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_flash = True # just set this to True if you have pytorch 2.0 installed
    )
)

Augmenting Self-attention with Persistent Memory

<img src="./images/all-attention.png" width="500px"></img>

https://arxiv.org/abs/1907.01470

Proposes adding learned memory key / values prior to attention. They were able to remove feedforwards altogether and attain similar performance to the original transformers. I have found that keeping the feedforwards and adding the memory key / values leads to even better performance.

from x_transformers import Decoder, Encoder

enc = Encoder(
    dim = 512,
    depth = 6,
    heads = 8,
    attn_num_mem_kv = 16 # 16 memory key / values
)

Memory Transformers

<img src="./images/memory-transformer.png" width="500px"></img>

https://arxiv.org/abs/2006.11527

Proposes adding learned tokens, akin to CLS tokens, named memory tokens, that is passed through the attention layers alongside the input tokens. This setting is compatible with both encoder and decoder training.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    num_memory_tokens = 20, # 20 memory tokens
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).

Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper.

Update 3: further corroborated by <a href="https://arxiv.org/abs/2501.00663">a paper</a> trying to extend memory in attention networks, termed persistent memory

Transformers Without Tears

<img src="./images/scalenorm.png"></img>

https://arxiv.org/abs/1910.05895

They experiment with alternatives to Layer normalization and found one that is both effective and simpler. Researchers have shared with me this leads to faster convergence.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_scalenorm = True # set to True to use for all layers
    )
)

You can also use the l2 normalized embeddings proposed as part of fixnorm. I have found it leads to improved convergence, when paired with small initialization (proposed by <a href="https://github.com/BlinkDL">BlinkDL</a>). The small in