Source code for discrete_optimization.tsptw.problem

#  Copyright (c) 2025 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Dict, List, Optional, Tuple

import numpy as np

from discrete_optimization.generic_tools.do_problem import (
    ModeOptim,
    ObjectiveDoc,
    ObjectiveHandling,
    ObjectiveRegister,
    Problem,
    Solution,
    TypeObjective,
)
from discrete_optimization.generic_tools.encoding_register import (
    EncodingRegister,
    Permutation,
)


[docs] class TSPTWSolution(Solution): """ Solution class for the TSP-TW problem. Attributes: problem (TSPTWProblem): The problem instance. permutation (List[int]): A list of customer node indices in the order they are visited. The depot is not included in this list. arrival_times (Dict[int, float]): A dictionary mapping each node index to its arrival time. start_service_times (Dict[int, float]): A dictionary mapping each node to the time service begins. makespan (float): The total time of the tour, from leaving the depot to returning. This is the primary objective. tw_violation (float): The total violation of time windows (sum of lateness at each node). """ problem: TSPTWProblem def __init__( self, problem: TSPTWProblem, permutation: List[int], arrival_times: Optional[Dict[int, float]] = None, start_service_times: Optional[Dict[int, float]] = None, makespan: Optional[float] = None, tw_violation: Optional[float] = None, ): super().__init__(problem=problem) self.permutation = permutation self.arrival_times = arrival_times if arrival_times is not None else {} self.start_service_times = ( start_service_times if start_service_times is not None else {} ) self.makespan = makespan self.tw_violation = tw_violation
[docs] def copy(self) -> TSPTWSolution: return TSPTWSolution( problem=self.problem, permutation=list(self.permutation), arrival_times=self.arrival_times.copy(), start_service_times=self.start_service_times.copy(), makespan=self.makespan, tw_violation=self.tw_violation, )
[docs] def lazy_copy(self) -> TSPTWSolution: return TSPTWSolution( problem=self.problem, permutation=self.permutation, arrival_times=self.arrival_times, start_service_times=self.start_service_times, makespan=self.makespan, tw_violation=self.tw_violation, )
[docs] def change_problem(self, new_problem: Problem) -> None: super().change_problem(new_problem=new_problem) # Invalidate evaluated metrics as they depend on the problem self.arrival_times = {} self.start_service_times = {} self.makespan = None self.tw_violation = None
def __str__(self) -> str: path_str = " -> ".join( map( str, [self.problem.depot_node] + self.permutation + [self.problem.depot_node], ) ) return ( f"Path: {path_str}\n" f"Makespan: {self.makespan:.2f}\n" f"Time Window Violation: {self.tw_violation:.2f}" )
[docs] class TSPTWProblem(Problem): """ Traveling Salesman Problem with Time Windows (TSP-TW) Problem class. """ def __init__( self, nb_nodes: int, distance_matrix: np.ndarray, time_windows: List[Tuple[int, int]], depot_node: int = 0, ): self.nb_nodes = nb_nodes self.distance_matrix = distance_matrix self.time_windows = time_windows self.depot_node = depot_node self.customers = sorted( [i for i in range(self.nb_nodes) if i != self.depot_node] ) self.nb_customers = len(self.customers)
[docs] def get_attribute_register(self) -> EncodingRegister: return EncodingRegister({"permutation": Permutation(range=self.customers)})
[docs] def get_objective_register(self) -> ObjectiveRegister: return ObjectiveRegister( objective_sense=ModeOptim.MINIMIZATION, objective_handling=ObjectiveHandling.AGGREGATE, dict_objective_to_doc={ "makespan": ObjectiveDoc( type=TypeObjective.OBJECTIVE, default_weight=1.0 ), "tw_violation": ObjectiveDoc( type=TypeObjective.PENALTY, default_weight=-1000.0 ), }, )
[docs] def evaluate(self, variable: TSPTWSolution) -> Dict[str, float]: """ Evaluates a solution by calculating the makespan and time window violations. This evaluation assumes the distance matrix D[u,v] includes the service time at node u. The timeline is calculated as follows: 1. Arrival at node v = Start of service at u + D[u,v] 2. Start of service at v = max(Arrival at v, Earliest time for v) 3. Violation at v = max(0, Start of service at v - Latest time for v) """ # Initialize at depot current_node = self.depot_node start_service_time = 0.0 variable.start_service_times = {current_node: 0.0} variable.arrival_times = {current_node: 0.0} total_violation = 0.0 # Travel to all customers in the permutation for next_node in variable.permutation: dist = self.distance_matrix[current_node, next_node] arrival_time = start_service_time + dist earliest, latest = self.time_windows[next_node] start_service_time = max(arrival_time, earliest) violation = max(0, start_service_time - latest) total_violation += violation variable.arrival_times[next_node] = arrival_time variable.start_service_times[next_node] = start_service_time current_node = next_node # Travel back to the depot dist_to_depot = self.distance_matrix[current_node, self.depot_node] arrival_back_at_depot = start_service_time + dist_to_depot earliest_depot, latest_depot = self.time_windows[self.depot_node] # Violation for returning to the depot depot_return_violation = max(0, arrival_back_at_depot - latest_depot) total_violation += depot_return_violation variable.makespan = arrival_back_at_depot variable.tw_violation = total_violation return {"makespan": variable.makespan, "tw_violation": -variable.tw_violation}
[docs] def satisfy(self, variable: TSPTWSolution) -> bool: if variable.tw_violation is None: self.evaluate(variable) return variable.tw_violation == 0
[docs] def get_dummy_solution(self) -> TSPTWSolution: """Returns a simple, non-random dummy solution (e.g., customers in order).""" return TSPTWSolution(problem=self, permutation=self.customers)
[docs] def get_solution_type(self) -> type: return TSPTWSolution