ssl
A module with higher-order data modules for semi-supervised learning.
SemiSupervisedDataModule
Bases: LightningDataModule
A higher-order data module used for semi-supervised learning with a labeled data module and an unlabeled one. It makes sure that both data modules come from the same sub-dataset.
Examples:
>>> import rul_datasets
>>> fd1 = rul_datasets.CmapssReader(fd=1, window_size=20, percent_fail_runs=0.5)
>>> fd1_complement = fd1.get_complement(percent_broken=0.8)
>>> labeled = rul_datasets.RulDataModule(fd1, 32)
>>> unlabeled = rul_datasets.RulDataModule(fd1_complement, 32)
>>> dm = rul_datasets.SemiSupervisedDataModule(labeled, unlabeled)
>>> dm.prepare_data()
>>> dm.setup()
>>> train_ssl = dm.train_dataloader()
>>> val = dm.val_dataloader()
>>> test = dm.test_dataloader()
__init__(labeled, unlabeled)
Create a new semi-supervised data module from a labeled and unlabeled RulDataModule.
The both data modules are checked for compatability (seeRulDataModule). These
checks include that the fd
match between them.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labeled |
RulDataModule
|
The data module of the labeled dataset. |
required |
unlabeled |
RulDataModule
|
The data module of the unlabeled dataset. |
required |
prepare_data(*args, **kwargs)
Download and pre-process the underlying data.
This calls the prepare_data
function for source and target domain. All
previously completed preparation steps are skipped. It is called
automatically by pytorch_lightning
and executed on the first GPU in
distributed mode.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Passed down to each data module's |
()
|
**kwargs |
Any
|
Passed down to each data module's |
{}
|
setup(stage=None)
Load labeled and unlabeled data into memory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Optional[str]
|
Passed down to each data module's |
None
|
test_dataloader(*args, **kwargs)
Create a data loader of the labeled test data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Ignored. Only for adhering to parent class interface. |
()
|
**kwargs |
Any
|
Ignored. Only for adhering to parent class interface. |
{}
|
Returns:
Type | Description |
---|---|
DataLoader
|
The labeled test data loader. |
train_dataloader(*args, **kwargs)
Create a data loader of an AdaptionDataset using labeled and unlabeled.
The data loader is configured to shuffle the data. The pin_memory
option is
activated to achieve maximum transfer speed to the GPU.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Ignored. Only for adhering to parent class interface. |
()
|
**kwargs |
Any
|
Ignored. Only for adhering to parent class interface. |
{}
|
Returns:
Type | Description |
---|---|
DataLoader
|
The training data loader |
val_dataloader(*args, **kwargs)
Create a data loader of the labeled validation data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Ignored. Only for adhering to parent class interface. |
()
|
**kwargs |
Any
|
Ignored. Only for adhering to parent class interface. |
{}
|
Returns:
Type | Description |
---|---|
DataLoader
|
The labeled validation data loader. |