# Copyright (c) 2026 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from abc import abstractmethod
from typing import Any, Iterable, Optional
from ortools.sat.python.cp_model import IntVar, LinearExprT
from discrete_optimization.generic_tasks_tools.base import Task
from discrete_optimization.generic_tasks_tools.enums import StartOrEnd
from discrete_optimization.generic_tasks_tools.scheduling import SchedulingCpSolver
from discrete_optimization.generic_tools.cp_tools import SignEnum
from discrete_optimization.generic_tools.ortools_cpsat_tools import OrtoolsCpSatSolver
[docs]
class SchedulingCpSatSolver(OrtoolsCpSatSolver, SchedulingCpSolver[Task]):
"""Base class for most ortools/cpsat solvers handling scheduling problems.
Allows to have common code.
"""
_makespan: Optional[IntVar] = None
"""Internal variable use to define the global makespan."""
_subtasks_makespan: Optional[IntVar] = None
"""Internal variable use to define the partial makespan."""
constraints_on_makespan: Optional[list[Any]] = None
"""Constraints on partial makespan so that it can be considered as the objective."""
[docs]
def init_model(self, **kwargs: Any) -> None:
"""Init cp model and reset stored variables if any."""
super().init_model(**kwargs)
self._makespan = None
self._subtasks_makespan = None
self.constraints_on_makespan = None
[docs]
@abstractmethod
def get_task_start_or_end_variable(
self, task: Task, start_or_end: StartOrEnd
) -> LinearExprT:
"""Retrieve the variable storing the start or end time of given task.
Args:
task:
start_or_end:
Returns:
"""
...
[docs]
def add_constraint_on_task(
self, task: Task, start_or_end: StartOrEnd, sign: SignEnum, time: int
) -> list[Any]:
var = self.get_task_start_or_end_variable(task=task, start_or_end=start_or_end)
return self.add_bound_constraint(var=var, sign=sign, value=time)
[docs]
def add_constraint_chaining_tasks(self, task1: Task, task2: Task) -> list[Any]:
var1 = self.get_task_start_or_end_variable(
task=task1, start_or_end=StartOrEnd.END
)
var2 = self.get_task_start_or_end_variable(
task=task2, start_or_end=StartOrEnd.START
)
return [self.cp_model.add(var1 == var2)]
def _get_makespan_var(self) -> IntVar:
"""Get the makespan variable used to track global makespan."""
if self._makespan is None:
self._makespan = self.cp_model.NewIntVar(
lb=self.get_makespan_lower_bound(),
ub=self.get_makespan_upper_bound(),
name="makespan",
)
return self._makespan
def _get_subtasks_makespan_var(self) -> IntVar:
"""Get the makespan variable used to track subtasks makespan."""
if self._subtasks_makespan is None:
self._subtasks_makespan = self.cp_model.NewIntVar(
lb=0, # lower bound for any tasks subset
ub=self.get_makespan_upper_bound(),
name="subtasks_makespan",
)
return self._subtasks_makespan
[docs]
def remove_constraints_on_objective(self) -> None:
if self.constraints_on_makespan is not None:
self.remove_constraints(self.constraints_on_makespan)
[docs]
def get_global_makespan_variable(self) -> Any:
# remove previous constraints on makespan variable from cp model
self.remove_constraints_on_objective()
# get makespan variable
makespan = self._get_makespan_var()
# update those constraints
self.constraints_on_makespan = [
self.cp_model.AddMaxEquality(
makespan,
[
self.get_task_start_or_end_variable(task, StartOrEnd.END)
for task in self.problem.get_last_tasks()
],
)
]
return makespan
[docs]
def get_subtasks_makespan_variable(self, subtasks: Iterable[Task]) -> Any:
# remove previous constraints on makespan variable from cp model
self.remove_constraints_on_objective()
# get makespan variable
makespan = self._get_subtasks_makespan_var()
# update those constraints
self.constraints_on_makespan = [
self.cp_model.AddMaxEquality(
makespan,
[
self.get_task_start_or_end_variable(task, StartOrEnd.END)
for task in subtasks
],
)
]
return makespan
[docs]
def get_subtasks_sum_end_time_variable(self, subtasks: Iterable[Task]) -> Any:
self.remove_constraints_on_objective()
return sum(
self.get_task_start_or_end_variable(task, StartOrEnd.END)
for task in subtasks
)
[docs]
def get_subtasks_sum_start_time_variable(self, subtasks: Iterable[Task]) -> Any:
self.remove_constraints_on_objective()
return sum(
self.get_task_start_or_end_variable(task, StartOrEnd.START)
for task in subtasks
)