This is an unofficial PyTorch implementation created by Ignacio Oguiza (timeseriesAI@gmail.com) based on TST (Zerveas, 2020) and Transformer (Vaswani, 2017).

References:

This implementation is adapted to work with the rest of the tsai library, and contain some hyperparameters that are not available in the original implementation. I included them for experimenting.

Tips on how to use transformers:

  • In general, transformers require a lower lr compared to other time series models when used with the same datasets. It's important to use learn.lr_find() to learn what a good lr may be. In general, I've found lr between 1e-4 to 3e-4 work well.

  • The paper authors recommend to standardize data by feature. This can be done by adding TSStandardize(by_var=True) as a batch_tfm when creating the TSDataLoaders.

  • When using TST with a long time series, you may use max_w_len to reduce the memory size and thus avoid gpu issues. By default it's set to 512.

  • I've tried different types of positional encoders. In my experience, the default one works just fine.

Imports

Positional encoders

PositionalEncoding[source]

PositionalEncoding(q_len, d_model, normalize=True)

pe = PositionalEncoding(1000, 512).detach().cpu().numpy()
plt.pcolormesh(pe, cmap='viridis')
plt.title('PositionalEncoding')
plt.colorbar()
plt.show()
pe.mean(), pe.std(), pe.min(), pe.max(), pe.shape
(3.2037498e-10, 0.09999991, -0.18388666, 0.11518021, (1000, 512))

Coord2dPosEncoding[source]

Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=0.001, verbose=False, device=device(type='cpu'))

cpe = Coord2dPosEncoding(1000, 512, exponential=True, normalize=True).cpu().numpy()
plt.pcolormesh(cpe, cmap='viridis')
plt.title('Coord2dPosEncoding')
plt.colorbar()
plt.show()
plt.plot(cpe.mean(0))
plt.show()
plt.plot(cpe.mean(1))
plt.show()
cpe.mean(), cpe.std(), cpe.min(), cpe.max()
(3.695488e-09, 0.09999991, -0.22459325, 0.22487777)

Coord1dPosEncoding[source]

Coord1dPosEncoding(q_len, exponential=False, normalize=True, device=device(type='cpu'))

cpe = Coord1dPosEncoding(1000, exponential=True, normalize=True).detach().cpu().numpy()
plt.pcolormesh(cpe, cmap='viridis')
plt.title('Coord1dPosEncoding')
plt.colorbar()
plt.show()
plt.plot(cpe.mean(1))
plt.show()
cpe.mean(), cpe.std(), cpe.min(), cpe.max(), cpe.shape
(0.0, 0.099949986, -0.2820423, 0.14113107, (1000, 1))
cpe = Coord1dPosEncoding(1000, exponential=True, normalize=True).detach().cpu().numpy()
plt.pcolormesh(cpe, cmap='viridis')
plt.title('Coord1dPosEncoding')
plt.colorbar()
plt.show()
plt.plot(cpe.mean(1))
plt.show()
cpe.mean(), cpe.std(), cpe.min(), cpe.max()
(0.0, 0.099949986, -0.2820423, 0.14113107)

TST

class ScaledDotProductAttention[source]

ScaledDotProductAttention(d_k:int, res_attention:bool=False) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

B = 16
C = 3
H = 1
D = 128
M = 1500
N = 512
d_k = D // N

xb = torch.randn(B, C, M)

# Attention
# q
lin = nn.Linear(M, N, bias=False)
Q = lin(xb).transpose(1,2)
to_q = nn.Linear(C, D, bias=False)
q = to_q(Q)

# k, v
context = xb.transpose(1,2)
to_kv = nn.Linear(C, D * 2, bias=False)
k, v = to_kv(context).chunk(2, dim = -1)
k = k.transpose(-1, -2)
q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)

output, attn = ScaledDotProductAttention(d_k=d_k)(q, k, v)
output.shape, attn.shape, Q.shape, q.shape, k.shape, v.shape
(torch.Size([16, 1, 512, 128]),
 torch.Size([16, 1, 512, 1500]),
 torch.Size([16, 512, 3]),
 torch.Size([16, 1, 512, 128]),
 torch.Size([16, 1, 128, 1500]),
 torch.Size([16, 1, 1500, 128]))
q = torch.rand([16, 3, 50, 8]) 
k = torch.rand([16, 3, 50, 8]).transpose(-1, -2)
v = torch.rand([16, 3, 50, 6])
attn_mask = torch.triu(torch.ones(50, 50)) # shape: q_len x q_len
key_padding_mask = torch.zeros(16, 50)
key_padding_mask[[1, 3, 6, 15], -10:] = 1
key_padding_mask = key_padding_mask.bool()
output, attn = ScaledDotProductAttention(d_k=8)(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
output.shape, attn.shape
(torch.Size([16, 3, 50, 6]), torch.Size([16, 3, 50, 50]))

class MultiHeadAttention[source]

MultiHeadAttention(d_model:int, n_heads:int, d_k:int, d_v:int, res_attention:bool=False) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

t = torch.rand(16, 50, 128)
output, attn = MultiHeadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)(t, t, t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
output.shape, attn.shape
(torch.Size([16, 50, 128]), torch.Size([16, 3, 50, 50]))
t = torch.rand(16, 50, 128)
att_mask = (torch.rand((50, 50)) > .85).float()
att_mask[att_mask == 1] = -float("Inf")

mha = MultiHeadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)
output, attn = mha(t, t, t, attn_mask=att_mask)
test_eq(torch.isnan(output).sum().item(), 0)
test_eq(torch.isnan(attn).sum().item(), 0)
loss = output[:2, :].sum()
test_eq(torch.isnan(loss).sum().item(), 0)
loss.backward()
for n, p in mha.named_parameters(): test_eq(torch.isnan(p.grad).sum().item(), 0)
t = torch.rand(16, 50, 128)
attn_mask = (torch.rand((50, 50)) > .85)

# True values will be masked
mha = MultiHeadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)
output, attn = mha(t, t, t, attn_mask=att_mask)
test_eq(torch.isnan(output).sum().item(), 0)
test_eq(torch.isnan(attn).sum().item(), 0)
loss = output[:2, :].sum()
test_eq(torch.isnan(loss).sum().item(), 0)
loss.backward()
for n, p in mha.named_parameters(): test_eq(torch.isnan(p.grad).sum().item(), 0)
t = torch.rand(16, 50, 128)
encoder = _TSTEncoderLayer(q_len=50, d_model=128, n_heads=8, d_k=None, d_v=None, d_ff=512, res_dropout=0.1, activation='gelu')
output = encoder(t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
output.shape
torch.Size([16, 50, 128])
cmap='viridis'
figsize=(6,5)
plt.figure(figsize=figsize)
plt.pcolormesh(encoder.attn[0][0].detach().cpu().numpy(), cmap=cmap)
plt.title('Self-attention map')
plt.colorbar()
plt.show()

class TSTPlus[source]

TSTPlus(c_in:int, c_out:int, seq_len:int, max_seq_len:Optional[int]=512, n_layers:int=3, d_model:int=128, n_heads:int=16, d_k:Optional[int]=None, d_v:Optional[int]=None, d_ff:int=256, res_dropout:float=0.0, act:str='gelu', key_padding_mask:bool=True, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, pe:str='zeros', learn_pe:bool=True, flatten:bool=False, fc_dropout:float=0.0, concat_pool:bool=False, bn:bool=True, custom_head:Optional=None, y_range:Optional[tuple]=None, verbose:bool=False, **kwargs) :: Sequential

TST (Time Series Transformer) is a Transformer that takes continuous time series as inputs

from tsai.models.utils import build_ts_model

bs = 8
c_in = 9  # aka channels, features, variables, dimensions
c_out = 2
seq_len = 1_500

xb = torch.randn(bs, c_in, seq_len).to(device)

# standardize by channel by_var based on the training set
xb = (xb - xb.mean((0, 2), keepdim=True)) / xb.std((0, 2), keepdim=True)

# Settings
max_seq_len = 256
d_model = 128
n_heads = 16
d_k = d_v = None  # if None --> d_model // n_heads
d_ff = 256
res_dropout = 0.1
activation = "gelu"
n_layers = 3
fc_dropout = 0.1
pe = None
learn_pe = True
kwargs = {}

model = TSTPlus(c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads,
                d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, activation=activation, n_layers=n_layers,
                fc_dropout=fc_dropout, pe=pe, learn_pe=learn_pe, **kwargs).to(device)
test_eq(model(xb).shape, [bs, c_out])
test_eq(model[0], model.backbone)
test_eq(model[1], model.head)
model2 = build_ts_model(TSTPlus, c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads,
                           d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, activation=activation, n_layers=n_layers,
                           fc_dropout=fc_dropout, pe=pe, learn_pe=learn_pe, **kwargs).to(device)
test_eq(model2(xb).shape, [bs, c_out])
test_eq(model2[0], model2.backbone)
test_eq(model2[1], model2.head)
print(f'model parameters: {count_parameters(model)}')
model parameters: 404992
model = TSTPlus(c_in, c_out, seq_len, pre_norm=True).to(device)
test_eq(model(xb).shape, [bs, c_out])
bs = 8
c_in = 9  # aka channels, features, variables, dimensions
c_out = 2
seq_len = 5000

xb = torch.randn(bs, c_in, seq_len)

# standardize by channel by_var based on the training set
xb = (xb - xb.mean((0, 2), keepdim=True)) / xb.std((0, 2), keepdim=True)

model = TSTPlus(c_in, c_out, seq_len, res_attention=True)
test_eq(model(xb).shape, [bs, c_out])
print(f'model parameters: {count_parameters(model)}')
model parameters: 478208
custom_head = partial(create_pool_head, concat_pool=True)
model = TSTPlus(c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads,
            d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, activation=activation, n_layers=n_layers,
            fc_dropout=fc_dropout, pe=pe, learn_pe=learn_pe, flatten=False, custom_head=custom_head, **kwargs)
test_eq(model(xb).shape, [bs, c_out])
print(f'model parameters: {count_parameters(model)}')
model parameters: 421122
custom_head = partial(create_pool_plus_head, concat_pool=True)
model = TSTPlus(c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads,
            d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, activation=activation, n_layers=n_layers,
            fc_dropout=fc_dropout, pe=pe, learn_pe=learn_pe, flatten=False, custom_head=custom_head, **kwargs)
test_eq(model(xb).shape, [bs, c_out])
print(f'model parameters: {count_parameters(model)}')
model parameters: 554240
bs = 8
c_in = 9  # aka channels, features, variables, dimensions
c_out = 2
seq_len = 60

xb = torch.randn(bs, c_in, seq_len)

# standardize by channel by_var based on the training set
xb = (xb - xb.mean((0, 2), keepdim=True)) / xb.std((0, 2), keepdim=True)

# Settings
max_seq_len = 120
d_model = 128
n_heads = 16
d_k = d_v = None # if None --> d_model // n_heads
d_ff = 256
res_dropout = 0.1
act = "gelu"
n_layers = 3
fc_dropout = 0.1
pe='zeros'
learn_pe=True
kwargs = {}
# kwargs = dict(kernel_size=5, padding=2)

model = TSTPlus(c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads,
            d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, act=act, n_layers=n_layers,
            fc_dropout=fc_dropout, pe=pe, learn_pe=learn_pe, **kwargs)
test_eq(model(xb).shape, [bs, c_out])
print(f'model parameters: {count_parameters(model)}')
body, head = model[0], model[1]
test_eq(body(xb).ndim, 3)
test_eq(head(body(xb)).ndim, 2)
head
model parameters: 404560
Sequential(
  (0): GAP1d(
    (gap): AdaptiveAvgPool1d(output_size=1)
    (flatten): Flatten(full=False)
  )
  (1): LinBnDrop(
    (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): Linear(in_features=128, out_features=2, bias=False)
  )
)
model.show_pe()
model = TSTPlus(3, 2, 10)
xb = torch.randn(4, 3, 10)
yb = torch.randint(0, 2, (4,))
test_eq(model.backbone._key_padding_mask(xb)[1], None)
random_idxs = np.random.choice(len(xb), 2, False)
xb[random_idxs, :, -5:] = float('nan')
xb[random_idxs, 0, 1] = float('nan')
test_eq(model.backbone._key_padding_mask(xb.clone())[1].data, (torch.isnan(xb).float().mean(1)==1).bool())
test_eq(model.backbone._key_padding_mask(xb.clone())[1].data.shape, (4,10))
print(torch.isnan(xb).sum())
pred = model(xb.clone())
loss = CrossEntropyLossFlat()(pred, yb)
loss.backward()
torch.isnan(xb), model.backbone._key_padding_mask(xb)[1].data
tensor(32)
(tensor([[[False, False, False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False]],
 
         [[False, False, False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False]],
 
         [[False,  True, False, False, False,  True,  True,  True,  True,  True],
          [False, False, False, False, False,  True,  True,  True,  True,  True],
          [False, False, False, False, False,  True,  True,  True,  True,  True]],
 
         [[False,  True, False, False, False,  True,  True,  True,  True,  True],
          [False, False, False, False, False,  True,  True,  True,  True,  True],
          [False, False, False, False, False,  True,  True,  True,  True,  True]]]),
 tensor([[False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False,  True,  True,  True,  True,  True],
         [False, False, False, False, False,  True,  True,  True,  True,  True]]))

class MultiTSTPlus[source]

MultiTSTPlus(feat_list, c_out, seq_len, max_seq_len:Optional[int]=512, custom_head=None, n_layers:int=3, d_model:int=128, n_heads:int=16, d_k:Optional[int]=None, d_v:Optional[int]=None, d_ff:int=256, res_dropout:float=0.0, act:str='gelu', key_padding_mask:bool=True, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, pe:str='zeros', learn_pe:bool=True, flatten:bool=False, fc_dropout:float=0.0, concat_pool:bool=False, bn:bool=True, y_range:Optional[tuple]=None, verbose:bool=False) :: Sequential

A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

To make it easier to understand, here is a small example::

# Example of using Sequential
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
bs = 8
c_in = 7  # aka channels, features, variables, dimensions
c_out = 2
seq_len = 10
xb2 = torch.randn(bs, c_in, seq_len)
model1 = MultiTSTPlus([2, 5], c_out, seq_len, )
model2 = MultiTSTPlus(7, c_out, seq_len)
test_eq(model1(xb2).shape, (bs, c_out))
test_eq(model1(xb2).shape, model2(xb2).shape)
test_eq(count_parameters(model1) > count_parameters(model2), True)
bs = 8
c_in = 7  # aka channels, features, variables, dimensions
c_out = 2
seq_len = 10
xb2 = torch.randn(bs, c_in, seq_len)
model1 = MultiTSTPlus([2, 5], c_out, seq_len, )
model2 = MultiTSTPlus([[0,2,5], [0,1,3,4,6]], c_out, seq_len)
test_eq(model1(xb2).shape, (bs, c_out))
test_eq(model1(xb2).shape, model2(xb2).shape)
model1 = MultiTSTPlus([2, 5], c_out, seq_len, y_range=(0.5, 5.5))
body, head = split_model(model1)
test_eq(body(xb2).ndim, 3)
test_eq(head(body(xb2)).ndim, 2)
head
Sequential(
  (0): Sequential(
    (0): Flatten(full=False)
    (1): LinBnDrop(
      (0): Linear(in_features=2560, out_features=2, bias=True)
    )
  )
)
model = MultiTSTPlus([2, 5], c_out, seq_len, pre_norm=True)
bs = 8
n_vars = 3
seq_len = 12
c_out = 2
xb = torch.rand(bs, n_vars, seq_len)
net = MultiTSTPlus(n_vars, c_out, seq_len)
change_model_head(net, create_pool_plus_head, concat_pool=False)
print(net(xb).shape)
net.head
torch.Size([8, 2])
Sequential(
  (0): AdaptiveAvgPool1d(output_size=1)
  (1): Flatten(full=False)
  (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Linear(in_features=128, out_features=512, bias=False)
  (4): ReLU(inplace=True)
  (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Linear(in_features=512, out_features=2, bias=False)
)
bs = 8
n_vars = 3
seq_len = 12
c_out = 10
xb = torch.rand(bs, n_vars, seq_len)
new_head = partial(conv_lin_3d_head, d=(5 ,2))
net = MultiTSTPlus(n_vars, c_out, seq_len, custom_head=new_head)
print(net(xb).shape)
net.head
torch.Size([8, 5, 2])
Sequential(
  (0): create_conv_lin_3d_head(
    (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Conv1d(128, 5, kernel_size=(1,), stride=(1,), bias=False)
    (2): Transpose(-1, -2)
    (3): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Transpose(-1, -2)
    (5): Linear(in_features=12, out_features=2, bias=False)
  )
)