Pseudo Label Approach¶
from itertools import chain
import rul_adapt
import rul_datasets
import pytorch_lightning as pl
import torch
The pseudo label approach works by training a supervised model on the source domain and then using the model to predict labels for the target domain. The target domain is then combined with the source domain and the model is retrained on the combined dataset. This process is repeated until the model converges.
Here, we will train a model of FD003 of the CMAPSS dataset and pseudo label the FD001 dataset.
feature_extractor = rul_adapt.model.CnnExtractor(
14, [32, 16, 8], 30, fc_units=64
)
regressor = rul_adapt.model.FullyConnectedHead(
64, [1], act_func_on_last_layer=False
)
Supervised Training¶
First we set up a data module for FD003.
fd3 = rul_datasets.CmapssReader(fd=3)
dm_labeled = rul_datasets.RulDataModule(fd3, batch_size=128)
Then we set up a supervised approach and train it for 10 epochs. In practice, it should be trained until the validation loss stops decreasing.
approach = rul_adapt.approach.SupervisedApproach(
lr=0.001, loss_type="rmse", optim_type="adam"
)
approach.set_model(feature_extractor, regressor)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(approach, dm_labeled)
trainer.validate(approach, dm_labeled)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs | Name | Type | Params ---------------------------------------------------------- 0 | train_loss | MeanSquaredError | 0 1 | val_loss | MeanSquaredError | 0 2 | test_loss | MeanSquaredError | 0 3 | evaluator | AdaptionEvaluator | 0 4 | _feature_extractor | CnnExtractor | 15.7 K 5 | _regressor | FullyConnectedHead | 65 ---------------------------------------------------------- 15.7 K Trainable params 0 Non-trainable params 15.7 K Total params 0.063 Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=10` reached.
Validation: 0it [00:00, ?it/s]
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Validate metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── val/loss 14.083422660827637 ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'val/loss': 14.083422660827637}]
Pseudo Labeling¶
Now we can use the trained model to generate labels for FD001. We truncate FD001 to 80% to simulate a target domain without failure data.
fd1 = rul_datasets.CmapssReader(fd=1, percent_broken=0.8)
dm_unlabeled = rul_datasets.RulDataModule(fd1, batch_size=128)
The pseudo label is generated for the last time step of each sequence. They may be implausible, e.g. less than zero, in the early iterations and need to be clipped. When patching the data module with the pseudo labels, a suitable RUL values for each sequence are created.
pseudo_labels = rul_adapt.approach.generate_pseudo_labels(dm_unlabeled, approach)
pseudo_labels = [max(0, pl) for pl in pseudo_labels]
rul_adapt.approach.patch_pseudo_labels(dm_unlabeled, pseudo_labels)
/home/tilman/Programming/rul-adapt/rul_adapt/approach/pseudo_labels.py:88: UserWarning: At least one of the generated pseudo labels is negative. Please consider clipping them to zero. warnings.warn(
We create a new trainer and validate our pre-trained approach on FD001 to get a baseline.
trainer = pl.Trainer(max_epochs=10)
trainer.validate(approach, dm_unlabeled)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
Validation: 0it [00:00, ?it/s]
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Validate metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── val/loss 36.179779052734375 ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'val/loss': 36.179779052734375}]
Afterward, we combine FD003 and the pseudo labeled FD001 and train the approach for another 10 epochs. We can observe that the validation loss decreased significantly. The pseudo labeling and training can now be repeated with the new model until the validation loss converges.
combined_train_data = torch.utils.data.ConcatDataset(
[dm_labeled.to_dataset("dev"), dm_unlabeled.to_dataset("dev")]
)
combined_train_dl = torch.utils.data.DataLoader(
combined_train_data, batch_size=128, shuffle=True
)
trainer.fit(approach, train_dataloaders=combined_train_dl)
trainer.validate(approach, dm_unlabeled)
/home/tilman/Programming/rul-adapt/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:108: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop. rank_zero_warn( | Name | Type | Params ---------------------------------------------------------- 0 | train_loss | MeanSquaredError | 0 1 | val_loss | MeanSquaredError | 0 2 | test_loss | MeanSquaredError | 0 3 | evaluator | AdaptionEvaluator | 0 4 | _feature_extractor | CnnExtractor | 15.7 K 5 | _regressor | FullyConnectedHead | 65 ---------------------------------------------------------- 15.7 K Trainable params 0 Non-trainable params 15.7 K Total params 0.063 Total estimated model params size (MB)
Training: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=10` reached.
Validation: 0it [00:00, ?it/s]
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Validate metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── val/loss 29.42894172668457 ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'val/loss': 29.42894172668457}]