# Copyright (c) 2024 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
from abc import abstractmethod
from typing import Any, List, Optional
import didppy as dp
from discrete_optimization.generic_tools.callbacks.callback import (
Callback,
CallbackList,
)
from discrete_optimization.generic_tools.do_problem import Solution
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.exceptions import SolveEarlyStop
from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
CategoricalHyperparameter,
)
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
solvers = {
x.__name__: x
for x in [
dp.ForwardRecursion,
dp.CABS,
dp.CAASDy,
dp.LNBS,
dp.DFBB,
dp.CBFS,
dp.ACPS,
dp.APPS,
dp.DBDFS,
dp.BreadthFirstSearch,
dp.DDLNS,
dp.WeightedAstar,
dp.ExpressionBeamSearch,
]
}
logger = logging.getLogger(__name__)
[docs]
class DpCallback:
def __init__(self, do_solver: "DpSolver", callback: Callback):
super().__init__()
self.do_solver = do_solver
self.callback = callback
self.res = do_solver.create_result_storage()
self.nb_solutions = 0
[docs]
def on_solution_callback(self, sol: dp.Solution) -> bool:
self.nb_solutions += 1
self.store_current_solution(sol)
try:
stopping = self.callback.on_step_end(
step=self.nb_solutions, res=self.res, solver=self.do_solver
)
except Exception as e:
self.do_solver.early_stopping_exception = e
stopping = True
else:
if stopping:
self.do_solver.early_stopping_exception = SolveEarlyStop(
f"{self.do_solver.__class__.__name__}.solve() stopped by user callback."
)
return stopping
[docs]
def store_current_solution(self, sol: dp.Solution):
solution = self.do_solver.retrieve_solution(sol)
fit = self.do_solver.aggreg_from_sol(solution)
self.res.append((solution, fit))
[docs]
class DpSolver(SolverDO):
early_stopping_exception: Optional[Exception] = None
model: dp.Model = None
hyperparameters = [
CategoricalHyperparameter(name="solver", choices=solvers, default=dp.CABS)
]
initial_solution: Optional[list[dp.Transition]] = None
[docs]
@abstractmethod
def init_model(self, **kwargs: Any) -> None:
...
[docs]
@abstractmethod
def retrieve_solution(self, sol: dp.Solution) -> Solution:
...
[docs]
def solve(
self,
callbacks: Optional[List[Callback]] = None,
time_limit: Optional[float] = 100.0,
retrieve_intermediate_solutions: bool = True,
**kwargs: Any,
) -> ResultStorage:
self.early_stopping_exception = None
callbacks_list = CallbackList(callbacks=callbacks)
callbacks_list.on_solve_start(solver=self)
if self.model is None:
self.init_model(**kwargs)
did_callback = DpCallback(do_solver=self, callback=callbacks_list)
kwargs = self.complete_with_default_hyperparameters(kwargs)
if self.initial_solution is not None:
kwargs["initial_solution"] = self.initial_solution
if "quiet" not in kwargs:
kwargs["quiet"] = False
solver_cls = kwargs["solver"]
try:
solver_allowed_params = inspect.signature(solver_cls).parameters
kwargs_solver = {
k: v for k, v in kwargs.items() if k in solver_allowed_params
}
except:
# Previous mode, for python<=3.9
for k in list(kwargs.keys()):
if k not in {"threads", "initial_solution"}:
kwargs.pop(k)
if k == "threads" and solver_cls in {dp.DDLNS, dp.DFBB}:
kwargs.pop(k)
kwargs_solver = kwargs
solver = solver_cls(self.model, time_limit=time_limit, **kwargs_solver)
if retrieve_intermediate_solutions:
while True:
solution, terminated = solver.search_next()
logger.info(f"Objective = {solution.cost}, {solution.is_infeasible}")
stopping = did_callback.on_solution_callback(solution)
if terminated or stopping:
break
else:
solution = solver.search()
did_callback.on_solution_callback(solution)
logger.info(f"Is optimal {solution.is_optimal}")
logger.info(f"Is infeasible {solution.is_infeasible}")
logger.info(f"Best bound {solution.best_bound}")
if self.early_stopping_exception:
if isinstance(self.early_stopping_exception, SolveEarlyStop):
logger.info(self.early_stopping_exception)
else:
raise self.early_stopping_exception
res = did_callback.res
callbacks_list.on_solve_end(res=res, solver=self)
return res