Source code for discrete_optimization.tsp.solvers.cpsat

#  Copyright (c) 2024 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Optional

from ortools.sat.python.cp_model import CpModel

from discrete_optimization.generic_tools.do_problem import (
    ParamsObjectiveFunction,
    Problem,
    Solution,
)
from discrete_optimization.generic_tools.do_solver import WarmstartMixin
from discrete_optimization.generic_tools.ortools_cpsat_tools import (
    CpSolverSolutionCallback,
    OrtoolsCpSatSolver,
)
from discrete_optimization.tsp.problem import TspProblem, TspSolution
from discrete_optimization.tsp.solvers import TspSolver
from discrete_optimization.tsp.utils import build_matrice_distance

logger = logging.getLogger(__name__)


[docs] class CpSatTspSolver(OrtoolsCpSatSolver, TspSolver, WarmstartMixin): def __init__( self, problem: Problem, params_objective_function: Optional[ParamsObjectiveFunction] = None, **kwargs: Any, ): super().__init__(problem, params_objective_function, **kwargs) self.variables = {} self.distance_matrix = build_matrice_distance( self.problem.node_count, method=self.problem.evaluate_function_indexes, ) self.distance_matrix[self.problem.end_index, self.problem.start_index] = 0
[docs] def set_warm_start(self, solution: TspSolution) -> None: """Make the solver warm start from the given solution.""" self.cp_model.clear_hints() hints = {} num_nodes = self.problem.node_count all_nodes = range(num_nodes) for i in all_nodes: for j in all_nodes: if i == j: continue hints[i, j] = 0 current_node = self.problem.start_index for next_node in solution.permutation: hints[current_node, next_node] = 1 current_node = next_node # end the loop last_node = solution.permutation[-1] if self.problem.end_index not in solution.permutation: # end node not in last 2 nodes of permutation hints[last_node, self.problem.end_index] = 1 last_node = self.problem.end_index if self.problem.start_index not in solution.permutation: # close the cycle hints[last_node, self.problem.start_index] = 1 for i in all_nodes: for j in all_nodes: if i == j: continue self.cp_model.AddHint(self.variables["arc_literals"][i, j], hints[i, j])
[docs] def retrieve_solution(self, cpsolvercb: CpSolverSolutionCallback) -> Solution: current_node = self.problem.start_index route_is_finished = False path = [] route_distance = 0 while not route_is_finished: for i in range(self.problem.node_count): if i == current_node: continue if cpsolvercb.boolean_value( self.variables["arc_literals"][current_node, i] ): route_distance += self.distance_matrix[current_node, i] current_node = i if current_node == self.problem.start_index: route_is_finished = True break if not route_is_finished: path.append(current_node) logger.info(f"Recomputed sol length = {route_distance}") return TspSolution( problem=self.problem, start_index=self.problem.start_index, end_index=self.problem.end_index, permutation=path if self.problem.start_index == self.problem.end_index else path[:-1], )
[docs] def init_model(self, **args: Any) -> None: model = CpModel() num_nodes = self.problem.node_count all_nodes = range(num_nodes) obj_vars = [] obj_coeffs = [] arcs = [] arc_literals = {} for i in all_nodes: for j in all_nodes: if i == j: continue lit = model.new_bool_var(f"{j} follows {i}") arcs.append((i, j, lit)) arc_literals[i, j] = lit obj_vars.append(lit) obj_coeffs.append(int(self.distance_matrix[i, j])) model.add_circuit(arcs) if self.problem.start_index != self.problem.end_index: model.Add( arc_literals[self.problem.end_index, self.problem.start_index] == True ) model.minimize(sum(obj_vars[i] * obj_coeffs[i] for i in range(len(obj_vars)))) self.variables["arc_literals"] = arc_literals self.cp_model = model