Skip to content

adaption

A module for unsupervised domain adaption losses.

DomainAdversarialLoss

Bases: Metric

The Domain Adversarial Neural Network Loss (DANN) uses a domain discriminator to measure the distance between two feature distributions.

The domain discriminator is a neural network that is jointly trained on classifying its input as one of two domains. Its output should be a single unscaled score (logit) which is fed to a binary cross entropy loss.

The domain discriminator is preceded by a GradientReversalLayer. This way, the discriminator is trained to separate the domains while the network generating the inputs is trained to marginalize the domain difference.

__init__(domain_disc)

Create a new DANN loss module.

Parameters:

Name Type Description Default
domain_disc Module

The neural network to act as the domain discriminator.

required

update(source, target)

Calculate the DANN loss as the binary cross entropy of the discriminators prediction for the source and target features.

The source features receive a domain label of zero and the target features a domain label of one.

Parameters:

Name Type Description Default
source Tensor

The source features with domain label zero.

required
target Tensor

The target features with domain label one.

required

GradientReversalLayer

Bases: Module

The gradient reversal layer (GRL) acts as an identity function in the forward pass and scales the gradient by a negative scalar in the backward pass.

GRL(f(x)) = f(x)
GRL`(f(x)) = -lambda * f`(x)

__init__(grad_weight=1.0)

Create a new Gradient Reversal Layer.

Parameters:

Name Type Description Default
grad_weight float

The scalar that weights the negative gradient.

1.0

JointMaximumMeanDiscrepancyLoss

Bases: Metric

The Joint Maximum Mean Discrepancy Loss (JMMD) is a distance measure between multiple pairs of arbitrary distributions.

It is related to the MMD insofar as the distance of each distribution pair in a reproducing Hilbert kernel space (RHKS) is calculated and then multiplied before the discrepancy is computed.

joint_rhks(xs, ys) = prod(rhks(x, y) for x, y in zip(xs, xs))

For more information see MaximumMeanDiscrepancyLoss.

__init__()

Create a new JMMD loss module.

It features a single Gaussian kernel with a bandwidth chosen by the median heuristic.

update(source_features, target_features)

Compute the JMMD loss between multiple feature distributions.

Parameters:

Name Type Description Default
source_features List[Tensor]

The list of source features of shape [batch_size, num_features].

required
target_features List[Tensor]

The list of target features of shape [batch_size, num_features].

required

Returns:

Type Description
None

scalar JMMD loss

MaximumMeanDiscrepancyLoss

Bases: Metric

The Maximum Mean Discrepancy Loss (MMD) is a distance measure between two arbitrary distributions.

The distance is defined as the dot product in a reproducing Hilbert kernel space (RHKS) and is zero if and only if the distributions are identical. The RHKS is the space of the linear combination of multiple Gaussian kernels with bandwidths derived by the median heuristic.

The source and target feature batches are treated as samples from their respective distribution. The linear pairwise distances between the two batches are transformed into distances in the RHKS via the kernel trick:

rhks(x, y) = dot(to_rhks(x), to_rhks(y)) = multi_kernel(dot(x, y))
multi_kernel(distance) = mean([gaussian(distance, bw) for bw in bandwidths])
gaussian(distance, bandwidth) = exp(-distance * bandwidth)

The n kernels will use bandwidths between median / (2**(n/ 2)) and median * ( 2**(n / 2)), where median is the median of the linear distances.

The MMD loss is then calculated as:

mean(rhks(source, source) + rhks(target, target) - 2 * rhks(source, target))

This version of MMD is biased, which is acceptable for training purposes.

__init__(num_kernels)

Create a new MMD loss module with n kernels.

The bandwidths of the Gaussian kernels are derived by the median heuristic.

Parameters:

Name Type Description Default
num_kernels int

Number of Gaussian kernels to use.

required

update(source_features, target_features)

Compute the MMD loss between source and target feature distributions.

Parameters:

Name Type Description Default
source_features Tensor

Source features with shape [batch_size, num_features]

required
target_features Tensor

Target features with shape [batch_size, num_features]

required

Returns:

Type Description
None

scalar MMD loss