Source code for discrete_optimization.generic_tools.callbacks.warm_start_callback

#  Copyright (c) 2025 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
import logging
from typing import Optional

from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.do_solver import SolverDO, WarmstartMixin
from discrete_optimization.generic_tools.lexico_tools import LexicoSolver
from discrete_optimization.generic_tools.lns_cp import BaseLnsCp
from discrete_optimization.generic_tools.result_storage.result_storage import (
    ResultStorage,
)

logger = logging.getLogger(__name__)


[docs] class WarmStartCallback(Callback): def __init__( self, warm_start_best_solution: bool = True, warm_start_last_solution: bool = False, ): self.warm_start_best_solution = warm_start_best_solution self.warm_start_last_solution = warm_start_last_solution
[docs] def on_step_end( self, step: int, res: ResultStorage, solver: SolverDO ) -> Optional[bool]: solver_ = None if isinstance(solver, LexicoSolver): if isinstance(solver.subsolver, WarmstartMixin): solver_ = solver.subsolver if isinstance(solver, WarmstartMixin): solver_ = solver if isinstance(solver, BaseLnsCp): solver_ = solver.subsolver if solver_ is not None: sol = None if self.warm_start_best_solution: sol, _ = res.get_best_solution_fit() if self.warm_start_last_solution: sol, _ = res[-1] solver_.set_warm_start(sol) logger.info(f"Warm-start done")
[docs] class WarmStartCallbackLastRun(WarmStartCallback): """Only works for Cp-sat solver"""
[docs] def on_step_end( self, step: int, res: ResultStorage, solver: SolverDO ) -> Optional[bool]: from discrete_optimization.generic_tools.ortools_cpsat_tools import ( OrtoolsCpSatSolver, ) solver_ = None if isinstance(solver, LexicoSolver): if isinstance(solver.subsolver, WarmstartMixin): solver_ = solver.subsolver if isinstance(solver, WarmstartMixin): solver_ = solver if isinstance(solver, BaseLnsCp): solver_ = solver.subsolver solver_: OrtoolsCpSatSolver if solver_.solver is not None: solver_.set_warm_start_from_previous_run() else: super().on_step_end(step, res, solver)