Source code for discrete_optimization.tsptw.solvers.cpsat

#  Copyright (c) 2025 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional

from ortools.sat.python.cp_model import CpModel, CpSolverSolutionCallback

from discrete_optimization.generic_tools.do_problem import ParamsObjectiveFunction
from discrete_optimization.generic_tools.do_solver import WarmstartMixin
from discrete_optimization.generic_tools.ortools_cpsat_tools import OrtoolsCpSatSolver
from discrete_optimization.tsptw.problem import TSPTWProblem, TSPTWSolution

logger = logging.getLogger(__name__)


[docs] class CpSatTSPTWSolver(OrtoolsCpSatSolver, WarmstartMixin): """ CP-SAT solver for the Traveling Salesman Problem with Time Windows. This solver uses a circuit constraint to ensure a valid tour and enforces time window constraints through implications based on the selected arcs. Attributes: problem (TSPTWProblem): The TSP-TW problem instance. variables (Dict[str, Any]): A dictionary to store the CP-SAT model variables, including arc variables ('x_arc'), time variables ('t_time'), and the makespan variable. """ problem: TSPTWProblem def __init__( self, problem: TSPTWProblem, params_objective_function: Optional[ParamsObjectiveFunction] = None, **kwargs, ): super().__init__( problem=problem, params_objective_function=params_objective_function ) self.variables: Dict[str, Any] = {}
[docs] def init_model(self, scaling_factor: float = 10.0, **kwargs: Any) -> None: """Initialise the CP-SAT model.""" self.cp_model = CpModel() n = self.problem.nb_nodes depot = self.problem.depot_node # --- Create variables --- # Arc variables: x_arc[i, j] is true if the tour includes the arc from i to j x_arc = { (i, j): self.cp_model.NewBoolVar(f"x_{i},{j}") for i in range(n) for j in range(n) if i != j } # Time variables: t_time[i] is the time at which service starts at node i t_time = [ self.cp_model.NewIntVar( lb=int(scaling_factor * self.problem.time_windows[i][0]), ub=int(scaling_factor * self.problem.time_windows[i][1]), name=f"t_{i}", ) for i in range(n) ] t_time_return_depot = self.cp_model.NewIntVar( lb=int(scaling_factor * self.problem.time_windows[depot][0]), ub=int(scaling_factor * self.problem.time_windows[depot][1]), name=f"t_return_depot", ) # Makespan variable: represents the total tour duration (arrival back at depot) makespan_ub = int(scaling_factor * self.problem.time_windows[depot][1]) makespan = self.cp_model.NewIntVar(0, makespan_ub, "makespan") self.variables = { "x_arc": x_arc, "t_time": t_time, "t_return": t_time_return_depot, "makespan": makespan, } # --- Add constraints --- # Build arcs list for the circuit constraint arcs = [] for i in range(n): for j in range(n): if i != j: arcs.append((i, j, x_arc[i, j])) # Add a single tour visiting all nodes self.cp_model.AddCircuit(arcs) # Fix the start time at the depot self.cp_model.Add( t_time[depot] == int(scaling_factor * self.problem.time_windows[depot][0]) ) # Time window propagation constraints for i in range(n): for j in range(n): if i == j: continue # The time to get from i to j, including service time at i travel_and_service_time = int( scaling_factor * self.problem.distance_matrix[i, j] ) # If arc (i,j) is taken, then t_j must be after t_i + travel if j == self.problem.depot_node: self.cp_model.Add( t_time_return_depot >= t_time[i] + travel_and_service_time ).OnlyEnforceIf(x_arc[i, j]) else: self.cp_model.Add( t_time[j] >= t_time[i] + travel_and_service_time ).OnlyEnforceIf(x_arc[i, j]) # Makespan constraints: if arc (i, depot) is taken, makespan is at least t_i + travel for i in range(n): if i == depot: continue travel_to_depot = int( scaling_factor * self.problem.distance_matrix[i, depot] ) self.cp_model.Add(makespan >= t_time[i] + travel_to_depot).OnlyEnforceIf( x_arc[i, depot] ) # --- Set objective --- self.cp_model.Minimize(makespan) logger.info("CP-SAT model initialized.")
[docs] def retrieve_solution(self, cpsolvercb: CpSolverSolutionCallback) -> TSPTWSolution: """ Build a TSPTWSolution from the CP-SAT solver's callback. Args: cpsolvercb: The ortools callback object containing the current solution. Returns: The current solution in the TSPTWSolution format. """ # Reconstruct the path from the active arc variables permutation = [] current_node = self.problem.depot_node for _ in range(self.problem.nb_nodes - 1): for j in range(self.problem.nb_nodes): if current_node == j: continue if cpsolvercb.Value(self.variables["x_arc"][current_node, j]): permutation.append(j) current_node = j break return TSPTWSolution(problem=self.problem, permutation=permutation)
[docs] def set_warm_start(self, solution: TSPTWSolution) -> None: """ Provides a warm start hint to the CP-SAT solver from an existing solution. Args: solution: A TSPTWSolution object. """ if self.cp_model is None: self.init_model() self.cp_model.ClearHints() logger.info("Setting warm start from solution.") # Hint arc variables path = ( [self.problem.depot_node] + solution.permutation + [self.problem.depot_node] ) for i in range(len(path) - 1): u, v = path[i], path[i + 1] if (u, v) in self.variables["x_arc"]: self.cp_model.AddHint(self.variables["x_arc"][u, v], 1) # Hint time variables if they have been calculated if solution.start_service_times: for i in range(self.problem.nb_nodes): if i in solution.start_service_times: self.cp_model.AddHint( self.variables["t_time"][i], int(solution.start_service_times[i]), )