Skip to content

pseudo_labels

Pseudo labeling is a simple approach that takes a model trained on the source domain to label the target domain. Afterward, the model training is continued on the combined source and target data. This process is repeated until the validation loss converges.

Used In
  • Wang et al. (2022). Residual Life Prediction of Bearings Based on SENet-TCN and Transfer Learning. IEEE Access, 10, 10.1109/ACCESS.2022.3223387

Examples:

import torch
import rul_datasets
import pytorch_lightning as pl

from rul_adapt import model
from rul_adapt import approach

feat_ex = model.CnnExtractor(14, [16], 30, fc_units=16)
reg = model.FullyConnectedHead(16, [1])

supervised = approach.SupervisedApproach(0.001, "rmse", "adam")
supervised.set_model(feat_ex, reg)

fd1 = rul_datasets.RulDataModule(rul_datasets.CmapssReader(1), 32)
fd1.setup()
fd3 = rul_datasets.RulDataModule(rul_datasets.CmapssReader(3), 32)
fd3.setup()

trainer = pl.Trainer(max_epochs=10)
trainer.fit(supervised, fd3)

pseudo_labels = approach.generate_pseudo_labels(fd1, supervised)
pseudo_labels = [max(0, min(125, pl)) for pl in pseudo_labels]
approach.patch_pseudo_labels(fd3, pseudo_labels)

combined_data = torch.utils.data.ConcatDataset(
    [fd1.to_dataset("dev"), fd3.to_dataset("dev")]
)
combined_dl = torch.utils.data.DataLoader(combined_data, batch_size=32)

trainer = pl.Trainer(max_epochs=10)
trainer.fit(supervised, train_dataloader=combined_dl)

generate_pseudo_labels(dm, model, inductive=False)

Generate pseudo labels for the dev set of a data module.

The pseudo labels are generated for the last timestep of each run. They are returned raw and may therefore contain values bigger than max_rul or negative values. It is recommended to clip them to zero and max_rul respectively before using them to patch a reader.

The model is assumed to reside on the CPU where the calculation will be performed.

Parameters:

Name Type Description Default
dm RulDataModule

The data module to generate pseudo labels for.

required
model Module

The model to use for generating the pseudo labels.

required
inductive bool

Whether to generate pseudo labels for inductive adaption, i.e., use 'test' instead of 'dev' split.

False

Returns:

Type Description
List[float]

A list of pseudo labels for the dev set of the data module.

get_max_rul(reader)

Resolve the maximum RUL of a reader to be comparable to floats.

patch_pseudo_labels(dm, pseudo_labels, inductive=False)

Patch a data module with pseudo labels in-place.

The pseudo labels are used to replace the RUL targets of the dev set of the data module. The validation and test sets are not affected.

It is not possible to patch the same data module multiple times. Instead, instantiate a fresh data module and patch that one.

Parameters:

Name Type Description Default
dm RulDataModule

The data module to patch.

required
pseudo_labels List[float]

The pseudo labels to use for patching the data module.

required
inductive bool

Whether to generate pseudo labels for inductive adaption, i.e., use 'test' instead of 'dev' split.

False