Source code for discrete_optimization.singlemachine.solvers.cpsat

#  Copyright (c) 2025 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
import logging
from typing import Any

from ortools.sat.python.cp_model import (
    CpSolverSolutionCallback,
    LinearExpr,
    LinearExprT,
)

from discrete_optimization.generic_tasks_tools.enums import StartOrEnd
from discrete_optimization.generic_tasks_tools.solvers.cpsat import (
    SchedulingCpSatSolver,
)
from discrete_optimization.generic_tools.do_problem import Solution
from discrete_optimization.generic_tools.do_solver import WarmstartMixin
from discrete_optimization.singlemachine.problem import (
    Task,
    WeightedTardinessProblem,
    WTSolution,
)

logger = logging.getLogger(__name__)


[docs] class CpsatWTSolver(SchedulingCpSatSolver[Task], WarmstartMixin): problem: WeightedTardinessProblem variables: dict
[docs] def retrieve_solution(self, cpsolvercb: CpSolverSolutionCallback) -> Solution: schedule = [] for i in range(self.problem.num_jobs): st = cpsolvercb.value(self.variables["starts"][i]) end = st + self.problem.processing_times[i] schedule.append((st, end)) logger.info(f"Obj = {cpsolvercb.objective_value}") return WTSolution(problem=self.problem, schedule=schedule)
[docs] def init_model(self, **args: Any) -> None: self.variables = {"starts": [], "intervals": [], "lateness": []} super().init_model(**args) max_time = sum(self.problem.processing_times) for i in range(self.problem.num_jobs): start = self.cp_model.NewIntVar(lb=0, ub=max_time, name=f"start_{i}") interval = self.cp_model.NewFixedSizeIntervalVar( start=start, size=self.problem.processing_times[i], name=f"interval_{i}" ) self.variables["starts"].append(start) self.variables["intervals"].append(interval) lateness = self.cp_model.NewIntVar( lb=0, ub=max(max_time - self.problem.due_dates[i], 0), name=f"lateness_{i}", ) self.cp_model.AddMaxEquality( lateness, [ start + self.problem.processing_times[i] - self.problem.due_dates[i], 0, ], ) self.variables["lateness"].append(lateness) self.cp_model.AddNoOverlap(self.variables["intervals"]) self.cp_model.Minimize( LinearExpr.weighted_sum(self.variables["lateness"], self.problem.weights) )
[docs] def set_warm_start(self, solution: WTSolution) -> None: self.cp_model.clear_hints() for i in range(self.problem.num_jobs): self.cp_model.add_hint(self.variables["starts"][i], solution.schedule[i][0]) self.cp_model.add_hint( self.variables["lateness"][i], max(0, solution.schedule[i][1] - self.problem.due_dates[i]), )
[docs] def get_task_start_or_end_variable( self, task: Task, start_or_end: StartOrEnd ) -> LinearExprT: if start_or_end == StartOrEnd.START: return self.variables["starts"][task] if start_or_end == StartOrEnd.END: return self.variables["starts"][task] + self.problem.processing_times[task]