Source code for discrete_optimization.generic_tools.ortools_cpsat_tools
# Copyright (c) 2024 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from abc import abstractmethod
from collections.abc import Iterable
from typing import Any, Optional
from ortools.sat.python.cp_model import (
FEASIBLE,
INFEASIBLE,
OPTIMAL,
UNKNOWN,
Constraint,
CpModel,
)
from ortools.sat.python.cp_model import CpSolver as OrtoolsInternalCpSolver
from ortools.sat.python.cp_model import CpSolverSolutionCallback
from discrete_optimization.generic_tools.callbacks.callback import (
Callback,
CallbackList,
)
from discrete_optimization.generic_tools.cp_tools import CpSolver, ParametersCp
from discrete_optimization.generic_tools.do_problem import Solution
from discrete_optimization.generic_tools.do_solver import StatusSolver
from discrete_optimization.generic_tools.exceptions import SolveEarlyStop
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
logger = logging.getLogger(__name__)
[docs]
class OrtoolsCpSatSolver(CpSolver):
"""Generic ortools cp-sat solver."""
cp_model: Optional[CpModel] = None
solver: Optional[OrtoolsInternalCpSolver] = None
clb: Optional[CpSolverSolutionCallback] = None
early_stopping_exception: Optional[Exception] = None
[docs]
@abstractmethod
def retrieve_solution(self, cpsolvercb: CpSolverSolutionCallback) -> Solution:
"""Construct a do solution from the cpsat solver internal solution.
It will be called each time the cpsat solver find a new solution.
At that point, value of internal variables are accessible via `cpsolvercb.Value(VARIABLE_NAME)`.
Args:
cpsolvercb: the ortools callback called when the cpsat solver finds a new solution.
Returns:
the intermediate solution, at do format.
"""
...
[docs]
def solve(
self,
callbacks: Optional[list[Callback]] = None,
parameters_cp: Optional[ParametersCp] = None,
time_limit: Optional[float] = 100.0,
ortools_cpsat_solver_kwargs: Optional[dict[str, Any]] = None,
retrieve_stats: bool = False,
**kwargs: Any,
) -> ResultStorage:
"""Solve the problem with a CpSat solver drom ortools library.
Args:
callbacks: list of callbacks used to hook into the various stage of the solve
time_limit: the solve process stops after this time limit (in seconds).
If None, no time limit is applied.
parameters_cp: parameters specific to cp solvers.
We use here only `parameters_cp.nb_process`.
ortools_cpsat_solver_kwargs: used to customize the underlying ortools solver.
Each key/value will update the corresponding attribute from the ortools.sat.python.cp_model.CpSolver
retrieve_stats: retrieve detailed stats of cpsat solving in the cpsat callback
and store it in the res object.
**kwargs: keyword arguments passed to `self.init_model()`
Returns:
A dedicated ortools callback is used to:
- update a resultstorage each time a new solution is found by the cpsat solver.
- call the user (do) callbacks at each new solution, with the possibility of early stopping if the callback return True.
This ortools callback use the method `self.retrieve_solution()` to reconstruct a do Solution from the cpsat solve internal state.
"""
self.early_stopping_exception = None
callbacks_list = CallbackList(callbacks=callbacks)
callbacks_list.on_solve_start(solver=self)
if self.cp_model is None:
self.init_model(**kwargs)
if parameters_cp is None:
parameters_cp = ParametersCp.default_cpsat()
solver = OrtoolsInternalCpSolver()
self.solver = solver
if time_limit is not None:
solver.parameters.max_time_in_seconds = time_limit
solver.parameters.num_workers = parameters_cp.nb_process
if ortools_cpsat_solver_kwargs is not None:
# customize solver
for k, v in ortools_cpsat_solver_kwargs.items():
setattr(solver.parameters, k, v)
ortools_callback = OrtoolsCpSatCallback(
do_solver=self, callback=callbacks_list, retrieve_stats=retrieve_stats
)
self.clb = ortools_callback
status = solver.Solve(self.cp_model, ortools_callback)
self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status)
if self.early_stopping_exception:
if isinstance(self.early_stopping_exception, SolveEarlyStop):
logger.info(self.early_stopping_exception)
else:
raise self.early_stopping_exception
res = ortools_callback.res
callbacks_list.on_solve_end(res=res, solver=self)
return res
[docs]
def remove_constraints(self, constraints: Iterable[Any]) -> None:
"""Remove the internal model constraints.
Args:
constraints: constraints created for instance with `add_lexico_constraint()`
Returns:
"""
for cstr in constraints:
if not isinstance(cstr, Constraint):
raise RuntimeError()
cstr.proto.Clear()
[docs]
class OrtoolsCpSatCallback(CpSolverSolutionCallback):
def __init__(
self,
do_solver: OrtoolsCpSatSolver,
callback: Callback,
retrieve_stats: bool = False,
):
super().__init__()
self.do_solver = do_solver
self.callback = callback
self.retrieve_stats = retrieve_stats
self.res = do_solver.create_result_storage()
if retrieve_stats:
self.res.stats = []
self.nb_solutions = 0
[docs]
def on_solution_callback(self) -> None:
self.store_current_solution()
self.nb_solutions += 1
# end of step callback: stopping?
try:
stopping = self.callback.on_step_end(
step=self.nb_solutions, res=self.res, solver=self.do_solver
)
except Exception as e:
self.do_solver.early_stopping_exception = e
stopping = True
else:
if stopping:
self.do_solver.early_stopping_exception = SolveEarlyStop(
f"{self.do_solver.__class__.__name__}.solve() stopped by user callback."
)
if stopping:
self.StopSearch()
[docs]
def store_current_solution(self):
sol = self.do_solver.retrieve_solution(cpsolvercb=self)
fit = self.do_solver.aggreg_from_sol(sol)
self.res.append((sol, fit))
if self.retrieve_stats:
self.res.stats.append(
{
"bound": self.BestObjectiveBound(),
"obj": self.ObjectiveValue(),
"time": self.UserTime(),
"num_conflicts": self.NumConflicts(),
}
)
[docs]
def cpstatus_to_dostatus(status_from_cpsat) -> StatusSolver:
"""
:param status_from_cpsat: either [UNKNOWN,INFEASIBLE,OPTIMAL,FEASIBLE] from ortools.cp api.
:return: Status
"""
if status_from_cpsat == UNKNOWN:
return StatusSolver.UNKNOWN
if status_from_cpsat == INFEASIBLE:
return StatusSolver.UNSATISFIABLE
if status_from_cpsat == OPTIMAL:
return StatusSolver.OPTIMAL
if status_from_cpsat == FEASIBLE:
return StatusSolver.SATISFIED