Skip to content

abstract

A module for the abstract base class of all approaches.

AdaptionApproach

Bases: LightningModule

This abstract class is the base of all adaption approaches.

It defines that there needs to be a feature_extractor, a regressor. These members can be accessed via read-only properties. The feature_extractor and regressor are trainable neural networks.

All child classes are supposed to implement their own constructors. The feature_extractor and regressor should explicitly not be arguments of the constructor and should be set by calling set_model. This way, the approach can be initialized with all hyperparameters first and afterward supplied with the networks. This is useful for initializing the networks with pre-trained weights.

Because models are constructed outside the approach, the default checkpointing mechanism of PyTorch Lightning fails to load checkpoints of AdaptionApproaches. We extended the checkpointing mechanism by implementing the on_save_checkpoint and on_load_checkpoint callbacks to make it work. If a subclass uses an additional model, besides feature extractor and regressor, that is not initialized in the constructor, the subclass needs to implement the CHECKPOINT_MODELS class variable. This variable is a list of model names to be included in the checkpoint. For example, if your approach has an additional model self._domain_disc, the CHECKPOINT_MODELS variable should be set to ['_domain_disc']. Otherwise, loading a checkpoint of this approach will fail.

feature_extractor: nn.Module property

The feature extraction network.

regressor: nn.Module property

The RUL regression network.

set_model(feature_extractor, regressor, *args, **kwargs)

Set the feature extractor and regressor for this approach.

Child classes can override this function to add additional models to an approach. The args and kwargs making this possible are ignored in this function.

Parameters:

Name Type Description Default
feature_extractor Module

The feature extraction network.

required
regressor Module

The RUL regression network.

required