baseline
Higher-order data modules to establish a baseline for transfer learning and domain adaption experiments.
BaselineDataModule
Bases: LightningDataModule
A higher-order data module that takes a RulDataModule. It provides the training and validation splits of the sub-dataset selected in the underlying data module but provides the test splits of all available subsets of the dataset. This makes it easy to evaluate the generalization of a supervised model on all sub-datasets.
Examples:
>>> import rul_datasets
>>> cmapss = rul_datasets.reader.CmapssReader(fd=1)
>>> dm = rul_datasets.RulDataModule(cmapss, batch_size=32)
>>> baseline_dm = rul_datasets.BaselineDataModule(dm)
>>> baseline_dm.prepare_data()
>>> baseline_dm.setup()
>>> train_fd1 = baseline_dm.train_dataloader()
>>> val_fd1 = baseline_dm.val_dataloader()
>>> test_fd1, test_fd2, test_fd3, test_fd4 = baseline_dm.test_dataloader()
__init__(data_module)
Create a new baseline data module from a RulDataModule.
It will provide a data loader of the underlying data module's training and validation splits. Additionally, it provides a data loader of the test split of all sub-datasets.
The data module keeps the configuration made in the underlying data module.
The same configuration is then passed on to create RulDataModules for all
sub-datasets, beside percent_fail_runs
and percent_broken
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_module |
RulDataModule
|
the underlying RulDataModule |
required |
prepare_data(*args, **kwargs)
Download and pre-process the underlying data.
This calls the prepare_data
function for all sub-datasets. All
previously completed preparation steps are skipped. It is called
automatically by pytorch_lightning
and executed on the first GPU in
distributed mode.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Passed down to each data module's |
()
|
**kwargs |
Any
|
Passed down to each data module's |
{}
|
setup(stage=None)
Load all splits as tensors into memory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Optional[str]
|
Passed down to each data module's |
None
|
test_dataloader(*args, **kwargs)
Return data loaders for all sub-datasets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Passed down to each data module. |
()
|
**kwargs |
Any
|
Passed down to each data module. |
{}
|
Returns:
Type | Description |
---|---|
List[DataLoader]
|
The test dataloaders of all sub-datasets. |