# Copyright (c) 2022 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import string
from collections.abc import Hashable, KeysView
from typing import Any, Optional, Union
import networkx as nx
[docs]
class Graph:
def __init__(
self,
nodes: list[tuple[Hashable, dict[str, Any]]],
edges: list[tuple[Hashable, Hashable, dict[str, Any]]],
undirected: bool = True,
compute_predecessors: bool = True,
):
self.nodes = nodes
self.edges = edges
self.undirected = undirected
self.neighbors_dict: dict[Hashable, set[Hashable]] = {}
self.predecessors_dict: dict[Hashable, set[Hashable]] = {}
self.edges_infos_dict: dict[tuple[Hashable, Hashable], dict[str, Any]] = {}
self.nodes_infos_dict: dict[Hashable, dict[str, Any]] = {}
self.build_nodes_infos_dict()
self.build_edges()
self.nodes_name = list(self.nodes_infos_dict)
self.graph_nx = self.to_networkx()
self.full_predecessors: Optional[dict[Hashable, set[Hashable]]]
self.full_successors: Optional[dict[Hashable, set[Hashable]]]
if compute_predecessors:
self.full_predecessors = self.ancestors_map()
self.full_successors = self.descendants_map()
else:
self.full_predecessors = None
self.full_successors = None
[docs]
def get_edges(self) -> KeysView[tuple[Hashable, Hashable]]:
return self.edges_infos_dict.keys()
[docs]
def get_nodes(self) -> list[Hashable]:
return self.nodes_name
[docs]
def build_nodes_infos_dict(self) -> None:
for n, d in self.nodes:
self.nodes_infos_dict[n] = d
[docs]
def build_edges(self) -> None:
for n1, n2, d in self.edges:
self.edges_infos_dict[(n1, n2)] = d
if n2 not in self.predecessors_dict:
self.predecessors_dict[n2] = set()
if n1 not in self.neighbors_dict:
self.neighbors_dict[n1] = set()
self.predecessors_dict[n2].add(n1)
self.neighbors_dict[n1].add(n2)
if self.undirected:
if n1 not in self.predecessors_dict:
self.predecessors_dict[n1] = set()
if n2 not in self.neighbors_dict:
self.neighbors_dict[n2] = set()
self.predecessors_dict[n1].add(n2)
self.neighbors_dict[n2].add(n1)
self.edges_infos_dict[(n2, n1)] = d
[docs]
def get_neighbors(self, node: Hashable) -> set[Hashable]:
return self.neighbors_dict.get(node, set())
[docs]
def get_predecessors(self, node: Hashable) -> set[Hashable]:
return self.predecessors_dict.get(node, set())
[docs]
def get_attr_node(self, node: Hashable, attr: str) -> Any:
return self.nodes_infos_dict.get(node, {}).get(attr, None)
[docs]
def get_attr_edge(self, node1: Hashable, node2: Hashable, attr: str) -> Any:
return self.edges_infos_dict.get((node1, node2), {}).get(attr, None)
[docs]
def to_networkx(self) -> nx.DiGraph:
graph_nx = nx.DiGraph() if not self.undirected else nx.Graph()
graph_nx.add_nodes_from(self.nodes)
graph_nx.add_edges_from(self.edges)
return graph_nx
[docs]
def check_loop(self) -> Optional[list[tuple[Hashable, Hashable, str]]]:
try:
cycles = nx.find_cycle(self.graph_nx, orientation="original")
except:
cycles = None
return cycles
[docs]
def precedessors_nodes(self, n: Hashable) -> set[Hashable]:
return nx.algorithms.ancestors(self.graph_nx, n)
[docs]
def ancestors_map(self) -> dict[Hashable, set[Hashable]]:
return {
n: nx.algorithms.ancestors(self.graph_nx, n) for n in self.graph_nx.nodes()
}
[docs]
def descendants_map(self) -> dict[Hashable, set[Hashable]]:
return {
n: nx.algorithms.descendants(self.graph_nx, n)
for n in self.graph_nx.nodes()
}
[docs]
def successors_map(self) -> dict[Hashable, list[Hashable]]:
return {n: list(nx.neighbors(self.graph_nx, n)) for n in self.graph_nx.nodes()}
[docs]
def predecessors_map(self) -> dict[Hashable, list[Hashable]]:
return {n: list(self.graph_nx.predecessors(n)) for n in self.graph_nx.nodes()}
[docs]
def compute_length(
self, path: list[Hashable], attribute_name: Optional[str] = None
):
if attribute_name is None:
length = len(path) - 1
else:
length = sum(
[
self.graph_nx.edges[(i1, i2)][attribute_name]
for i1, i2 in zip(path[:-1], path[1:])
]
)
return length
[docs]
def compute_shortest_path(
self, source: Hashable, target: Hashable, attribute_name: Optional[str] = None
):
path = nx.dijkstra_path(
G=self.graph_nx, source=source, target=target, weight=attribute_name
)
length = self.compute_length(path=path, attribute_name=attribute_name)
return path, length
[docs]
def compute_all_shortest_path(
self, attribute_name: Optional[str] = None
) -> dict[Hashable, dict[Hashable, tuple[list[Hashable], float]]]:
all_path = nx.all_pairs_dijkstra_path(G=self.graph_nx, weight=attribute_name)
dict_path_and_distance = {}
for source, dict_path in all_path:
dict_path_and_distance[source] = {}
for target in dict_path:
length = self.compute_length(
path=dict_path[target], attribute_name=attribute_name
)
dict_path_and_distance[source][target] = (dict_path[target], length)
return dict_path_and_distance
[docs]
def from_networkx(
graph_nx: Union[nx.DiGraph, nx.Graph],
undirected: Optional[bool] = None,
compute_predecessors: bool = False,
):
return Graph(
nodes=[(n, graph_nx.nodes[n]) for n in graph_nx.nodes()],
edges=[(e[0], e[1], graph_nx.edges[e]) for e in graph_nx.edges()],
undirected=undirected
if undirected is not None
else not isinstance(graph_nx, nx.DiGraph),
compute_predecessors=compute_predecessors,
)
# this method is implemented to bypass the fact that networkX >= 3.2 is not compatible with python 3.8
[docs]
def get_node_attributes(graph: nx.Graph, name: string, default: Any):
"""
@param graph: a nx.Graph
@param name: name of attribut of intereste
@param default: default value if no value for attribute of interest
@return: a dictionnary with for each node of graph, the attribute value corresponding
"""
if default is not None:
return {n: d.get(name, default) for n, d in graph.nodes.items()}
return {n: d[name] for n, d in graph.nodes.items() if name in d}