Source code for decomon.keras_utils

from typing import Any

import keras
import keras.ops as K
import numpy as np
from keras.layers import Layer

from decomon.types import BackendTensor, Tensor

BACKEND_TENSORFLOW = "tensorflow"
BACKEND_PYTORCH = "torch"
BACKEND_NUMPY = "numpy"
BACKEND_JAX = "jax"


[docs] class BatchedIdentityLike(keras.Operation): """Keras Operation creating an identity tensor with shape (including batch_size) based on input. The output shape is tuple(x.shape) + (x.shape[-1],), the tensor being the identity along the 2 last dimensions. """
[docs] def call(self, x: BackendTensor) -> Tensor: input_shape = x.shape identity_tensor = K.identity(input_shape[-1], dtype=x.dtype) n_repeat = int(np.prod(input_shape[:-1])) return K.reshape(K.repeat(identity_tensor[None], n_repeat, axis=0), tuple(input_shape) + (-1,))
[docs] def compute_output_spec(self, x: Tensor) -> keras.KerasTensor: x_shape = x.shape x_type = getattr(x, "dtype", type(x)) x_sparse = getattr(x, "sparse", False) return keras.KerasTensor( shape=tuple(x_shape) + (x_shape[-1],), dtype=x_type, sparse=x_sparse, )
[docs] class BatchedDiagLike(keras.Operation): """Keras Operation transforming last dimension into a diagonal tensor. The output shape is tuple(x.shape) + (x.shape[-1],). When fixing all but 2 last dimensions, the output tensor is a square tensor whose main diagonal is the input tensor with same first dimensions fixed, and 0 elsewhere. This is a replacement for tensorflow.linalg.diag(). """
[docs] def call(self, x: BackendTensor) -> Tensor: return K.concatenate([K.diag(K.ravel(w_part))[None] for w_part in K.split(x, len(x), axis=0)], axis=0)
[docs] def compute_output_spec(self, x: Tensor) -> keras.KerasTensor: x_shape = x.shape x_type = getattr(x, "dtype", type(x)) x_sparse = getattr(x, "sparse", False) return keras.KerasTensor( shape=tuple(x_shape) + (x_shape[-1],), dtype=x_type, sparse=x_sparse, )
[docs] def is_symbolic_tensor(x: Tensor) -> bool: """Check whether the tensor is symbolic or not. Works even during backend calls made by layers without actual compute_output_shape(). In this case, x is not KerasTensor anymore but a backend Tensor with None in its shape. """ return None in x.shape
[docs] def get_weight_index(layer: Layer, weight: keras.Variable) -> int: """Get weight index among layer tracked weights Args: layer: layer we are looking weight: weight supposed to be part of tracked weights by the layer Returns: the index of the weight in `layer.weights` list Raises: IndexError: if `weight` is not part of `layer.weights` """ indexes = [i for i, w in enumerate(layer.weights) if w is weight] try: return indexes[0] except IndexError: raise IndexError(f"The weight {weight} is not tracked by the layer {layer}.")
[docs] def get_weight_index_from_name(layer: Layer, weight_name: str) -> int: """Get weight index among layer tracked weights Args: layer: layer we are looking weight_name: name of the weight supposed to be part of tracked weights by the layer Returns: the index of the weight in `layer.weights` list Raises: AttributeError: if `weight_name` is not the name of an attribute of `layer` IndexError: if the corresponding layer attribute is not part of `layer.weights` """ weight = getattr(layer, weight_name) try: return get_weight_index(layer=layer, weight=weight) except IndexError: raise IndexError(f"The weight {weight_name} is not tracked by the layer {layer}.")
[docs] def reset_layer(new_layer: Layer, original_layer: Layer, weight_names: list[str]) -> None: """Reset some weights of a layer by using the weights of another layer. Args: new_layer: the decomon layer whose weights will be updated original_layer: the layer used to update the weights weight_names: the names of the weights to update Returns: """ if not original_layer.built: raise ValueError(f"the layer {original_layer.name} has not been built yet") if not new_layer.built: raise ValueError(f"the layer {new_layer.name} has not been built yet") else: new_params = new_layer.get_weights() original_params = original_layer.get_weights() for weight_name in weight_names: new_params[get_weight_index_from_name(new_layer, weight_name)] = original_params[ get_weight_index_from_name(original_layer, weight_name) ] new_layer.set_weights(new_params)
[docs] def reset_layer_all_weights(new_layer: Layer, original_layer: Layer) -> None: """Reset all the weights of a layer by using the weights of another layer. Args: new_layer: the decomon layer whose weights will be updated original_layer: the layer used to update the weights Returns: """ reset_layer(new_layer=new_layer, original_layer=original_layer, weight_names=[w.name for w in new_layer.weights])
[docs] def share_layer_all_weights( original_layer: Layer, new_layer: Layer, ) -> None: """Share all the weights of an already built layer to another unbuilt layer. Args: original_layer: the layer used to share the weights new_layer: the new layer which will be buit and will share the weights of the original layer Returns: """ share_weights_and_build( new_layer=new_layer, original_layer=original_layer, weight_names=[w.name for w in original_layer.weights] )
[docs] def share_weights_and_build(original_layer: Layer, new_layer: Layer, weight_names: list[str]) -> None: """Share the weights specidifed by names of an already built layer to another unbuilt layer. We assume that each weight is also an original_laer's attribute whose name is the weight name. Args: original_layer: the layer used to share the weights new_layer: the new layer which will be buit and will share the weights of the original layer weight_names: names of the weights to share Returns: """ # Check the original_layer is built and the new_layer is not built if not original_layer.built: raise ValueError("The original layer must already be built for sharing its weights.") if new_layer.built: raise ValueError("The new layer must not be built to get the weights of the original layer") # Check that input exists really (ie that the layer has already been called on a symbolic KerasTensor inp = original_layer.input # will raise a ValueError if not existing # store the weights as a new_layer variable before build (ie before the lock) for w_name in weight_names: w = getattr(original_layer, w_name) try: setattr(new_layer, w_name, w) except AttributeError: # manage hidden weights introduced for LoRA https://github.com/keras-team/keras/pull/18942 w_name = f"_{w_name}" w = getattr(original_layer, w_name) setattr(new_layer, w_name, w) # build the layer new_layer(inp) # overwrite the newly generated weights and untrack them for w_name in weight_names: w = getattr(original_layer, w_name) w_to_drop = getattr(new_layer, w_name) try: setattr(new_layer, w_name, w) except AttributeError: # manage hidden weights introduced for LoRA https://github.com/keras-team/keras/pull/18942 w_name = f"_{w_name}" w = getattr(original_layer, w_name) w_to_drop = getattr(new_layer, w_name) setattr(new_layer, w_name, w) # untrack the not used anymore weight new_layer._tracker.untrack(w_to_drop)
[docs] def check_if_single_shape(shape: Any) -> bool: """ Args: input_shape: Returns: """ if isinstance(shape, list) and shape and isinstance(shape[0], (int, type(None))): return True if not isinstance(shape, (list, tuple, dict)): shape = tuple(shape) return isinstance(shape, tuple) and len(shape) > 0 and isinstance(shape[0], (int, type(None)))