bertblocks.modeling.loss¶
Functions¶
|
Return the applicable loss function for a given problem type. |
Module Contents¶
- bertblocks.modeling.loss.get_loss_function(
- problem_type: Literal['regression', 'single_label_classification', 'multi_label_classification'] | None,
Return the applicable loss function for a given problem type.
- Parameters:
problem_type (Literal["regression", "single_label_classification", "multi_label_classification"] | None) – The type of problem.
- Returns:
The appropriate loss function module.
- Return type:
nn.Module
- Raises:
ValueError – If the problem type is not supported.