adaption
A module for unsupervised domain adaption losses.
DomainAdversarialLoss
Bases: Metric
The Domain Adversarial Neural Network Loss (DANN) uses a domain discriminator to measure the distance between two feature distributions.
The domain discriminator is a neural network that is jointly trained on classifying its input as one of two domains. Its output should be a single unscaled score (logit) which is fed to a binary cross entropy loss.
The domain discriminator is preceded by a GradientReversalLayer. This way, the discriminator is trained to separate the domains while the network generating the inputs is trained to marginalize the domain difference.
__init__(domain_disc)
Create a new DANN loss module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain_disc |
Module
|
The neural network to act as the domain discriminator. |
required |
update(source, target)
Calculate the DANN loss as the binary cross entropy of the discriminators prediction for the source and target features.
The source features receive a domain label of zero and the target features a domain label of one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
Tensor
|
The source features with domain label zero. |
required |
target |
Tensor
|
The target features with domain label one. |
required |
GradientReversalLayer
Bases: Module
The gradient reversal layer (GRL) acts as an identity function in the forward pass and scales the gradient by a negative scalar in the backward pass.
__init__(grad_weight=1.0)
Create a new Gradient Reversal Layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
grad_weight |
float
|
The scalar that weights the negative gradient. |
1.0
|
JointMaximumMeanDiscrepancyLoss
Bases: Metric
The Joint Maximum Mean Discrepancy Loss (JMMD) is a distance measure between multiple pairs of arbitrary distributions.
It is related to the MMD insofar as the distance of each distribution pair in a reproducing Hilbert kernel space (RHKS) is calculated and then multiplied before the discrepancy is computed.
For more information see MaximumMeanDiscrepancyLoss.
__init__()
Create a new JMMD loss module.
It features a single Gaussian kernel with a bandwidth chosen by the median heuristic.
update(source_features, target_features)
Compute the JMMD loss between multiple feature distributions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source_features |
List[Tensor]
|
The list of source features of shape
|
required |
target_features |
List[Tensor]
|
The list of target features of shape
|
required |
Returns:
Type | Description |
---|---|
None
|
scalar JMMD loss |
MaximumMeanDiscrepancyLoss
Bases: Metric
The Maximum Mean Discrepancy Loss (MMD) is a distance measure between two arbitrary distributions.
The distance is defined as the dot product in a reproducing Hilbert kernel space (RHKS) and is zero if and only if the distributions are identical. The RHKS is the space of the linear combination of multiple Gaussian kernels with bandwidths derived by the median heuristic.
The source and target feature batches are treated as samples from their respective distribution. The linear pairwise distances between the two batches are transformed into distances in the RHKS via the kernel trick:
rhks(x, y) = dot(to_rhks(x), to_rhks(y)) = multi_kernel(dot(x, y))
multi_kernel(distance) = mean([gaussian(distance, bw) for bw in bandwidths])
gaussian(distance, bandwidth) = exp(-distance * bandwidth)
The n kernels will use bandwidths between median / (2**(n/ 2))
and median * (
2**(n / 2))
, where median
is the median of the linear distances.
The MMD loss is then calculated as:
This version of MMD is biased, which is acceptable for training purposes.
__init__(num_kernels)
Create a new MMD loss module with n
kernels.
The bandwidths of the Gaussian kernels are derived by the median heuristic.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_kernels |
int
|
Number of Gaussian kernels to use. |
required |
update(source_features, target_features)
Compute the MMD loss between source and target feature distributions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source_features |
Tensor
|
Source features with shape |
required |
target_features |
Tensor
|
Target features with shape |
required |
Returns:
Type | Description |
---|---|
None
|
scalar MMD loss |