Source code for discrete_optimization.generic_tools.graph_api

#  Copyright (c) 2022 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.
import string
from collections.abc import Hashable, KeysView
from typing import Any, Optional, Union

import networkx as nx


[docs] class Graph: def __init__( self, nodes: list[tuple[Hashable, dict[str, Any]]], edges: list[tuple[Hashable, Hashable, dict[str, Any]]], undirected: bool = True, compute_predecessors: bool = True, ): self.nodes = nodes self.edges = edges self.undirected = undirected self.neighbors_dict: dict[Hashable, set[Hashable]] = {} self.predecessors_dict: dict[Hashable, set[Hashable]] = {} self.edges_infos_dict: dict[tuple[Hashable, Hashable], dict[str, Any]] = {} self.nodes_infos_dict: dict[Hashable, dict[str, Any]] = {} self.build_nodes_infos_dict() self.build_edges() self.nodes_name = list(self.nodes_infos_dict) self.graph_nx = self.to_networkx() self.full_predecessors: Optional[dict[Hashable, set[Hashable]]] self.full_successors: Optional[dict[Hashable, set[Hashable]]] if compute_predecessors: self.full_predecessors = self.ancestors_map() self.full_successors = self.descendants_map() else: self.full_predecessors = None self.full_successors = None
[docs] def get_edges(self) -> KeysView[tuple[Hashable, Hashable]]: return self.edges_infos_dict.keys()
[docs] def get_nodes(self) -> list[Hashable]: return self.nodes_name
[docs] def build_nodes_infos_dict(self) -> None: for n, d in self.nodes: self.nodes_infos_dict[n] = d
[docs] def build_edges(self) -> None: for n1, n2, d in self.edges: self.edges_infos_dict[(n1, n2)] = d if n2 not in self.predecessors_dict: self.predecessors_dict[n2] = set() if n1 not in self.neighbors_dict: self.neighbors_dict[n1] = set() self.predecessors_dict[n2].add(n1) self.neighbors_dict[n1].add(n2) if self.undirected: if n1 not in self.predecessors_dict: self.predecessors_dict[n1] = set() if n2 not in self.neighbors_dict: self.neighbors_dict[n2] = set() self.predecessors_dict[n1].add(n2) self.neighbors_dict[n2].add(n1) self.edges_infos_dict[(n2, n1)] = d
[docs] def get_neighbors(self, node: Hashable) -> set[Hashable]: return self.neighbors_dict.get(node, set())
[docs] def get_predecessors(self, node: Hashable) -> set[Hashable]: return self.predecessors_dict.get(node, set())
[docs] def get_attr_node(self, node: Hashable, attr: str) -> Any: return self.nodes_infos_dict.get(node, {}).get(attr, None)
[docs] def get_attr_edge(self, node1: Hashable, node2: Hashable, attr: str) -> Any: return self.edges_infos_dict.get((node1, node2), {}).get(attr, None)
[docs] def to_networkx(self) -> nx.DiGraph: graph_nx = nx.DiGraph() if not self.undirected else nx.Graph() graph_nx.add_nodes_from(self.nodes) graph_nx.add_edges_from(self.edges) return graph_nx
[docs] def check_loop(self) -> Optional[list[tuple[Hashable, Hashable, str]]]: try: cycles = nx.find_cycle(self.graph_nx, orientation="original") except: cycles = None return cycles
[docs] def precedessors_nodes(self, n: Hashable) -> set[Hashable]: return nx.algorithms.ancestors(self.graph_nx, n)
[docs] def ancestors_map(self) -> dict[Hashable, set[Hashable]]: return { n: nx.algorithms.ancestors(self.graph_nx, n) for n in self.graph_nx.nodes() }
[docs] def descendants_map(self) -> dict[Hashable, set[Hashable]]: return { n: nx.algorithms.descendants(self.graph_nx, n) for n in self.graph_nx.nodes() }
[docs] def successors_map(self) -> dict[Hashable, list[Hashable]]: return {n: list(nx.neighbors(self.graph_nx, n)) for n in self.graph_nx.nodes()}
[docs] def predecessors_map(self) -> dict[Hashable, list[Hashable]]: return {n: list(self.graph_nx.predecessors(n)) for n in self.graph_nx.nodes()}
[docs] def compute_length( self, path: list[Hashable], attribute_name: Optional[str] = None ): if attribute_name is None: length = len(path) - 1 else: length = sum( [ self.graph_nx.edges[(i1, i2)][attribute_name] for i1, i2 in zip(path[:-1], path[1:]) ] ) return length
[docs] def compute_shortest_path( self, source: Hashable, target: Hashable, attribute_name: Optional[str] = None ): path = nx.dijkstra_path( G=self.graph_nx, source=source, target=target, weight=attribute_name ) length = self.compute_length(path=path, attribute_name=attribute_name) return path, length
[docs] def compute_all_shortest_path( self, attribute_name: Optional[str] = None ) -> dict[Hashable, dict[Hashable, tuple[list[Hashable], float]]]: all_path = nx.all_pairs_dijkstra_path(G=self.graph_nx, weight=attribute_name) dict_path_and_distance = {} for source, dict_path in all_path: dict_path_and_distance[source] = {} for target in dict_path: length = self.compute_length( path=dict_path[target], attribute_name=attribute_name ) dict_path_and_distance[source][target] = (dict_path[target], length) return dict_path_and_distance
[docs] def from_networkx( graph_nx: Union[nx.DiGraph, nx.Graph], undirected: Optional[bool] = None, compute_predecessors: bool = False, ): return Graph( nodes=[(n, graph_nx.nodes[n]) for n in graph_nx.nodes()], edges=[(e[0], e[1], graph_nx.edges[e]) for e in graph_nx.edges()], undirected=undirected if undirected is not None else not isinstance(graph_nx, nx.DiGraph), compute_predecessors=compute_predecessors, )
# this method is implemented to bypass the fact that networkX >= 3.2 is not compatible with python 3.8
[docs] def get_node_attributes(graph: nx.Graph, name: string, default: Any): """ @param graph: a nx.Graph @param name: name of attribut of intereste @param default: default value if no value for attribute of interest @return: a dictionnary with for each node of graph, the attribute value corresponding """ if default is not None: return {n: d.get(name, default) for n, d in graph.nodes.items()} return {n: d[name] for n, d in graph.nodes.items() if name in d}