Callback to apply noisy student self-training (a semi-supervised learning approach) based on: Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10687-10698).

class NoisyStudent[source]

NoisyStudent(dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True, pseudolabel_sample_weight:float=1.0, verbose=False) :: Callback

A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise:

- stochastic depth: .8
- RandAugment: N=2, M=27
- dropout: .5

Steps:

1. Build the dl you will use as a teacher
2. Create dl2 with the pseudolabels (either soft or hard preds)
3. Pass any required batch_tfms to the callback
from tsai.data.all import *
from tsai.models.all import *
from tsai.tslearner import *
dsid = 'NATOPS'
X, y, splits = get_UCR_data(dsid, return_split=False)
pseudolabeled_data = X
soft_preds = True

pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)
dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)
dl2 = TSDataLoader(dsets2, num_workers=0)
noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)
learn = TSClassifier(X, y, splits=splits, batch_tfms=[TSStandardize(), TSRandomSize(.5)], cbs=noisy_student_cb)
learn.fit_one_cycle(1)
labels / pseudolabels per training batch              : 171 / 85
relative labeled/ pseudolabel sample weight in dataset: 4.0
epoch train_loss valid_loss accuracy time
0 1.973474 1.819007 0.100000 00:08
X: (171, 24, 51)  X2: (85, 24, 51)  X_comb: (256, 24, 43)
y: (171,)  y2: torch.Size([85])  y_comb: (256,)
pseudolabeled_data = X
soft_preds = False

pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)
dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)
dl2 = TSDataLoader(dsets2, num_workers=0)
noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)
learn = TSClassifier(X, y, splits=splits, batch_tfms=[TSStandardize(), TSRandomSize(.5)], cbs=noisy_student_cb)
learn.fit_one_cycle(1)
labels / pseudolabels per training batch              : 171 / 85
relative labeled/ pseudolabel sample weight in dataset: 4.0
epoch train_loss valid_loss accuracy time
0 1.798368 1.785091 0.166667 00:10
X: (171, 24, 51)  X2: (85, 24, 51)  X_comb: (256, 24, 61)
y: torch.Size([171, 6])  y2: torch.Size([85, 6])  y_comb: torch.Size([256, 6])