pseudo_labels
Pseudo labeling is a simple approach that takes a model trained on the source domain to label the target domain. Afterward, the model training is continued on the combined source and target data. This process is repeated until the validation loss converges.
Used In
- Wang et al. (2022). Residual Life Prediction of Bearings Based on SENet-TCN and Transfer Learning. IEEE Access, 10, 10.1109/ACCESS.2022.3223387
Examples:
import torch
import rul_datasets
import pytorch_lightning as pl
from rul_adapt import model
from rul_adapt import approach
feat_ex = model.CnnExtractor(14, [16], 30, fc_units=16)
reg = model.FullyConnectedHead(16, [1])
supervised = approach.SupervisedApproach(0.001, "rmse", "adam")
supervised.set_model(feat_ex, reg)
fd1 = rul_datasets.RulDataModule(rul_datasets.CmapssReader(1), 32)
fd1.setup()
fd3 = rul_datasets.RulDataModule(rul_datasets.CmapssReader(3), 32)
fd3.setup()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(supervised, fd3)
pseudo_labels = approach.generate_pseudo_labels(fd1, supervised)
pseudo_labels = [max(0, min(125, pl)) for pl in pseudo_labels]
approach.patch_pseudo_labels(fd3, pseudo_labels)
combined_data = torch.utils.data.ConcatDataset(
[fd1.to_dataset("dev"), fd3.to_dataset("dev")]
)
combined_dl = torch.utils.data.DataLoader(combined_data, batch_size=32)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(supervised, train_dataloader=combined_dl)
generate_pseudo_labels(dm, model, inductive=False)
Generate pseudo labels for the dev set of a data module.
The pseudo labels are generated for the last timestep of each run. They are
returned raw and may therefore contain values bigger than max_rul
or negative
values. It is recommended to clip them to zero and max_rul
respectively before
using them to patch a reader.
The model is assumed to reside on the CPU where the calculation will be performed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dm |
RulDataModule
|
The data module to generate pseudo labels for. |
required |
model |
Module
|
The model to use for generating the pseudo labels. |
required |
inductive |
bool
|
Whether to generate pseudo labels for inductive adaption, i.e., use 'test' instead of 'dev' split. |
False
|
Returns:
Type | Description |
---|---|
List[float]
|
A list of pseudo labels for the dev set of the data module. |
get_max_rul(reader)
Resolve the maximum RUL of a reader to be comparable to floats.
patch_pseudo_labels(dm, pseudo_labels, inductive=False)
Patch a data module with pseudo labels in-place.
The pseudo labels are used to replace the RUL targets of the dev set of the data module. The validation and test sets are not affected.
It is not possible to patch the same data module multiple times. Instead, instantiate a fresh data module and patch that one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dm |
RulDataModule
|
The data module to patch. |
required |
pseudo_labels |
List[float]
|
The pseudo labels to use for patching the data module. |
required |
inductive |
bool
|
Whether to generate pseudo labels for inductive adaption, i.e., use 'test' instead of 'dev' split. |
False
|