Source code for decomon.metrics.utils

from typing import Optional, Union

from decomon.core import BoxDomain, ForwardMode, PerturbationDomain
from decomon.layers.utils import exp, expand_dims, log, sum
from decomon.types import Tensor
from decomon.utils import add, minus

# compute categorical cross entropy


[docs] def categorical_cross_entropy( inputs: list[Tensor], dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, perturbation_domain: Optional[PerturbationDomain] = None, ) -> list[Tensor]: # step 1: exponential if perturbation_domain is None: perturbation_domain = BoxDomain() outputs = exp(inputs, mode=mode, perturbation_domain=perturbation_domain, dc_decomp=dc_decomp) # step 2: sum outputs = sum(outputs, axis=-1, mode=mode) # step 3: log outputs = log(outputs, dc_decomp=dc_decomp, mode=mode, perturbation_domain=perturbation_domain) outputs = expand_dims(outputs, dc_decomp=dc_decomp, mode=mode, perturbation_domain=perturbation_domain, axis=-1) # step 4: - inputs return add( minus(inputs, mode=mode, perturbation_domain=perturbation_domain, dc_decomp=dc_decomp), outputs, mode=mode, perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, )