mmd
The Maximum Mean Discrepancy (MMD) approach uses the distance measure of the same name to adapt a feature extractor. This implementation uses a multi-kernel variant of the MMD loss with bandwidths set via the median heuristic.
Source --> FeatEx --> Source Feats -----------> Regressor --> RUL Prediction
^ | |
| | v
Target -- --> Target Feats --> MMD Loss
It was first introduced by Long et al. as Deep Adaption Network (DAN) for image classification.
Used In
- Cao et al. (2021). Transfer learning for remaining useful life prediction of multi-conditions bearings based on bidirectional-GRU network. Measurement: Journal of the International Measurement Confederation, 178. 10.1016/j.measurement.2021.109287
- Krokotsch et al. (2020). A Novel Evaluation Framework for Unsupervised Domain Adaption on Remaining Useful Lifetime Estimation. 2020 IEEE International Conference on Prognostics and Health Management (ICPHM). 10.1109/ICPHM49022.2020.9187058
MmdApproach
Bases: AdaptionApproach
The MMD uses the Maximum Mean Discrepancy to adapt a feature extractor to be used with the source regressor.
The regressor needs the same number of input units as the feature extractor has output units.
Examples:
>>> from rul_adapt import model
>>> from rul_adapt import approach
>>> feat_ex = model.CnnExtractor(1, [16, 16, 1], 10, fc_units=16)
>>> reg = model.FullyConnectedHead(16, [1])
>>> mmd = approach.MmdApproach(0.01)
>>> mmd.set_model(feat_ex, reg)
__init__(mmd_factor, num_mmd_kernels=5, loss_type='mse', rul_score_mode='phm08', evaluate_degraded_only=False, **optim_kwargs)
Create a new MMD approach.
The strength of the influence of the MMD loss on the feature
extractor is controlled by the mmd_factor
. The higher it is, the stronger
the influence.
For more information about the possible optimizer keyword arguments, see here.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mmd_factor |
float
|
The strength of the MMD loss' influence. |
required |
num_mmd_kernels |
int
|
The number of kernels for the MMD loss. |
5
|
loss_type |
Literal['mse', 'rmse', 'mae']
|
The type of regression loss, either 'mse', 'rmse' or 'mae'. |
'mse'
|
rul_score_mode |
Literal['phm08', 'phm12']
|
The mode for the val and test RUL score, either 'phm08' or 'phm12'. |
'phm08'
|
evaluate_degraded_only |
bool
|
Whether to only evaluate the RUL score on degraded samples. |
False
|
**optim_kwargs |
Any
|
Keyword arguments for the optimizer, e.g. learning rate. |
{}
|
configure_optimizers()
Configure an optimizer.
forward(inputs)
Predict the RUL values for a batch of input features.
test_step(batch, batch_idx, dataloader_idx)
Execute one test step.
The batch
argument is a list of two tensors representing features and
labels. A RUL prediction is made from the features and the validation RMSE
and RUL score are calculated. The metrics recorded for dataloader idx zero
are assumed to be from the source domain and for dataloader idx one from the
target domain. The metrics are written to the configured logger under the
prefix test
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
List[Tensor]
|
A list containing a feature and a label tensor. |
required |
batch_idx |
int
|
The index of the current batch. |
required |
dataloader_idx |
int
|
The index of the current dataloader (0: source, 1: target). |
required |
training_step(batch, batch_idx)
Execute one training step.
The batch
argument is a list of three tensors representing the source
features, source labels and target features. Both types of features are fed
to the feature extractor. Then the regression loss for the source domain and
the MMD loss between domains is computed. The regression, MMD and combined
loss are logged.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
List[Tensor]
|
A list of a source feature, source label and target feature tensors. |
required |
batch_idx |
int
|
The index of the current batch. |
required |
Returns: The combined loss.
validation_step(batch, batch_idx, dataloader_idx)
Execute one validation step.
The batch
argument is a list of two tensors representing features and
labels. A RUL prediction is made from the features and the validation RMSE
and RUL score are calculated. The metrics recorded for dataloader idx zero
are assumed to be from the source domain and for dataloader idx one from the
target domain. The metrics are written to the configured logger under the
prefix val
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
List[Tensor]
|
A list containing a feature and a label tensor. |
required |
batch_idx |
int
|
The index of the current batch. |
required |
dataloader_idx |
int
|
The index of the current dataloader (0: source, 1: target). |
required |