# Copyright (c) 2022 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import string
from collections.abc import Hashable, KeysView
from typing import Any, Optional, Union
import networkx as nx
import numpy as np
[docs]
class Graph:
def __init__(
self,
nodes: list[tuple[Hashable, dict[str, Any]]],
edges: list[tuple[Hashable, Hashable, dict[str, Any]]],
undirected: bool = True,
compute_predecessors: bool = True,
):
self.nodes = nodes
self.edges = edges
self.undirected = undirected
self.neighbors_dict: dict[Hashable, set[Hashable]] = {}
self.predecessors_dict: dict[Hashable, set[Hashable]] = {}
self.edges_infos_dict: dict[tuple[Hashable, Hashable], dict[str, Any]] = {}
self.nodes_infos_dict: dict[Hashable, dict[str, Any]] = {}
self.build_nodes_infos_dict()
self.build_edges()
self.nodes_name = list(self.nodes_infos_dict)
self.graph_nx = self.to_networkx()
self.full_predecessors: Optional[dict[Hashable, set[Hashable]]]
self.full_successors: Optional[dict[Hashable, set[Hashable]]]
if compute_predecessors:
self.full_predecessors = self.ancestors_map()
self.full_successors = self.descendants_map()
else:
self.full_predecessors = None
self.full_successors = None
[docs]
def get_edges(self) -> KeysView[tuple[Hashable, Hashable]]:
return self.edges_infos_dict.keys()
[docs]
def get_nodes(self) -> list[Hashable]:
return self.nodes_name
[docs]
def build_nodes_infos_dict(self) -> None:
for n, d in self.nodes:
self.nodes_infos_dict[n] = d
[docs]
def build_edges(self) -> None:
for n1, n2, d in self.edges:
self.edges_infos_dict[(n1, n2)] = d
if n2 not in self.predecessors_dict:
self.predecessors_dict[n2] = set()
if n1 not in self.neighbors_dict:
self.neighbors_dict[n1] = set()
self.predecessors_dict[n2].add(n1)
self.neighbors_dict[n1].add(n2)
if self.undirected:
if n1 not in self.predecessors_dict:
self.predecessors_dict[n1] = set()
if n2 not in self.neighbors_dict:
self.neighbors_dict[n2] = set()
self.predecessors_dict[n1].add(n2)
self.neighbors_dict[n2].add(n1)
self.edges_infos_dict[(n2, n1)] = d
[docs]
def get_neighbors(self, node: Hashable) -> set[Hashable]:
return self.neighbors_dict.get(node, set())
[docs]
def get_predecessors(self, node: Hashable) -> set[Hashable]:
return self.predecessors_dict.get(node, set())
[docs]
def get_attr_node(self, node: Hashable, attr: str) -> Any:
return self.nodes_infos_dict.get(node, {}).get(attr, None)
[docs]
def get_attr_edge(self, node1: Hashable, node2: Hashable, attr: str) -> Any:
return self.edges_infos_dict.get((node1, node2), {}).get(attr, None)
[docs]
def to_networkx(self) -> nx.DiGraph:
graph_nx = nx.DiGraph() if not self.undirected else nx.Graph()
graph_nx.add_nodes_from(self.nodes)
graph_nx.add_edges_from(self.edges)
return graph_nx
[docs]
def check_loop(self) -> Optional[list[tuple[Hashable, Hashable, str]]]:
try:
cycles = nx.find_cycle(self.graph_nx, orientation="original")
except:
cycles = None
return cycles
[docs]
def precedessors_nodes(self, n: Hashable) -> set[Hashable]:
return nx.algorithms.ancestors(self.graph_nx, n)
[docs]
def ancestors_map(self) -> dict[Hashable, set[Hashable]]:
return {
n: nx.algorithms.ancestors(self.graph_nx, n) for n in self.graph_nx.nodes()
}
[docs]
def descendants_map(self) -> dict[Hashable, set[Hashable]]:
return {
n: nx.algorithms.descendants(self.graph_nx, n)
for n in self.graph_nx.nodes()
}
[docs]
def successors_map(self) -> dict[Hashable, list[Hashable]]:
return {n: list(nx.neighbors(self.graph_nx, n)) for n in self.graph_nx.nodes()}
[docs]
def predecessors_map(self) -> dict[Hashable, list[Hashable]]:
return {n: list(self.graph_nx.predecessors(n)) for n in self.graph_nx.nodes()}
[docs]
def compute_length(
self, path: list[Hashable], attribute_name: Optional[str] = None
):
if attribute_name is None:
length = len(path) - 1
else:
length = sum(
[
self.graph_nx.edges[(i1, i2)][attribute_name]
for i1, i2 in zip(path[:-1], path[1:])
]
)
return length
[docs]
def compute_shortest_path(
self, source: Hashable, target: Hashable, attribute_name: Optional[str] = None
):
path = nx.dijkstra_path(
G=self.graph_nx, source=source, target=target, weight=attribute_name
)
length = self.compute_length(path=path, attribute_name=attribute_name)
return path, length
[docs]
def compute_all_shortest_path(
self, attribute_name: Optional[str] = None
) -> dict[Hashable, dict[Hashable, tuple[list[Hashable], float]]]:
all_path = nx.all_pairs_dijkstra_path(G=self.graph_nx, weight=attribute_name)
dict_path_and_distance = {}
for source, dict_path in all_path:
dict_path_and_distance[source] = {}
for target in dict_path:
length = self.compute_length(
path=dict_path[target], attribute_name=attribute_name
)
dict_path_and_distance[source][target] = (dict_path[target], length)
return dict_path_and_distance
[docs]
def compute_weighted_adjacency(
self, nodes_order: list[Hashable], attribute_edge: str
) -> np.ndarray:
return nx.to_numpy_array(
self.graph_nx, nodelist=nodes_order, weight=attribute_edge
)
[docs]
def from_networkx(
graph_nx: Union[nx.DiGraph, nx.Graph],
undirected: Optional[bool] = None,
compute_predecessors: bool = False,
):
return Graph(
nodes=[(n, graph_nx.nodes[n]) for n in graph_nx.nodes()],
edges=[(e[0], e[1], graph_nx.edges[e]) for e in graph_nx.edges()],
undirected=undirected
if undirected is not None
else not isinstance(graph_nx, nx.DiGraph),
compute_predecessors=compute_predecessors,
)
# this method is implemented to bypass the fact that networkX >= 3.2 is not compatible with python 3.8
[docs]
def get_node_attributes(graph: nx.Graph, name: string, default: Any):
"""
@param graph: a nx.Graph
@param name: name of attribut of intereste
@param default: default value if no value for attribute of interest
@return: a dictionnary with for each node of graph, the attribute value corresponding
"""
if default is not None:
return {n: d.get(name, default) for n, d in graph.nodes.items()}
return {n: d[name] for n, d in graph.nodes.items() if name in d}