Source code for discrete_optimization.generic_tools.transformation.composite

#  Copyright (c) 2026 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.

"""Composite transformations for chaining multiple transformations."""

from __future__ import annotations

from typing import Optional

from discrete_optimization.generic_tools.do_problem import Problem, Solution
from discrete_optimization.generic_tools.transformation.problem_transformation import (
    ProblemTransformation,
)


[docs] class CompositeTransformation(ProblemTransformation): """Compose multiple transformations into a single transformation. Example: Chain transformations T1: RCPSP → Multiskill and T2: Multiskill → Preemptive # >>> t1 = RcpspToMultiskillTransformation() # >>> t2 = MultiskillToPreemptiveTransformation() # >>> composite = CompositeTransformation([t1, t2]) # >>> # Now composite: RCPSP → Preemptive Back-transformation automatically chains in reverse: S_preemptive → T2⁻¹ → S_multiskill → T1⁻¹ → S_rcpsp """ transformations: list[ProblemTransformation] _intermediate_problems: dict[int, Problem] # Cache for problem instances def __init__(self, transformations: list[ProblemTransformation]): """Initialize composite transformation. Args: transformations: List of transformations to chain. Raises: ValueError: If transformations list is empty """ if len(transformations) == 0: raise ValueError("Must provide at least one transformation") self.transformations = transformations self._intermediate_problems = {}
[docs] def transform_problem(self, source_problem: Problem) -> Problem: """Apply all transformations in sequence. Args: source_problem: Original problem Returns: Final transformed problem """ current_problem = source_problem self._intermediate_problems.clear() self._intermediate_problems[0] = source_problem for i, transformation in enumerate(self.transformations): current_problem = transformation.transform_problem(current_problem) self._intermediate_problems[i + 1] = current_problem return current_problem
[docs] def back_transform_solution( self, solution: Solution, source_problem: Problem ) -> Solution: """Back-transform through the chain in reverse order. If we have transformations [T1, T2, T3]: - T1: P1 → P2 - T2: P2 → P3 - T3: P3 → P4 Then back-transform: S4 → T3⁻¹ → S3 → T2⁻¹ → S2 → T1⁻¹ → S1 Args: solution: Solution in final target problem space source_problem: Original source problem Returns: Solution in original source problem space """ current_solution = solution # Apply back-transformations in reverse order for i, transformation in enumerate(reversed(self.transformations)): # Get the source problem for this back-transformation # (which is the intermediate problem from the forward pass) intermediate_idx = len(self.transformations) - i - 1 intermediate_problem = self._intermediate_problems.get( intermediate_idx, source_problem ) current_solution = transformation.back_transform_solution( current_solution, intermediate_problem ) return current_solution
[docs] def forward_transform_solution( self, solution: Solution, target_problem: Problem ) -> Optional[Solution]: """Forward-transform through the chain. Only works if ALL transformations support forward transformation. Args: solution: Solution in source problem space target_problem: Final target problem Returns: Solution in final target problem space, or None if any transformation doesn't support forward transformation """ current_solution = solution for i, transformation in enumerate(self.transformations): # Get intermediate target problem intermediate_target = self._intermediate_problems.get(i + 1, target_problem) current_solution = transformation.forward_transform_solution( current_solution, intermediate_target ) # If any transformation doesn't support forward, abort if current_solution is None: return None return current_solution
[docs] def is_bidirectional(self, source_problem: Problem) -> bool: """Check if all transformations are bidirectional. Args: source_problem: Source problem to check Returns: True if all transformations support forward transformation """ current_problem = source_problem for t in self.transformations: if not t.is_bidirectional(current_problem): return False current_problem = t.transform_problem(current_problem) return True
def __repr__(self) -> str: """Nice representation showing the transformation chain.""" if not self.transformations: return "CompositeTransformation(empty)" chain = " → ".join([type(t).__name__ for t in self.transformations]) return f"CompositeTransformation({chain})"
[docs] def chain_transformations( *transformations: ProblemTransformation, ) -> CompositeTransformation: """Chain multiple transformations into a composite. Example: # >>> t1 = RcpspToMultiskillTransformation() # >>> t2 = MultiskillToPreemptiveTransformation() # >>> t3 = PreemptiveToMultiskillTransformation() # >>> # >>> composite = chain_transformations(t1, t2, t3) # >>> # Equivalent to: RCPSP → Multiskill → Preemptive → Multiskill Args: *transformations: Transformations to chain Returns: CompositeTransformation instance """ return CompositeTransformation(list(transformations))