Source code for decomon.layers.decomon_merge_layers

from typing import Any, Optional, Union

import keras.ops as K
from keras.layers import (
    Add,
    Average,
    Concatenate,
    Dot,
    Lambda,
    Layer,
    Maximum,
    Minimum,
    Multiply,
    Subtract,
)

from decomon.core import ForwardMode, PerturbationDomain
from decomon.layers.core import DecomonLayer
from decomon.layers.utils import broadcast, multiply, permute_dimensions
from decomon.types import BackendTensor
from decomon.utils import maximum, minus, subtract

##### Merge Layer ####


[docs] class DecomonMerge(DecomonLayer): """Base class for Decomon layers based on Mergind Keras layers."""
[docs] def compute_output_shape(self, input_shape: list[tuple[Optional[int], ...]]) -> list[tuple[Optional[int], ...]]: # type: ignore """Compute output shapes from input shapes. By default, we assume that all inputs will be merged into "one" (still a list of tensors though). """ # split inputs input_shapes_list = self.inputs_outputs_spec.split_inputsformode_to_merge(input_shape) # compute original layer output shape y_shapes = [ self.inputs_outputs_spec.get_kerasinputshape_from_inputshapesformode(input_shapes) for input_shapes in input_shapes_list ] # input shapes for the original layer y_out_shape = self.original_keras_layer_class.compute_output_shape( self, y_shapes ) # output shape of the original layer # same shape as one input or not? if y_out_shape == y_shapes[0]: # same output shape as one input return input_shapes_list[0] else: # something change along the way (cf Concatenate), => we deduce decomon output shape ( x_shape, u_c_shape, w_u_shape, b_u_shape, l_c_shape, w_l_shape, b_l_shape, h_shape, g_shape, ) = self.inputs_outputs_spec.get_fullinputshapes_from_inputshapesformode(input_shapes_list[0]) y_out_shape_wo_batchsize = y_out_shape[1:] if self.inputs_outputs_spec.affine: model_inputdim = x_shape[-1] batchsize = x_shape[0] w_out_shape = (batchsize, model_inputdim) + y_out_shape_wo_batchsize else: w_out_shape = tuple() fulloutputshapes = [ x_shape, y_out_shape, w_out_shape, y_out_shape, y_out_shape, w_out_shape, y_out_shape, y_out_shape, y_out_shape, ] return self.inputs_outputs_spec.extract_inputshapesformode_from_fullinputshapes(fulloutputshapes)
[docs] def build(self, input_shape: list[tuple[Optional[int], ...]]) -> None: n_comp = self.nb_tensors input_shape_y = input_shape[n_comp - 1 :: n_comp] self.original_keras_layer_class.build(self, input_shape_y)
[docs] class DecomonAdd(DecomonMerge, Add): """LiRPA implementation of Add layers. See Keras official documentation for further details on the Add operator """ original_keras_layer_class = Add def __init__( self, perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, )
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: # splits the inputs ( inputs_x, inputs_u_c, inputs_w_u, inputs_b_u, inputs_l_c, inputs_w_l, inputs_b_l, inputs_h, inputs_g, ) = self.inputs_outputs_spec.get_fullinputs_by_type_from_inputsformode_to_merge(inputs) x_out = inputs_x[0] dtype = x_out.dtype empty_tensor = self.inputs_outputs_spec.get_empty_tensor(dtype=dtype) # outputs if self.ibp: u_c_out = sum(inputs_u_c) l_c_out = sum(inputs_l_c) else: u_c_out, l_c_out = empty_tensor, empty_tensor if self.affine: b_u_out = sum(inputs_b_u) b_l_out = sum(inputs_b_l) w_u_out = sum(inputs_w_u) w_l_out = sum(inputs_w_l) else: w_u_out, b_u_out, w_l_out, b_l_out = empty_tensor, empty_tensor, empty_tensor, empty_tensor if self.dc_decomp: raise NotImplementedError() else: h_out, g_out = empty_tensor, empty_tensor return self.inputs_outputs_spec.extract_outputsformode_from_fulloutputs( [x_out, u_c_out, w_u_out, b_u_out, l_c_out, w_l_out, b_l_out, h_out, g_out] )
[docs] class DecomonAverage(DecomonMerge, Average): """LiRPA implementation of Average layers. See Keras official documentation for further details on the Average operator """ original_keras_layer_class = Average def __init__( self, perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, ) self.op = Lambda(lambda x: sum(x) / len(x))
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: # splits the inputs ( inputs_x, inputs_u_c, inputs_w_u, inputs_b_u, inputs_l_c, inputs_w_l, inputs_b_l, inputs_h, inputs_g, ) = self.inputs_outputs_spec.get_fullinputs_by_type_from_inputsformode_to_merge(inputs) x_out = inputs_x[0] dtype = x_out.dtype empty_tensor = self.inputs_outputs_spec.get_empty_tensor(dtype=dtype) # outputs if self.ibp: u_c_out = self.op(inputs_u_c) l_c_out = self.op(inputs_l_c) else: u_c_out, l_c_out = empty_tensor, empty_tensor if self.affine: b_u_out = self.op(inputs_b_u) b_l_out = self.op(inputs_b_l) w_u_out = self.op(inputs_w_u) w_l_out = self.op(inputs_w_l) else: w_u_out, b_u_out, w_l_out, b_l_out = empty_tensor, empty_tensor, empty_tensor, empty_tensor if self.dc_decomp: raise NotImplementedError() else: h_out, g_out = empty_tensor, empty_tensor return self.inputs_outputs_spec.extract_outputsformode_from_fulloutputs( [x_out, u_c_out, w_u_out, b_u_out, l_c_out, w_l_out, b_l_out, h_out, g_out] )
[docs] class DecomonSubtract(DecomonMerge, Subtract): """LiRPA implementation of Subtract layers. See Keras official documentation for further details on the Subtract operator """ original_keras_layer_class = Subtract def __init__( self, perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, )
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: if self.dc_decomp: raise NotImplementedError() # splits the inputs inputs_list = self.inputs_outputs_spec.split_inputsformode_to_merge(inputs) # check number of inputs if len(inputs_list) != 2: raise ValueError("This layer is intended to merge only 2 layers.") output = subtract( inputs_list[0], inputs_list[1], dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode, ) return output
[docs] class DecomonMinimum(DecomonMerge, Minimum): """LiRPA implementation of Minimum layers. See Keras official documentation for further details on the Minimum operator """ original_keras_layer_class = Minimum def __init__( self, perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, )
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: if self.dc_decomp: raise NotImplementedError() # splits the inputs inputs_list = self.inputs_outputs_spec.split_inputsformode_to_merge(inputs) # look at minus the input to apply maximum inputs_list = [ minus(single_inputs, mode=self.mode, dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain) for single_inputs in inputs_list ] #  check number of inputs if len(inputs_list) == 1: # nothing to merge return inputs else: output = maximum( inputs_list[0], inputs_list[1], dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode, ) for j in range(2, len(inputs_list)): output = maximum( output, inputs_list[j], dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode, ) return minus(output, mode=self.mode, dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain)
[docs] class DecomonMaximum(DecomonMerge, Maximum): """LiRPA implementation of Maximum layers. See Keras official documentation for further details on the Maximum operator """ original_keras_layer_class = Maximum def __init__( self, perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, )
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: if self.dc_decomp: raise NotImplementedError() # splits the inputs inputs_list = self.inputs_outputs_spec.split_inputsformode_to_merge(inputs) #  check number of inputs if len(inputs_list) == 1: # nothing to merge return inputs else: output = maximum( inputs_list[0], inputs_list[1], dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode, ) for j in range(2, len(inputs_list)): output = maximum( output, inputs_list[j], dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode, ) return output
[docs] class DecomonConcatenate(DecomonMerge, Concatenate): """LiRPA implementation of Concatenate layers. See Keras official documentation for further details on the Concatenate operator """ original_keras_layer_class = Concatenate def __init__( self, axis: int = -1, perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( axis=axis, perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, ) def func(inputs: list[BackendTensor]) -> BackendTensor: return Concatenate.call(self, inputs) self.op = func if self.axis == -1: self.op_w = self.op else: self.op_w = Concatenate(axis=self.axis + 1)
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: # splits the inputs ( inputs_x, inputs_u_c, inputs_w_u, inputs_b_u, inputs_l_c, inputs_w_l, inputs_b_l, inputs_h, inputs_g, ) = self.inputs_outputs_spec.get_fullinputs_by_type_from_inputsformode_to_merge(inputs) x_out = inputs_x[0] dtype = x_out.dtype empty_tensor = self.inputs_outputs_spec.get_empty_tensor(dtype=dtype) # outputs if self.ibp: u_c_out = self.op(inputs_u_c) l_c_out = self.op(inputs_l_c) else: u_c_out, l_c_out = empty_tensor, empty_tensor if self.affine: b_u_out = self.op(inputs_b_u) b_l_out = self.op(inputs_b_l) w_u_out = self.op_w(inputs_w_u) w_l_out = self.op_w(inputs_w_l) else: w_u_out, b_u_out, w_l_out, b_l_out = empty_tensor, empty_tensor, empty_tensor, empty_tensor if self.dc_decomp: raise NotImplementedError() else: h_out, g_out = empty_tensor, empty_tensor return self.inputs_outputs_spec.extract_outputsformode_from_fulloutputs( [x_out, u_c_out, w_u_out, b_u_out, l_c_out, w_l_out, b_l_out, h_out, g_out] )
[docs] class DecomonMultiply(DecomonMerge, Multiply): """LiRPA implementation of Multiply layers. See Keras official documentation for further details on the Multiply operator """ original_keras_layer_class = Multiply def __init__( self, perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, )
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: if self.dc_decomp: raise NotImplementedError() # splits the inputs inputs_list = self.inputs_outputs_spec.split_inputsformode_to_merge(inputs) #  check number of inputs if len(inputs_list) == 1: # nothing to merge return inputs else: output = multiply( inputs_list[0], inputs_list[1], dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode, ) for j in range(2, len(inputs_list)): output = multiply( output, inputs_list[j], dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode, ) return output
[docs] class DecomonDot(DecomonMerge, Dot): """LiRPA implementation of Dot layers. See Keras official documentation for further details on the Dot operator """ original_keras_layer_class = Dot def __init__( self, axes: Union[int, tuple[int, int]] = (-1, -1), perturbation_domain: Optional[PerturbationDomain] = None, dc_decomp: bool = False, mode: Union[str, ForwardMode] = ForwardMode.HYBRID, finetune: bool = False, shared: bool = False, fast: bool = True, **kwargs: Any, ): super().__init__( axes=axes, perturbation_domain=perturbation_domain, dc_decomp=dc_decomp, mode=mode, finetune=finetune, shared=shared, fast=fast, **kwargs, ) if isinstance(axes, int): self.axes = (axes, axes) else: self.axes = axes
[docs] def call(self, inputs: list[BackendTensor], **kwargs: Any) -> list[BackendTensor]: if self.dc_decomp: raise NotImplementedError() # splits the inputs inputs_list = self.inputs_outputs_spec.split_inputsformode_to_merge(inputs) #  check number of inputs if len(inputs_list) == 1: # nothing to merge return inputs elif len(inputs_list) == 2: inputs_0, inputs_1 = inputs_list else: raise NotImplementedError("This layer is not implemented to merge more than 2 layers.") input_shape_0 = self.inputs_outputs_spec.get_kerasinputshape(inputs_0) input_shape_1 = self.inputs_outputs_spec.get_kerasinputshape(inputs_1) n_0 = len(input_shape_0) - 2 n_1 = len(input_shape_1) - 2 inputs_0 = permute_dimensions(inputs_0, self.axes[0], mode=self.mode, dc_decomp=self.dc_decomp) inputs_1 = permute_dimensions(inputs_1, self.axes[1], mode=self.mode, dc_decomp=self.dc_decomp) inputs_0 = broadcast( inputs_0, n_1, -1, mode=self.mode, dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain ) inputs_1 = broadcast( inputs_1, n_0, 2, mode=self.mode, dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain ) outputs_multiply = multiply( inputs_0, inputs_1, dc_decomp=self.dc_decomp, perturbation_domain=self.perturbation_domain, mode=self.mode ) x, u_c, w_u, b_u, l_c, w_l, b_l, h, g = self.inputs_outputs_spec.get_fullinputs_from_inputsformode( outputs_multiply, compute_ibp_from_affine=False ) dtype = x.dtype empty_tensor = self.inputs_outputs_spec.get_empty_tensor(dtype=dtype) if self.ibp: u_c_out = K.sum(u_c, 1) l_c_out = K.sum(l_c, 1) else: u_c_out, l_c_out = empty_tensor, empty_tensor if self.affine: w_u_out = K.sum(w_u, 2) b_u_out = K.sum(b_u, 1) w_l_out = K.sum(w_l, 2) b_l_out = K.sum(b_l, 1) else: w_u_out, b_u_out, w_l_out, b_l_out = empty_tensor, empty_tensor, empty_tensor, empty_tensor if self.dc_decomp: raise NotImplementedError() else: h_out, g_out = empty_tensor, empty_tensor return self.inputs_outputs_spec.extract_outputsformode_from_fulloutputs( [x, u_c_out, w_u_out, b_u_out, l_c_out, w_l_out, b_l_out, h_out, g_out] )