Utility functions used to build PyTorch timeseries models.

get_layers[source]

get_layers(model, cond=noop, full=True)

is_layer[source]

is_layer(*args)

is_linear[source]

is_linear(l)

is_bn[source]

is_bn(l)

is_conv_linear[source]

is_conv_linear(l)

is_affine_layer[source]

is_affine_layer(l)

is_conv[source]

is_conv(l)

has_bias[source]

has_bias(l)

has_weight[source]

has_weight(l)

has_weight_or_bias[source]

has_weight_or_bias(l)

check_bias[source]

check_bias(m, cond=noop, verbose=False)

check_weight[source]

check_weight(m, cond=noop, verbose=False)

get_nf[source]

get_nf(m)

Get nf from model's first linear layer in head

ts_splitter[source]

ts_splitter(m)

Split of a model between body and head

transfer_weights[source]

transfer_weights(model, weights_path:Path, device:device=None, exclude_head:bool=True)

Utility function that allows to easily transfer weights between models. Taken from the great self-supervised repository created by Kerem Turgutlu. https://github.com/KeremTurgutlu/self_supervised/blob/d87ebd9b4961c7da0efd6073c42782bbc61aaa2e/self_supervised/utils.py

build_ts_model[source]

build_ts_model(arch, c_in=None, c_out=None, seq_len=None, d=None, dls=None, device=None, verbose=False, pretrained=False, weights_path=None, exclude_head=True, cut=-1, init=None, **kwargs)

build_tabular_model[source]

build_tabular_model(arch, dls, layers=None, emb_szs=None, n_out=None, y_range=None, device=None, ps=None, embed_p=0.0, use_bn=True, bn_final=False, bn_cont=True, act_cls=ReLU(inplace=True), lin_first=True)

build_tsimage_model[source]

build_tsimage_model(arch, c_in=None, c_out=None, dls=None, pretrained=False, device=None, verbose=False, init=None, p=0.0, n_out=1000, stem_szs=(32, 32, 64), widen=1.0, sa=False, act_cls=ReLU, ndim=2, ks=3, stride=2, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1, sym=False, norm_type=<NormType.Batch: 1>, pool=AvgPool, pool_first=True, padding=None, bias=None, bn_1st=True, transpose=False, xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, padding_mode:str='zeros')

count_parameters[source]

count_parameters(model, trainable=True)

from tsai.data.external import get_UCR_data
from tsai.data.features import get_ts_features
dsid = 'NATOPS'
X, y, splits = get_UCR_data(dsid, split_data=False)
ts_features_df = get_ts_features(X, y)
Feature Extraction: 100%|██████████| 40/40 [00:04<00:00,  8.38it/s]
from tsai.data.tabular import get_tabular_dls
from tsai.models.TabModel import TabModel
cat_names = None
cont_names = ts_features_df.columns[:-2]
y_names = 'target'
tab_dls = get_tabular_dls(ts_features_df, cat_names=cat_names, cont_names=cont_names, y_names=y_names, splits=splits)
tab_model = build_tabular_model(TabModel, dls=tab_dls)
b = first(tab_dls.train)
test_eq(tab_model(*b[:-1]).shape, (64,6))
a = 'MLSTM_FCN'
if sum([1 for v in ['RNN_FCN', 'LSTM_FCN', 'GRU_FCN', 'OmniScaleCNN', 'Transformer', 'mWDN'] if v in a]): print(1)
1

get_clones[source]

get_clones(module, N)

m = nn.Conv1d(3,4,3)
get_clones(m, 3)
ModuleList(
  (0): Conv1d(3, 4, kernel_size=(3,), stride=(1,))
  (1): Conv1d(3, 4, kernel_size=(3,), stride=(1,))
  (2): Conv1d(3, 4, kernel_size=(3,), stride=(1,))
)

split_model[source]

split_model(m)

seq_len_calculator[source]

seq_len_calculator(seq_len, **kwargs)

seq_len = 345
kwargs = dict(kernel_size=5, stride=5)
seq_len_calculator(seq_len, **kwargs)
69

change_model_head[source]

change_model_head(model, custom_head, **kwargs)

Replaces a model's head by a custom head as long as the model has a head, head_nf, c_out and seq_len attributes

naive_forecaster[source]

naive_forecaster(o, split, horizon=1)

true_forecaster[source]

true_forecaster(o, split, horizon=1)

a = np.random.rand(20).cumsum()
split = np.arange(10, 20)
a, naive_forecaster(a, split, 1), true_forecaster(a, split, 1)
(array([0.38499951, 0.49040357, 0.67978685, 0.73094369, 0.75843037,
        1.16923118, 1.65045468, 2.6244671 , 3.54431571, 3.88991105,
        4.23980872, 4.27976719, 5.13150789, 5.75946097, 5.76437458,
        6.49241016, 7.32054377, 7.35596656, 8.19922808, 9.1472387 ]),
 array([3.88991105, 4.23980872, 4.27976719, 5.13150789, 5.75946097,
        5.76437458, 6.49241016, 7.32054377, 7.35596656, 8.19922808]),
 array([4.23980872, 4.27976719, 5.13150789, 5.75946097, 5.76437458,
        6.49241016, 7.32054377, 7.35596656, 8.19922808, 9.1472387 ]))