Source code for discrete_optimization.knapsack.solvers.cpsat

#  Copyright (c) 2024 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
import logging
from import Iterable
from typing import Any, Optional

from ortools.sat.python.cp_model import (

from discrete_optimization.generic_tools.do_problem import (
from discrete_optimization.generic_tools.do_solver import WarmstartMixin
from discrete_optimization.generic_tools.ortools_cpsat_tools import OrtoolsCpSatSolver
from discrete_optimization.knapsack.problem import KnapsackProblem, KnapsackSolution
from discrete_optimization.knapsack.solvers import KnapsackSolver

logger = logging.getLogger(__name__)

[docs] class CpSatKnapsackSolver(OrtoolsCpSatSolver, KnapsackSolver, WarmstartMixin): def __init__( self, problem: KnapsackProblem, params_objective_function: Optional[ParamsObjectiveFunction] = None, **kwargs, ): super().__init__( problem=problem, params_objective_function=params_objective_function ) self.variables: dict[str, list[IntVar]] = {}
[docs] def init_model(self, **args: Any) -> None: """Init CP model.""" model = CpModel() variables = [ model.NewBoolVar(name=f"x_{i}") for i in range(self.problem.nb_items) ] self.variables["taken"] = variables self.cp_model = model model.Add(-self._internal_weight() <= self.problem.max_capacity) self.set_lexico_objective("value")
[docs] def set_warm_start(self, solution: KnapsackSolution) -> None: """Make the solver warm start from the given solution.""" self.cp_model.clear_hints() for i in range(len(solution.list_taken)): self.cp_model.AddHint(self.variables["taken"][i], solution.list_taken[i])
def _internal_value(self) -> LinearExpr: return sum( [ self.variables["taken"][i] * self.problem.list_items[i].value for i in range(self.problem.nb_items) ] ) def _internal_weight(self) -> LinearExpr: return sum( [ -self.variables["taken"][i] * self.problem.list_items[i].weight for i in range(self.problem.nb_items) ] ) def _internal_heaviest_item(self) -> IntVar: if "heaviest_item" not in self.variables: heaviest_item_var = self.cp_model.new_int_var( name="heaviest_item", lb=0, ub=int( max( [ self.problem.list_items[i].weight for i in range(self.problem.nb_items) ] ) ), ) self.variables["heaviest_item"] = [heaviest_item_var] self.cp_model.add_max_equality( target=heaviest_item_var, exprs=[ self.variables["taken"][i] * self.problem.list_items[i].weight for i in range(self.problem.nb_items) ], ) else: heaviest_item_var = self.variables["heaviest_item"][0] return heaviest_item_var def _internal_objective(self, obj: str) -> ObjLinearExprT: internal_objective_mapping = { "value": self._internal_value, "weight": self._internal_weight, "heaviest_item": self._internal_heaviest_item, } if obj in internal_objective_mapping: return internal_objective_mapping[obj]() else: if obj == "weight_violation": raise ValueError( "weight_violation cannot be used as objective. " "Indeed, no violation is allowed with this solver." ) else: raise ValueError(f"Unknown objective '{obj}'.")
[docs] def set_lexico_objective(self, obj: str) -> None: self.cp_model.Maximize(self._internal_objective(obj))
[docs] def add_lexico_constraint(self, obj: str, value: float) -> Iterable[Constraint]: """ Args: obj: a string representing the desired objective. Should be one of `self.problem.get_objective_names()`. value: the limiting value. If the optimization direction is maximizing, this is a lower bound, else this is an upper bound. Returns: the created constraints. """ return [self.cp_model.Add(self._internal_objective(obj) >= int(value))]
[docs] @staticmethod def implements_lexico_api() -> bool: return True
[docs] def get_lexico_objectives_available(self) -> list[str]: return ["value", "weight", "heaviest_item"]
[docs] def retrieve_solution( self, cpsolvercb: CpSolverSolutionCallback ) -> KnapsackSolution: """Construct a do solution from the cpsat solver internal solution. It will be called each time the cpsat solver find a new solution. At that point, value of internal variables are accessible via `cpsolvercb.Value(VARIABLE_NAME)`. Args: cpsolvercb: the ortools callback called when the cpsat solver finds a new solution. Returns: the intermediate solution, at do format. """ taken = [int(cpsolvercb.Value(var)) for var in self.variables["taken"]] return KnapsackSolution(problem=self.problem, list_taken=taken)