Source code for discrete_optimization.shop.solvers.cpsat

#  Copyright (c) 2026 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.

import logging
from typing import Any

from discrete_optimization.generic_tasks_tools.generic_scheduling_utils import (
    RawSolution,
)
from discrete_optimization.generic_tasks_tools.solvers.cpsat.auto import (
    GenericSchedulingAutoCpSatSolver,
)
from discrete_optimization.generic_tasks_tools.solvers.cpsat.skill import (
    WithoutSkillSchedulingCpSatSolver,
)
from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
    CategoricalHyperparameter,
)
from discrete_optimization.shop.base import (
    AnyShopSolution,
    CommonShopProblem,
    NoNonRenewableResource,
    NonSkillCumulativeResource,
    NoSkill,
    NoUnaryResource,
    Task,
)

logger = logging.getLogger(__name__)


[docs] class CommonShopCpSatSolver( GenericSchedulingAutoCpSatSolver[ Task, NoUnaryResource, NoSkill, NonSkillCumulativeResource, NoNonRenewableResource, ], WithoutSkillSchedulingCpSatSolver[ Task, NoUnaryResource, NonSkillCumulativeResource, NoUnaryResource ], ): pass
[docs] class CpSatShopSolver(CommonShopCpSatSolver): hyperparameters = [ CategoricalHyperparameter( name="duplicate_temporal_var", choices=[True, False], default=False ), CategoricalHyperparameter( name="add_cumulative_constraint", choices=[True, False], default=False ), ] problem: CommonShopProblem _max_time: int
[docs] def init_model(self, **kwargs: Any) -> None: # optional parameters kwargs = self.complete_with_default_hyperparameters(kwargs) self._max_time: int | None = kwargs.get( "max_time", None ) # update the upper bound for makespan self.duplicate_start_var_per_mode = kwargs["duplicate_temporal_var"] # whether to add cumulative constraint on top of no_overlap constraint self.use_cumulative_for_capa_1 = bool(kwargs["add_cumulative_constraint"]) # use cpm to compute start/end bounds self.use_cpm_for_task_bounds = True super().init_model(**kwargs)
[docs] def get_makespan_upper_bound(self) -> int: if self._max_time is None: return super().get_makespan_upper_bound() else: return min(self._max_time, super().get_makespan_upper_bound())
[docs] def convert_task_variables_to_solution( self, raw_sol: RawSolution[Task, NoUnaryResource, NoSkill] ) -> AnyShopSolution: schedule_and_machine = [ [ ( (task_var := raw_sol.task_variables[j, k]).start, task_var.end, self.problem.mode2machine[j, k][task_var.mode], task_var.mode, ) for k, sub_job in enumerate(job.subjobs) ] for j, job in enumerate(self.problem.list_jobs) ] return AnyShopSolution( problem=self.problem, schedule=[ [(x[0], x[1]) for x in sched_i] for sched_i in schedule_and_machine ], machine_index=[[x[2] for x in sched_i] for sched_i in schedule_and_machine], recipe_index=[[x[3] for x in sched_i] for sched_i in schedule_and_machine], )