bertblocks.modeling.loss

Functions

get_loss_function(→ torch.nn.Module)

Return the applicable loss function for a given problem type.

Module Contents

bertblocks.modeling.loss.get_loss_function(
problem_type: Literal['regression', 'single_label_classification', 'multi_label_classification'] | None,
) torch.nn.Module[source]

Return the applicable loss function for a given problem type.

Parameters:

problem_type (Literal["regression", "single_label_classification", "multi_label_classification"] | None) – The type of problem.

Returns:

The appropriate loss function module.

Return type:

nn.Module

Raises:

ValueError – If the problem type is not supported.