# Copyright (c) 2026 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""DAG-based solver workflow (SolverGraph)."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Optional
import networkx as nx
from discrete_optimization.generic_tools.do_problem import (
Problem,
Solution,
build_aggreg_function_and_params_objective,
)
from discrete_optimization.generic_tools.do_solver import SolverDO, WarmstartMixin
from discrete_optimization.generic_tools.hyperparameters.hyperparameter import SubBrick
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
from discrete_optimization.generic_tools.transformation.problem_transformation import (
ProblemTransformation,
)
logger = logging.getLogger(__name__)
[docs]
@dataclass
class NodeData:
"""Data flowing through graph nodes.
Attributes:
problem: Problem instance (required for most nodes)
result: ResultStorage from solver nodes (optional)
solution: Single best solution for warmstart/kwargs extraction (optional)
"""
problem: Optional[Problem] = None
result: Optional[ResultStorage] = None
solution: Optional[Solution] = None
def __post_init__(self):
"""Auto-extract solution from result if available."""
if self.solution is None and self.result is not None and len(self.result) > 0:
self.solution = self.result.get_best_solution()
[docs]
class GraphNode(ABC):
"""Base class for nodes in the solver graph.
A node takes inputs, executes an operation, and produces outputs.
Data flows through the graph as NodeData objects.
"""
node_id: str
inputs: dict[str, NodeData] # upstream_node_id -> NodeData
output: Optional[NodeData] # Output data
def __init__(self, node_id: str):
"""Initialize graph node.
Args:
node_id: Unique identifier for this node
"""
self.node_id = node_id
self.inputs = {}
self.output = None
[docs]
@abstractmethod
def execute(self, **kwargs: Any) -> NodeData:
"""Execute the node's operation.
Args:
**kwargs: Execution parameters (e.g., time_limit)
Returns:
NodeData with problem, result, and/or solution
"""
...
[docs]
def can_execute(self) -> bool:
"""Check if all required inputs are available.
Default: can execute if we have at least one input.
Override in subclasses for specific requirements.
Returns:
True if node can execute
"""
return len(self.inputs) > 0
def __repr__(self) -> str:
"""String representation."""
return f"{self.__class__.__name__}(id={self.node_id})"
[docs]
class RootNode(GraphNode):
"""Virtual root node that provides the source problem."""
problem: Problem
def __init__(self, node_id: str, problem: Problem):
"""Initialize root node.
Args:
node_id: Node identifier
problem: Source problem
"""
super().__init__(node_id)
self.problem = problem
[docs]
def execute(self, **kwargs: Any) -> NodeData:
"""Return the source problem.
Returns:
NodeData with problem
"""
return NodeData(problem=self.problem)
[docs]
def can_execute(self) -> bool:
"""Root can always execute."""
return True
[docs]
class SolverNode(GraphNode):
"""Node that runs a solver."""
solver_brick: SubBrick
solver: Optional[SolverDO]
problem: Optional[Problem]
def __init__(self, node_id: str, solver_brick: SubBrick):
"""Initialize solver node.
Args:
node_id: Node identifier
solver_brick: Solver specification
"""
super().__init__(node_id)
self.solver_brick = solver_brick
self.solver = None
self.problem = None
[docs]
def execute(self, **kwargs: Any) -> NodeData:
"""Solve the input problem.
Returns:
NodeData with result, solution, and problem
"""
# Get problem from inputs
problem = self.get_input_problem()
# Get solution for warmstart (if available)
warmstart_solution = self.get_input_solution()
# Update kwargs using kwargs_from_solution (like SequentialMetasolver)
kwargs_updated = dict(self.solver_brick.kwargs)
if (
self.solver_brick.kwargs_from_solution is not None
and warmstart_solution is not None
):
kwargs_updated.update(
{
k: fun(warmstart_solution)
for k, fun in self.solver_brick.kwargs_from_solution.items()
}
)
# Instantiate solver if needed or problem changed
if self.solver is None or self.problem != problem:
self.problem = problem
self.solver = self.solver_brick.cls(problem=problem, **kwargs_updated)
self.solver.init_model(**kwargs_updated)
# Warmstart if solution available and solver supports it
if warmstart_solution is not None and isinstance(self.solver, WarmstartMixin):
self.solver.set_warm_start(warmstart_solution)
# Solve with updated kwargs
result = self.solver.solve(**kwargs)
# Return NodeData (solution auto-extracted in __post_init__)
return NodeData(problem=problem, result=result)
[docs]
class MergeNode(GraphNode):
"""Node that merges multiple result storages."""
strategy: str # "best", "all"
def __init__(self, node_id: str, strategy: str = "best"):
"""Initialize merge node.
Args:
node_id: Node identifier
strategy: Merge strategy ("best" or "all")
"""
super().__init__(node_id)
self.strategy = strategy
[docs]
def execute(self, **kwargs: Any) -> NodeData:
"""Merge input result storages.
Returns:
NodeData with merged result and solution
"""
# Extract results from all inputs
results = []
for input_data in self.inputs.values():
if input_data.result is not None:
results.append(input_data.result)
if len(results) == 0:
raise ValueError("MergeNode requires at least one result in inputs")
# Merge based on strategy
if self.strategy == "best":
# Keep only best solution from each result
merged = ResultStorage(
list_solution_fits=[], mode_optim=results[0].mode_optim
)
for res in results:
if len(res) > 0:
best_sol, best_fit = res.get_best_solution_fit()
if best_sol is not None:
merged.append((best_sol, best_fit))
elif self.strategy == "all":
# Combine all solutions
merged = ResultStorage(
list_solution_fits=[], mode_optim=results[0].mode_optim
)
for res in results:
merged.extend(res)
else:
raise ValueError(f"Unknown merge strategy: {self.strategy}")
# Return NodeData (solution auto-extracted in __post_init__)
return NodeData(result=merged)
[docs]
def can_execute(self) -> bool:
"""Can execute if we have at least one input."""
return len(self.inputs) >= 1
[docs]
class SolverGraph:
"""DAG-based solver workflow.
Supports:
- Branching (parallel strategies)
- Merging (combine results)
- Transformations (problem conversion)
- Arbitrary directed acyclic graphs
Example (linear, like SequentialMetasolver):
# >>> graph = SolverGraph(problem)
# >>> graph.add_solver("solver1", SubBrick(cls=Solver1, kwargs={}))
# >>> graph.add_solver("solver2", SubBrick(cls=Solver2, kwargs={}))
# >>> graph.add_edge("root", "solver1")
# >>> graph.add_edge("solver1", "solver2")
# >>> result = graph.run()
#
# Example (branching):
# >>> graph = SolverGraph(problem)
# >>> graph.add_solver("cpsat", SubBrick(cls=CPSat, kwargs={}))
# >>> graph.add_solver("lp", SubBrick(cls=LP, kwargs={}))
# >>> graph.add_merge("merge", strategy="best")
# >>> graph.add_edge("root", "cpsat")
# >>> graph.add_edge("root", "lp")
# >>> graph.add_edge("cpsat", "merge")
# >>> graph.add_edge("lp", "merge")
# >>> result = graph.run()
"""
source_problem: Problem
nodes: dict[str, GraphNode]
edges: dict[str, list[str]] # node_id -> list of downstream node_ids
reverse_edges: dict[str, list[str]] # node_id -> list of upstream node_ids
# Execution state
node_outputs: dict[str, NodeData] # node_id -> NodeData
# NetworkX graph cache
_nx_graph: Optional[nx.DiGraph] # Cached NetworkX representation
def __init__(self, source_problem: Problem):
"""Initialize solver graph.
Args:
source_problem: The problem to solve
"""
self.source_problem = source_problem
self.nodes = {"root": RootNode("root", source_problem)}
self.edges = defaultdict(list)
self.reverse_edges = defaultdict(list)
self.node_outputs = {}
self._nx_graph = None
[docs]
def add_solver(self, node_id: str, solver_brick: SubBrick) -> str:
"""Add a solver node.
Args:
node_id: Unique identifier for this node
solver_brick: Solver specification
Returns:
Node ID (for chaining)
"""
if node_id in self.nodes:
logging.error(f"{node_id} already added in the graph computation")
return None
node = SolverNode(node_id, solver_brick)
self.nodes[node_id] = node
return node_id
[docs]
def add_merge(self, node_id: str, strategy: str = "best") -> str:
"""Add a merge node.
Args:
node_id: Unique identifier for this node
strategy: Merge strategy ("best" or "all")
Returns:
Node ID (for chaining)
"""
if node_id in self.nodes:
logging.error(f"{node_id} already added in the graph computation")
return None
node = MergeNode(node_id, strategy)
self.nodes[node_id] = node
return node_id
[docs]
def add_edge(self, from_node: str, to_node: str) -> None:
"""Add an edge between nodes.
Args:
from_node: Source node ID
to_node: Target node ID
Raises:
ValueError: If nodes don't exist
"""
if from_node not in self.nodes:
raise ValueError(f"Node {from_node} does not exist")
if to_node not in self.nodes:
raise ValueError(f"Node {to_node} does not exist")
self.edges[from_node].append(to_node)
self.reverse_edges[to_node].append(from_node)
# Invalidate NetworkX cache
self._nx_graph = None
def _build_networkx_graph(self) -> nx.DiGraph:
"""Build NetworkX DiGraph representation.
Returns:
NetworkX directed graph representing the solver graph
"""
if self._nx_graph is None:
self._nx_graph = nx.DiGraph()
# Add all nodes
self._nx_graph.add_nodes_from(self.nodes.keys())
# Add all edges
for from_node, to_nodes in self.edges.items():
for to_node in to_nodes:
self._nx_graph.add_edge(from_node, to_node)
return self._nx_graph
[docs]
def topological_sort(self) -> list[str]:
"""Return nodes in topological order using NetworkX.
Returns:
List of node IDs in execution order
Raises:
ValueError: If graph contains a cycle
"""
# Build/retrieve NetworkX graph
nx_graph = self._build_networkx_graph()
# Use NetworkX topological sort
try:
return list(nx.topological_sort(nx_graph))
except (nx.NetworkXError, nx.NetworkXUnfeasible) as e:
# NetworkX raises NetworkXUnfeasible for cycles
raise ValueError("Graph contains a cycle!") from e
[docs]
def run(self, **solve_kwargs: Any) -> ResultStorage:
"""Execute the graph and return final results.
Args:
**solve_kwargs: Keyword arguments passed to all solver nodes
Returns:
ResultStorage with final solutions
Raises:
RuntimeError: If execution fails
"""
# Topological sort to determine execution order
execution_order = self.topological_sort()
print(f"Execution order: {' → '.join(execution_order)}")
# Execute nodes in order
for node_id in execution_order:
node = self.nodes[node_id]
# Collect inputs from upstream nodes
for upstream_id in self.reverse_edges[node_id]:
if upstream_id not in self.node_outputs:
raise RuntimeError(f"Node {upstream_id} has not been executed")
# Pass outputs from upstream to this node's inputs
# node.inputs is a dict: upstream_id -> NodeData
node.inputs[upstream_id] = self.node_outputs[upstream_id]
# Execute node
if node.can_execute():
print(f"Executing {node_id}: {node}")
output = node.execute(**solve_kwargs)
self.node_outputs[node_id] = output
else:
raise RuntimeError(f"Node {node_id} cannot execute (missing inputs)")
# Find terminal nodes (nodes with no outgoing edges)
terminal_nodes = [
node_id
for node_id in self.nodes
if len(self.edges[node_id]) == 0 and node_id != "root"
]
if len(terminal_nodes) == 0:
raise ValueError("Graph has no terminal nodes")
# Return output from terminal node(s)
if len(terminal_nodes) == 1:
terminal_output = self.node_outputs[terminal_nodes[0]]
if terminal_output.result is not None:
return terminal_output.result
else:
raise ValueError("Terminal node has no result")
else:
# Multiple terminal nodes - merge them
results = []
for terminal_id in terminal_nodes:
terminal_output = self.node_outputs[terminal_id]
if terminal_output.result is not None:
results.append(terminal_output.result)
# Simple merge: combine all results
if len(results) > 0:
merged = ResultStorage(
list_solution_fits=[], mode_optim=results[0].mode_optim
)
for res in results:
merged.extend(res)
return merged
else:
raise ValueError("No results from terminal nodes")
[docs]
def visualize(self) -> str:
"""Create ASCII art visualization of the graph.
Returns:
String representation of the graph
"""
lines = ["SolverGraph:"]
lines.append(f" Source Problem: {type(self.source_problem).__name__}")
lines.append("")
lines.append("Nodes:")
for node_id, node in self.nodes.items():
if node_id == "root":
continue
lines.append(f" - {node_id}: {type(node).__name__}")
lines.append("")
lines.append("Edges:")
for from_node, to_nodes in self.edges.items():
for to_node in to_nodes:
lines.append(f" {from_node} → {to_node}")
return "\n".join(lines)