Source code for discrete_optimization.tsptw.problem

#  Copyright (c) 2025 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple

import numpy as np

from discrete_optimization.generic_tools.do_problem import (
    EncodingRegister,
    ModeOptim,
    ObjectiveDoc,
    ObjectiveHandling,
    ObjectiveRegister,
    Problem,
    Solution,
    TypeAttribute,
    TypeObjective,
)


[docs] class TSPTWSolution(Solution): """ Solution class for the TSP-TW problem. Attributes: problem (TSPTWProblem): The problem instance. permutation (List[int]): A list of customer node indices in the order they are visited. The depot is not included in this list. arrival_times (Dict[int, float]): A dictionary mapping each node index to its arrival time. start_service_times (Dict[int, float]): A dictionary mapping each node to the time service begins. makespan (float): The total time of the tour, from leaving the depot to returning. This is the primary objective. tw_violation (float): The total violation of time windows (sum of lateness at each node). """ def __init__( self, problem: "TSPTWProblem", permutation: List[int], arrival_times: Optional[Dict[int, float]] = None, start_service_times: Optional[Dict[int, float]] = None, makespan: Optional[float] = None, tw_violation: Optional[float] = None, ): self.problem = problem self.permutation = permutation self.arrival_times = arrival_times if arrival_times is not None else {} self.start_service_times = ( start_service_times if start_service_times is not None else {} ) self.makespan = makespan self.tw_violation = tw_violation
[docs] def copy(self) -> "TSPTWSolution": return TSPTWSolution( problem=self.problem, permutation=list(self.permutation), arrival_times=self.arrival_times.copy(), start_service_times=self.start_service_times.copy(), makespan=self.makespan, tw_violation=self.tw_violation, )
[docs] def lazy_copy(self) -> "TSPTWSolution": return TSPTWSolution( problem=self.problem, permutation=self.permutation, arrival_times=self.arrival_times, start_service_times=self.start_service_times, makespan=self.makespan, tw_violation=self.tw_violation, )
[docs] def change_problem(self, new_problem: Problem) -> None: if not isinstance(new_problem, TSPTWProblem): raise ValueError("new_problem must be a TSPTWProblem instance.") self.problem = new_problem # Invalidate evaluated metrics as they depend on the problem self.arrival_times = {} self.start_service_times = {} self.makespan = None self.tw_violation = None
def __str__(self) -> str: path_str = " -> ".join( map( str, [self.problem.depot_node] + self.permutation + [self.problem.depot_node], ) ) return ( f"Path: {path_str}\n" f"Makespan: {self.makespan:.2f}\n" f"Time Window Violation: {self.tw_violation:.2f}" )
[docs] class TSPTWProblem(Problem): """ Traveling Salesman Problem with Time Windows (TSP-TW) Problem class. """ def __init__( self, nb_nodes: int, distance_matrix: np.ndarray, time_windows: List[Tuple[int, int]], depot_node: int = 0, ): self.nb_nodes = nb_nodes self.distance_matrix = distance_matrix self.time_windows = time_windows self.depot_node = depot_node self.customers = sorted( [i for i in range(self.nb_nodes) if i != self.depot_node] ) self.nb_customers = len(self.customers)
[docs] def get_attribute_register(self) -> EncodingRegister: return EncodingRegister( { "permutation": { "name": "permutation", "type": [TypeAttribute.PERMUTATION], "n": self.nb_customers, "arr": self.customers, } } )
[docs] def get_objective_register(self) -> ObjectiveRegister: return ObjectiveRegister( objective_sense=ModeOptim.MINIMIZATION, objective_handling=ObjectiveHandling.AGGREGATE, dict_objective_to_doc={ "makespan": ObjectiveDoc( type=TypeObjective.OBJECTIVE, default_weight=1.0 ), "tw_violation": ObjectiveDoc( type=TypeObjective.PENALTY, default_weight=-1000.0 ), }, )
[docs] def evaluate(self, solution: TSPTWSolution) -> Dict[str, float]: """ Evaluates a solution by calculating the makespan and time window violations. This evaluation assumes the distance matrix D[u,v] includes the service time at node u. The timeline is calculated as follows: 1. Arrival at node v = Start of service at u + D[u,v] 2. Start of service at v = max(Arrival at v, Earliest time for v) 3. Violation at v = max(0, Start of service at v - Latest time for v) """ path = [self.depot_node] + solution.permutation # Initialize at depot current_node = self.depot_node start_service_time = 0.0 solution.start_service_times = {current_node: 0.0} solution.arrival_times = {current_node: 0.0} total_violation = 0.0 # Travel to all customers in the permutation for next_node in solution.permutation: dist = self.distance_matrix[current_node, next_node] arrival_time = start_service_time + dist earliest, latest = self.time_windows[next_node] start_service_time = max(arrival_time, earliest) violation = max(0, start_service_time - latest) total_violation += violation solution.arrival_times[next_node] = arrival_time solution.start_service_times[next_node] = start_service_time current_node = next_node # Travel back to the depot dist_to_depot = self.distance_matrix[current_node, self.depot_node] arrival_back_at_depot = start_service_time + dist_to_depot earliest_depot, latest_depot = self.time_windows[self.depot_node] # Violation for returning to the depot depot_return_violation = max(0, arrival_back_at_depot - latest_depot) total_violation += depot_return_violation solution.makespan = arrival_back_at_depot solution.tw_violation = total_violation return {"makespan": solution.makespan, "tw_violation": -solution.tw_violation}
[docs] def evaluate_from_encoding( self, int_vector: List[int], encoding_name: str ) -> Dict[str, float]: if encoding_name == "permutation": # The encoding gives a permutation of indices from 0 to N-2 # We map these indices back to the actual customer node IDs perm_customers = [self.customers[i] for i in int_vector] sol = TSPTWSolution(problem=self, permutation=perm_customers) else: raise NotImplementedError(f"Encoding '{encoding_name}' is not supported.") return self.evaluate(sol)
[docs] def satisfy(self, solution: TSPTWSolution) -> bool: if solution.tw_violation is None: self.evaluate(solution) return solution.tw_violation == 0
[docs] def get_dummy_solution(self) -> TSPTWSolution: """Returns a simple, non-random dummy solution (e.g., customers in order).""" return TSPTWSolution(problem=self, permutation=self.customers)
[docs] def get_solution_type(self) -> type: return TSPTWSolution