Source code for discrete_optimization.generic_tools.transformation.transformation_graph

#  Copyright (c) 2026 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.

"""Transformation graph for discovering solver accessibility.

This module builds a graph of problem transformations and uses it to:
- Discover all transformations in the codebase
- Find paths between problem types
- Calculate transformation quality (based on losses)
- List all solvers accessible for a given problem
"""

from __future__ import annotations

import importlib
import inspect
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional

import networkx as nx

from discrete_optimization.generic_tools.transformation.problem_transformation import (
    ProblemTransformation,
)
from discrete_optimization.generic_tools.transformation.transformation_metadata import (
    LossImpact,
)


[docs] class WeightingStrategy(Enum): """Strategy for weighting transformation edges.""" UNIFORM = "uniform" # All edges weight 1 (minimize hops) BY_IMPACT = "by_impact" # Weight by max loss impact BY_LOSS_COUNT = "by_loss_count" # Weight by number of losses PREFER_EXACT = "prefer_exact" # Heavily penalize lossy transformations
[docs] @dataclass class TransformationEdge: """Edge in transformation graph representing a transformation.""" source_problem: str # Problem type name target_problem: str # Problem type name transformation_class: type[ProblemTransformation] transformation_instance: ProblemTransformation forward_exact: bool backward_exact: bool max_impact: LossImpact num_losses: int
[docs] def get_weight( self, strategy: WeightingStrategy = WeightingStrategy.UNIFORM ) -> float: """Get edge weight based on strategy. Args: strategy: Weighting strategy Returns: Edge weight (lower = better) """ if strategy == WeightingStrategy.UNIFORM: return 1.0 elif strategy == WeightingStrategy.BY_IMPACT: # Weight by impact level (0-4) impact_weights = { LossImpact.NONE: 0.0, LossImpact.MINOR: 1.0, LossImpact.MODERATE: 2.0, LossImpact.MAJOR: 5.0, LossImpact.CRITICAL: 10.0, } return impact_weights.get(self.max_impact, 10.0) elif strategy == WeightingStrategy.BY_LOSS_COUNT: # Weight by number of losses return float(self.num_losses) if self.num_losses > 0 else 0.1 elif strategy == WeightingStrategy.PREFER_EXACT: # Heavily penalize lossy transformations if self.forward_exact and self.backward_exact: return 0.1 elif self.forward_exact or self.backward_exact: return 5.0 # One direction lossy else: return 20.0 # Both directions lossy else: return 1.0
def __str__(self) -> str: """String representation.""" exact_str = "" if self.forward_exact and self.backward_exact: exact_str = " (exact)" elif self.forward_exact: exact_str = " (forward exact)" elif self.backward_exact: exact_str = " (backward exact)" else: exact_str = f" (lossy, impact: {self.max_impact.value})" return f"{self.source_problem}{self.target_problem}{exact_str}"
[docs] @dataclass class TransformationPath: """Path through transformation graph.""" problem_sequence: list[str] # Sequence of problem types transformations: list[TransformationEdge] # Edges used total_weight: float is_exact: bool # True if all transformations are exact def __str__(self) -> str: """String representation.""" path_str = " → ".join(self.problem_sequence) exact_str = ( " (exact)" if self.is_exact else f" (weight: {self.total_weight:.2f})" ) return f"{path_str}{exact_str}"
[docs] class TransformationGraph: """Graph of problem transformations. Builds and analyzes a graph where: - Nodes = problem types (e.g., "BinPackProblem", "SalbpProblem") - Edges = transformations with metadata """ def __init__(self): """Initialize empty transformation graph.""" self.graph: nx.DiGraph = nx.DiGraph() self.transformations: dict[tuple[str, str], TransformationEdge] = {}
[docs] def add_transformation( self, transformation_class: type[ProblemTransformation] ) -> None: """Add a transformation to the graph. Args: transformation_class: Transformation class to add """ # Instantiate transformation try: transformation = transformation_class() except Exception as e: print( f"Warning: Could not instantiate {transformation_class.__name__}: {e}" ) return # Extract problem type names from generic type parameters # This is a bit hacky but works for our use case source_name, target_name = self._extract_problem_types(transformation_class) if not source_name or not target_name: print( f"Warning: Could not extract types from {transformation_class.__name__}" ) return # Get metadata (only forward - backward is deprecated) forward_metadata = transformation.get_forward_metadata() # Create edge edge = TransformationEdge( source_problem=source_name, target_problem=target_name, transformation_class=transformation_class, transformation_instance=transformation, forward_exact=forward_metadata.is_exact(), backward_exact=True, # Solution mapping is always mechanical (deprecated concept) max_impact=forward_metadata.get_max_impact(), num_losses=len(forward_metadata.losses), ) # Add to graph self.graph.add_node(source_name) self.graph.add_node(target_name) self.graph.add_edge(source_name, target_name, transformation=edge) self.transformations[(source_name, target_name)] = edge
def _extract_problem_types( self, transformation_class: type[ProblemTransformation] ) -> tuple[Optional[str], Optional[str]]: """Extract source and target problem type names. Args: transformation_class: Transformation class Returns: Tuple of (source_problem_name, target_problem_name) """ # Try to get from class name convention # e.g., BinpackToSalbpTransformation -> BinPackProblem, SalbpProblem class_name = transformation_class.__name__ # Parse from name like "BinpackToSalbpTransformation" if "To" in class_name and "Transformation" in class_name: parts = class_name.replace("Transformation", "").split("To") if len(parts) == 2: source = parts[0].strip() target = parts[1].strip() # Convert to problem names (add "Problem" suffix) # Handle special cases source_problem = self._to_problem_name(source) target_problem = self._to_problem_name(target) return source_problem, target_problem return None, None def _to_problem_name(self, short_name: str) -> str: """Convert short name to problem class name. Args: short_name: Short name (e.g., "Binpack", "Salbp") Returns: Problem class name (e.g., "BinPackProblem", "SalbpProblem") """ # Special cases special_cases = { "Binpack": "BinPackProblem", "Salbp": "SalbpProblem", "RcalbpL": "RCALBPLProblem", "Rcpsp": "RcpspProblem", "Multiskill": "MultiskillRcpspProblem", "Preemptive": "PreemptiveRcpspProblem", "Fjsp": "FJobShopProblem", "Jsp": "JobShopProblem", "Facility": "FacilityProblem", "Singlebatch": "SingleBatchProcessingProblem", "Ovensched": "OvenSchedulingProblem", "WorkforceAllocation": "TeamAllocationProblem", "WorkforceScheduling": "AllocSchedulingProblem", "Coloring": "ColoringProblem", "ListColoring": "ListColoringProblem", "Tsp": "TspProblem", "Vrp": "VrpProblem", "Vrptw": "VRPTWProblem", "Gpdp": "GpdpProblem", "Top": "TeamOrienteeringProblem", } if short_name in special_cases: return special_cases[short_name] # Default: add "Problem" suffix return f"{short_name}Problem"
[docs] def find_path( self, source: str, target: str, strategy: WeightingStrategy = WeightingStrategy.UNIFORM, ) -> Optional[TransformationPath]: """Find shortest path between two problem types. Args: source: Source problem type name target: Target problem type name strategy: Weighting strategy for path finding Returns: TransformationPath if path exists, None otherwise """ # Set edge weights based on strategy for (src, tgt), edge in self.transformations.items(): self.graph[src][tgt]["weight"] = edge.get_weight(strategy) try: path = nx.shortest_path(self.graph, source, target, weight="weight") path_weight = nx.shortest_path_length( self.graph, source, target, weight="weight" ) # Extract transformations along path transformations = [] for i in range(len(path) - 1): edge = self.transformations[(path[i], path[i + 1])] transformations.append(edge) # Check if entire path is exact is_exact = all( t.forward_exact and t.backward_exact for t in transformations ) return TransformationPath( problem_sequence=path, transformations=transformations, total_weight=path_weight, is_exact=is_exact, ) except nx.NetworkXNoPath: return None
[docs] def find_all_paths( self, source: str, target: str, max_length: int = 5, ) -> list[TransformationPath]: """Find all paths between two problem types. Args: source: Source problem type name target: Target problem type name max_length: Maximum path length Returns: List of TransformationPath objects """ paths = [] try: for path_nodes in nx.all_simple_paths( self.graph, source, target, cutoff=max_length ): # Extract transformations transformations = [] total_weight = 0.0 for i in range(len(path_nodes) - 1): edge = self.transformations[(path_nodes[i], path_nodes[i + 1])] transformations.append(edge) total_weight += edge.get_weight(WeightingStrategy.BY_IMPACT) is_exact = all( t.forward_exact and t.backward_exact for t in transformations ) paths.append( TransformationPath( problem_sequence=path_nodes, transformations=transformations, total_weight=total_weight, is_exact=is_exact, ) ) except nx.NetworkXNoPath: pass return paths
[docs] def get_reachable_problems(self, source: str) -> set[str]: """Get all problem types reachable from source. Args: source: Source problem type Returns: Set of reachable problem type names """ try: return set(nx.descendants(self.graph, source)) | {source} except nx.NetworkXError: return {source}
[docs] def get_connected_components(self) -> list[set[str]]: """Get weakly connected components. Returns: List of sets, each containing connected problem types """ return list(nx.weakly_connected_components(self.graph))
[docs] def print_summary(self) -> None: """Print summary of transformation graph.""" print("Transformation Graph Summary") print("=" * 80) print(f"Nodes (problem types): {self.graph.number_of_nodes()}") print(f"Edges (transformations): {self.graph.number_of_edges()}") # Count exact vs lossy exact_count = sum( 1 for edge in self.transformations.values() if edge.forward_exact and edge.backward_exact ) print(f" - Exact transformations: {exact_count}") print(f" - Lossy transformations: {len(self.transformations) - exact_count}") # Connected components components = self.get_connected_components() print(f"\nConnected components: {len(components)}") for i, component in enumerate(components, 1): print(f" Component {i}: {', '.join(sorted(component))}")
[docs] def visualize(self, output_file: Optional[str] = None) -> None: """Visualize transformation graph. Args: output_file: Optional output file path (requires matplotlib/graphviz) """ try: import matplotlib.pyplot as plt pos = nx.spring_layout(self.graph, k=2, iterations=50) # Color nodes node_colors = [] for node in self.graph.nodes(): # Color by reachability reachable = len(self.get_reachable_problems(node)) node_colors.append(reachable) # Draw nx.draw_networkx_nodes( self.graph, pos, node_color=node_colors, cmap="Blues", node_size=1000 ) nx.draw_networkx_labels(self.graph, pos, font_size=8) # Draw edges with colors based on exactness exact_edges = [ (u, v) for u, v, data in self.graph.edges(data=True) if data["transformation"].forward_exact and data["transformation"].backward_exact ] lossy_edges = [ (u, v) for u, v, data in self.graph.edges(data=True) if not ( data["transformation"].forward_exact and data["transformation"].backward_exact ) ] nx.draw_networkx_edges( self.graph, pos, edgelist=exact_edges, edge_color="green", width=2 ) nx.draw_networkx_edges( self.graph, pos, edgelist=lossy_edges, edge_color="red", width=1, style="dashed", ) plt.title("Transformation Graph (green=exact, red=lossy)") plt.axis("off") if output_file: plt.savefig(output_file, dpi=150, bbox_inches="tight") print(f"Saved visualization to {output_file}") else: plt.show() except ImportError: print( "Matplotlib required for visualization. Install with: pip install matplotlib" )
[docs] def discover_transformations( base_module: str = "discrete_optimization", ) -> TransformationGraph: """Discover all transformations in codebase. Args: base_module: Base module to search (default: discrete_optimization) Returns: TransformationGraph with all discovered transformations """ graph = TransformationGraph() # Import base module try: base = importlib.import_module(base_module) except ImportError as e: print(f"Could not import base module {base_module}: {e}") return graph # Walk through all submodules base_path = Path(base.__file__).parent # Helper function to recursively find all transformations directories def find_transformation_modules(root_path: Path, module_prefix: str) -> list[str]: """Find all transformation module paths recursively.""" transformation_modules = [] for item in root_path.iterdir(): if ( not item.is_dir() or item.name.startswith("_") or item.name == "__pycache__" ): continue # Check if this directory has a transformations subdir transformations_dir = item / "transformations" if transformations_dir.exists(): module_path = f"{module_prefix}.{item.name}.transformations" transformation_modules.append(module_path) # Recursively search subdirectories (for cases like workforce/scheduling/transformations) sub_modules = find_transformation_modules( item, f"{module_prefix}.{item.name}" ) transformation_modules.extend(sub_modules) return transformation_modules # Find all transformation modules (including nested ones) transformation_module_paths = find_transformation_modules(base_path, base_module) # Import and process each transformation module for module_path in transformation_module_paths: try: transformations_module = importlib.import_module(module_path) # Find all transformation classes for name, obj in inspect.getmembers(transformations_module): if ( inspect.isclass(obj) and issubclass(obj, ProblemTransformation) and obj != ProblemTransformation and not inspect.isabstract(obj) ): graph.add_transformation(obj) except ImportError as e: # Skip modules that can't be imported (missing dependencies) pass except Exception as e: print(f"Error processing {module_path}: {e}") return graph
[docs] def build_transformation_graph() -> TransformationGraph: """Build transformation graph from codebase. Returns: TransformationGraph with all discovered transformations """ return discover_transformations()