abstract
A module for the abstract base class of all approaches.
AdaptionApproach
Bases: LightningModule
This abstract class is the base of all adaption approaches.
It defines that there needs to be a feature_extractor
, a regressor
. These
members can be accessed via read-only properties. The feature_extractor
and
regressor
are trainable neural networks.
All child classes are supposed to implement their own constructors. The
feature_extractor
and regressor
should explicitly not be arguments of the
constructor and should be set by calling set_model. This way, the approach can
be initialized with all hyperparameters first and afterward supplied with the
networks. This is useful for initializing the networks with pre-trained weights.
Because models are constructed outside the approach, the default checkpointing
mechanism of PyTorch Lightning fails to load checkpoints of AdaptionApproaches.
We extended the checkpointing mechanism by implementing the on_save_checkpoint
and on_load_checkpoint
callbacks to make it work. If a subclass uses an
additional model, besides feature extractor and regressor, that is not
initialized in the constructor, the subclass needs to implement the
CHECKPOINT_MODELS
class variable. This variable is a list of model names to be
included in the checkpoint. For example, if your approach has an additional model
self._domain_disc
, the CHECKPOINT_MODELS
variable should be set to
['_domain_disc']
. Otherwise, loading a checkpoint of this approach will fail.
feature_extractor: nn.Module
property
The feature extraction network.
regressor: nn.Module
property
The RUL regression network.
set_model(feature_extractor, regressor, *args, **kwargs)
Set the feature extractor and regressor for this approach.
Child classes can override this function to add additional models to an
approach. The args
and kwargs
making this possible are ignored in this
function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
feature_extractor |
Module
|
The feature extraction network. |
required |
regressor |
Module
|
The RUL regression network. |
required |