diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py | 3021 |
1 files changed, 3021 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py b/.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py new file mode 100644 index 0000000..778d6e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py @@ -0,0 +1,3021 @@ +""" +********** +Matplotlib +********** + +Draw networks with matplotlib. + +Examples +-------- +>>> G = nx.complete_graph(5) +>>> nx.draw(G) + +See Also +-------- + - :doc:`matplotlib <matplotlib:index>` + - :func:`matplotlib.pyplot.scatter` + - :obj:`matplotlib.patches.FancyArrowPatch` +""" + +import collections +import itertools +from numbers import Number + +import networkx as nx + +__all__ = [ + "display", + "apply_matplotlib_colors", + "draw", + "draw_networkx", + "draw_networkx_nodes", + "draw_networkx_edges", + "draw_networkx_labels", + "draw_networkx_edge_labels", + "draw_bipartite", + "draw_circular", + "draw_kamada_kawai", + "draw_random", + "draw_spectral", + "draw_spring", + "draw_planar", + "draw_shell", + "draw_forceatlas2", +] + + +def apply_matplotlib_colors( + G, src_attr, dest_attr, map, vmin=None, vmax=None, nodes=True +): + """ + Apply colors from a matplotlib colormap to a graph. + + Reads values from the `src_attr` and use a matplotlib colormap + to produce a color. Write the color to `dest_attr`. + + Parameters + ---------- + G : nx.Graph + The graph to read and compute colors for. + + src_attr : str or other attribute name + The name of the attribute to read from the graph. + + dest_attr : str or other attribute name + The name of the attribute to write to on the graph. + + map : matplotlib.colormap + The matplotlib colormap to use. + + vmin : float, default None + The minimum value for scaling the colormap. If `None`, find the + minimum value of `src_attr`. + + vmax : float, default None + The maximum value for scaling the colormap. If `None`, find the + maximum value of `src_attr`. + + nodes : bool, default True + Whether the attribute names are edge attributes or node attributes. + """ + import matplotlib as mpl + + if nodes: + type_iter = G.nodes() + elif G.is_multigraph(): + type_iter = G.edges(keys=True) + else: + type_iter = G.edges() + + if vmin is None or vmax is None: + vals = [type_iter[a][src_attr] for a in type_iter] + if vmin is None: + vmin = min(vals) + if vmax is None: + vmax = max(vals) + + mapper = mpl.cm.ScalarMappable(cmap=map) + mapper.set_clim(vmin, vmax) + + def do_map(x): + # Cast numpy scalars to float + return tuple(float(x) for x in mapper.to_rgba(x)) + + if nodes: + nx.set_node_attributes( + G, {n: do_map(G.nodes[n][src_attr]) for n in G.nodes()}, dest_attr + ) + else: + nx.set_edge_attributes( + G, {e: do_map(G.edges[e][src_attr]) for e in type_iter}, dest_attr + ) + + +def display( + G, + canvas=None, + **kwargs, +): + """Draw the graph G. + + Draw the graph as a collection of nodes connected by edges. + The exact details of what the graph looks like are controled by the below + attributes. All nodes and nodes at the end of visible edges must have a + position set, but nearly all other node and edge attributes are options and + nodes or edges missing the attribute will use the default listed below. A more + complete discription of each parameter is given below this summary. + + .. list-table:: Default Visualization Attributes + :widths: 25 25 50 + :header-rows: 1 + + * - Parameter + - Default Attribute + - Default Value + * - pos + - `"pos"` + - If there is not position, a layout will be calculated with `nx.spring_layout`. + * - node_visible + - `"visible"` + - True + * - node_color + - `"color"` + - #1f78b4 + * - node_size + - `"size"` + - 300 + * - node_label + - `"label"` + - Dict describing the node label. Defaults create a black text with + the node name as the label. The dict respects these keys and defaults: + + * size : 12 + * color : black + * family : sans serif + * weight : normal + * alpha : 1.0 + * h_align : center + * v_align : center + * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`. + Default is None. + + * - node_shape + - `"shape"` + - "o" + * - node_alpha + - `"alpha"` + - 1.0 + * - node_border_width + - `"border_width"` + - 1.0 + * - node_border_color + - `"border_color"` + - Matching node_color + * - edge_visible + - `"visible"` + - True + * - edge_width + - `"width"` + - 1.0 + * - edge_color + - `"color"` + - Black (#000000) + * - edge_label + - `"label"` + - Dict describing the edge label. Defaults create black text with a + white bounding box. The dictionary respects these keys and defaults: + + * size : 12 + * color : black + * family : sans serif + * weight : normal + * alpha : 1.0 + * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`. + Default {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)} + * h_align : "center" + * v_align : "center" + * pos : 0.5 + * rotate : True + + * - edge_style + - `"style"` + - "-" + * - edge_alpha + - `"alpha"` + - 1.0 + * - arrowstyle + - `"arrowstyle"` + - ``"-|>"`` if `G` is directed else ``"-"`` + * - arrowsize + - `"arrowsize"` + - 10 if `G` is directed else 0 + * - edge_curvature + - `"curvature"` + - arc3 + * - edge_source_margin + - `"source_margin"` + - 0 + * - edge_target_margin + - `"target_margin"` + - 0 + + Parameters + ---------- + G : graph + A networkx graph + + canvas : Matplotlib Axes object, optional + Draw the graph in specified Matplotlib axes + + pos : string or function, default "pos" + A string naming the node attribute storing the position of nodes as a tuple. + Or a function to be called with input `G` which returns the layout as a dict keyed + by node to position tuple like the NetworkX layout functions. + If no nodes in the graph has the attribute, a spring layout is calculated. + + node_visible : string or bool, default visible + A string naming the node attribute which stores if a node should be drawn. + If `True`, all nodes will be visible while if `False` no nodes will be visible. + If incomplete, nodes missing this attribute will be shown by default. + + node_color : string, default "color" + A string naming the node attribute which stores the color of each node. + Visible nodes without this attribute will use '#1f78b4' as a default. + + node_size : string or number, default "size" + A string naming the node attribute which stores the size of each node. + Visible nodes without this attribute will use a default size of 300. + + node_label : string or bool, default "label" + A string naming the node attribute which stores the label of each node. + The attribute value can be a string, False (no label for that node), + True (the node is the label) or a dict keyed by node to the label. + + If a dict is specified, these keys are read to further control the label: + + * label : The text of the label; default: name of the node + * size : Font size of the label; default: 12 + * color : Font color of the label; default: black + * family : Font family of the label; default: "sans-serif" + * weight : Font weight of the label; default: "normal" + * alpha : Alpha value of the label; default: 1.0 + * h_align : The horizontal alignment of the label. + one of "left", "center", "right"; default: "center" + * v_align : The vertical alignment of the label. + one of "top", "center", "bottom"; default: "center" + * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`. + + Visible nodes without this attribute will be treated as if the value was True. + + node_shape : string, default "shape" + A string naming the node attribute which stores the label of each node. + The values of this attribute are expected to be one of the matplotlib shapes, + one of 'so^>v<dph8'. Visible nodes without this attribute will use 'o'. + + node_alpha : string, default "alpha" + A string naming the node attribute which stores the alpha of each node. + The values of this attribute are expected to be floats between 0.0 and 1.0. + Visible nodes without this attribute will be treated as if the value was 1.0. + + node_border_width : string, default "border_width" + A string naming the node attribute storing the width of the border of the node. + The values of this attribute are expected to be numeric. Visible nodes without + this attribute will use the assumed default of 1.0. + + node_border_color : string, default "border_color" + A string naming the node attribute which storing the color of the border of the node. + Visible nodes missing this attribute will use the final node_color value. + + edge_visible : string or bool, default "visible" + A string nameing the edge attribute which stores if an edge should be drawn. + If `True`, all edges will be drawn while if `False` no edges will be visible. + If incomplete, edges missing this attribute will be shown by default. Values + of this attribute are expected to be booleans. + + edge_width : string or int, default "width" + A string nameing the edge attribute which stores the width of each edge. + Visible edges without this attribute will use a default width of 1.0. + + edge_color : string or color, default "color" + A string nameing the edge attribute which stores of color of each edge. + Visible edges without this attribute will be drawn black. Each color can be + a string or rgb (or rgba) tuple of floats from 0.0 to 1.0. + + edge_label : string, default "label" + A string naming the edge attribute which stores the label of each edge. + The values of this attribute can be a string, number or False or None. In + the latter two cases, no edge label is displayed. + + If a dict is specified, these keys are read to further control the label: + + * label : The text of the label, or the name of an edge attribute holding the label. + * size : Font size of the label; default: 12 + * color : Font color of the label; default: black + * family : Font family of the label; default: "sans-serif" + * weight : Font weight of the label; default: "normal" + * alpha : Alpha value of the label; default: 1.0 + * h_align : The horizontal alignment of the label. + one of "left", "center", "right"; default: "center" + * v_align : The vertical alignment of the label. + one of "top", "center", "bottom"; default: "center" + * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`. + * rotate : Whether to rotate labels to lie parallel to the edge, default: True. + * pos : A float showing how far along the edge to put the label; default: 0.5. + + edge_style : string, default "style" + A string naming the edge attribute which stores the style of each edge. + Visible edges without this attribute will be drawn solid. Values of this + attribute can be line styles, e.g. '-', '--', '-.' or ':' or words like 'solid' + or 'dashed'. If no edge in the graph has this attribute and it is a non-default + value, assume that it describes the edge style for all edges in the graph. + + edge_alpha : string or float, default "alpha" + A string naming the edge attribute which stores the alpha value of each edge. + Visible edges without this attribute will use an alpha value of 1.0. + + arrowstyle : string, default "arrow" + A string naming the edge attribute which stores the type of arrowhead to use for + each edge. Visible edges without this attribute use ``"-"`` for undirected graphs + and ``"-|>"`` for directed graphs. + + See `matplotlib.patches.ArrowStyle` for more options + + arrowsize : string or int, default "arrow_size" + A string naming the edge attribute which stores the size of the arrowhead for each + edge. Visible edges without this attribute will use a default value of 10. + + edge_curvature : string, default "curvature" + A string naming the edge attribute storing the curvature and connection style + of each edge. Visible edges without this attribute will use "arc3" as a default + value, resulting an a straight line between the two nodes. Curvature can be given + as 'arc3,rad=0.2' to specify both the style and radius of curvature. + + Please see `matplotlib.patches.ConnectionStyle` and + `matplotlib.patches.FancyArrowPatch` for more information. + + edge_source_margin : string or int, default "source_margin" + A string naming the edge attribute which stores the minimum margin (gap) between + the source node and the start of the edge. Visible edges without this attribute + will use a default value of 0. + + edge_target_margin : string or int, default "target_margin" + A string naming the edge attribute which stores the minimumm margin (gap) between + the target node and the end of the edge. Visible edges without this attribute + will use a default value of 0. + + hide_ticks : bool, default True + Weather to remove the ticks from the axes of the matplotlib object. + + Raises + ------ + NetworkXError + If a node or edge is missing a required parameter such as `pos` or + if `display` receives an argument not listed above. + + ValueError + If a node or edge has an invalid color format, i.e. not a color string, + rgb tuple or rgba tuple. + + Returns + ------- + The input graph. This is potentially useful for dispatching visualization + functions. + """ + from collections import Counter + + import matplotlib as mpl + import matplotlib.pyplot as plt + import numpy as np + + defaults = { + "node_pos": None, + "node_visible": True, + "node_color": "#1f78b4", + "node_size": 300, + "node_label": { + "size": 12, + "color": "#000000", + "family": "sans-serif", + "weight": "normal", + "alpha": 1.0, + "h_align": "center", + "v_align": "center", + "bbox": None, + }, + "node_shape": "o", + "node_alpha": 1.0, + "node_border_width": 1.0, + "node_border_color": "face", + "edge_visible": True, + "edge_width": 1.0, + "edge_color": "#000000", + "edge_label": { + "size": 12, + "color": "#000000", + "family": "sans-serif", + "weight": "normal", + "alpha": 1.0, + "bbox": {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}, + "h_align": "center", + "v_align": "center", + "pos": 0.5, + "rotate": True, + }, + "edge_style": "-", + "edge_alpha": 1.0, + "edge_arrowstyle": "-|>" if G.is_directed() else "-", + "edge_arrowsize": 10 if G.is_directed() else 0, + "edge_curvature": "arc3", + "edge_source_margin": 0, + "edge_target_margin": 0, + "hide_ticks": True, + } + + # Check arguments + for kwarg in kwargs: + if kwarg not in defaults: + raise nx.NetworkXError( + f"Unrecongized visualization keyword argument: {kwarg}" + ) + + if canvas is None: + canvas = plt.gca() + + if kwargs.get("hide_ticks", defaults["hide_ticks"]): + canvas.tick_params( + axis="both", + which="both", + bottom=False, + left=False, + labelbottom=False, + labelleft=False, + ) + + ### Helper methods and classes + + def node_property_sequence(seq, attr): + """Return a list of attribute values for `seq`, using a default if needed""" + + # All node attribute parameters start with "node_" + param_name = f"node_{attr}" + default = defaults[param_name] + attr = kwargs.get(param_name, attr) + + if default is None: + # raise instead of using non-existant default value + for n in seq: + if attr not in node_subgraph.nodes[n]: + raise nx.NetworkXError(f"Attribute '{attr}' missing for node {n}") + + # If `attr` is not a graph attr and was explicitly passed as an argument + # it must be a user-default value. Allow attr=None to tell draw to skip + # attributes which are on the graph + if ( + attr is not None + and nx.get_node_attributes(node_subgraph, attr) == {} + and any(attr == v for k, v in kwargs.items() if "node" in k) + ): + return [attr for _ in seq] + + return [node_subgraph.nodes[n].get(attr, default) for n in seq] + + def compute_colors(color, alpha): + if isinstance(color, str): + rgba = mpl.colors.colorConverter.to_rgba(color) + # Using a non-default alpha value overrides any alpha value in the color + if alpha != defaults["node_alpha"]: + return (rgba[0], rgba[1], rgba[2], alpha) + return rgba + + if isinstance(color, tuple) and len(color) == 3: + return (color[0], color[1], color[2], alpha) + + if isinstance(color, tuple) and len(color) == 4: + return color + + raise ValueError(f"Invalid format for color: {color}") + + # Find which edges can be plotted as a line collection + # + # Non-default values for these attributes require fancy arrow patches: + # - any arrow style (including the default -|> for directed graphs) + # - arrow size (by extension of style) + # - connection style + # - min_source_margin + # - min_target_margin + + def collection_compatible(e): + return ( + get_edge_attr(e, "arrowstyle") == "-" + and get_edge_attr(e, "curvature") == "arc3" + and get_edge_attr(e, "source_margin") == 0 + and get_edge_attr(e, "target_margin") == 0 + # Self-loops will use fancy arrow patches + and e[0] != e[1] + ) + + def edge_property_sequence(seq, attr): + """Return a list of attribute values for `seq`, using a default if needed""" + + param_name = f"edge_{attr}" + default = defaults[param_name] + attr = kwargs.get(param_name, attr) + + if default is None: + # raise instead of using non-existant default value + for e in seq: + if attr not in edge_subgraph.edges[e]: + raise nx.NetworkXError(f"Attribute '{attr}' missing for edge {e}") + + if ( + attr is not None + and nx.get_edge_attributes(edge_subgraph, attr) == {} + and any(attr == v for k, v in kwargs.items() if "edge" in k) + ): + return [attr for _ in seq] + + return [edge_subgraph.edges[e].get(attr, default) for e in seq] + + def get_edge_attr(e, attr): + """Return the final edge attribute value, using default if not None""" + + param_name = f"edge_{attr}" + default = defaults[param_name] + attr = kwargs.get(param_name, attr) + + if default is None and attr not in edge_subgraph.edges[e]: + raise nx.NetworkXError(f"Attribute '{attr}' missing from edge {e}") + + if ( + attr is not None + and nx.get_edge_attributes(edge_subgraph, attr) == {} + and attr in kwargs.values() + ): + return attr + + return edge_subgraph.edges[e].get(attr, default) + + def get_node_attr(n, attr, use_edge_subgraph=True): + """Return the final node attribute value, using default if not None""" + subgraph = edge_subgraph if use_edge_subgraph else node_subgraph + + param_name = f"node_{attr}" + default = defaults[param_name] + attr = kwargs.get(param_name, attr) + + if default is None and attr not in subgraph.nodes[n]: + raise nx.NetworkXError(f"Attribute '{attr}' missing from node {n}") + + if ( + attr is not None + and nx.get_node_attributes(subgraph, attr) == {} + and attr in kwargs.values() + ): + return attr + + return subgraph.nodes[n].get(attr, default) + + # Taken from ConnectionStyleFactory + def self_loop(edge_index, node_size): + def self_loop_connection(posA, posB, *args, **kwargs): + if not np.all(posA == posB): + raise nx.NetworkXError( + "`self_loop` connection style method" + "is only to be used for self-loops" + ) + # this is called with _screen space_ values + # so convert back to data space + data_loc = canvas.transData.inverted().transform(posA) + # Scale self loop based on the size of the base node + # Size of nodes are given in points ** 2 and each point is 1/72 of an inch + v_shift = np.sqrt(node_size) / 72 + h_shift = v_shift * 0.5 + # put the top of the loop first so arrow is not hidden by node + path = np.asarray( + [ + # 1 + [0, v_shift], + # 4 4 4 + [h_shift, v_shift], + [h_shift, 0], + [0, 0], + # 4 4 4 + [-h_shift, 0], + [-h_shift, v_shift], + [0, v_shift], + ] + ) + # Rotate self loop 90 deg. if more than 1 + # This will allow for maximum of 4 visible self loops + if edge_index % 4: + x, y = path.T + for _ in range(edge_index % 4): + x, y = y, -x + path = np.array([x, y]).T + return mpl.path.Path( + canvas.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4] + ) + + return self_loop_connection + + def to_marker_edge(size, marker): + if marker in "s^>v<d": + return np.sqrt(2 * size) / 2 + else: + return np.sqrt(size) / 2 + + def build_fancy_arrow(e): + source_margin = to_marker_edge( + get_node_attr(e[0], "size"), + get_node_attr(e[0], "shape"), + ) + source_margin = max( + source_margin, + get_edge_attr(e, "source_margin"), + ) + + target_margin = to_marker_edge( + get_node_attr(e[1], "size"), + get_node_attr(e[1], "shape"), + ) + target_margin = max( + target_margin, + get_edge_attr(e, "target_margin"), + ) + return mpl.patches.FancyArrowPatch( + edge_subgraph.nodes[e[0]][pos], + edge_subgraph.nodes[e[1]][pos], + arrowstyle=get_edge_attr(e, "arrowstyle"), + connectionstyle=( + get_edge_attr(e, "curvature") + if e[0] != e[1] + else self_loop( + 0 if len(e) == 2 else e[2] % 4, + get_node_attr(e[0], "size"), + ) + ), + color=get_edge_attr(e, "color"), + linestyle=get_edge_attr(e, "style"), + linewidth=get_edge_attr(e, "width"), + mutation_scale=get_edge_attr(e, "arrowsize"), + shrinkA=source_margin, + shrinkB=source_margin, + zorder=1, + ) + + class CurvedArrowText(mpl.text.Text): + def __init__( + self, + arrow, + *args, + label_pos=0.5, + labels_horizontal=False, + ax=None, + **kwargs, + ): + # Bind to FancyArrowPatch + self.arrow = arrow + # how far along the text should be on the curve, + # 0 is at start, 1 is at end etc. + self.label_pos = label_pos + self.labels_horizontal = labels_horizontal + if ax is None: + ax = plt.gca() + self.ax = ax + self.x, self.y, self.angle = self._update_text_pos_angle(arrow) + + # Create text object + super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs) + # Bind to axis + self.ax.add_artist(self) + + def _get_arrow_path_disp(self, arrow): + """ + This is part of FancyArrowPatch._get_path_in_displaycoord + It omits the second part of the method where path is converted + to polygon based on width + The transform is taken from ax, not the object, as the object + has not been added yet, and doesn't have transform + """ + dpi_cor = arrow._dpi_cor + # trans_data = arrow.get_transform() + trans_data = self.ax.transData + if arrow._posA_posB is not None: + posA = arrow._convert_xy_units(arrow._posA_posB[0]) + posB = arrow._convert_xy_units(arrow._posA_posB[1]) + (posA, posB) = trans_data.transform((posA, posB)) + _path = arrow.get_connectionstyle()( + posA, + posB, + patchA=arrow.patchA, + patchB=arrow.patchB, + shrinkA=arrow.shrinkA * dpi_cor, + shrinkB=arrow.shrinkB * dpi_cor, + ) + else: + _path = trans_data.transform_path(arrow._path_original) + # Return is in display coordinates + return _path + + def _update_text_pos_angle(self, arrow): + # Fractional label position + path_disp = self._get_arrow_path_disp(arrow) + (x1, y1), (cx, cy), (x2, y2) = path_disp.vertices + # Text position at a proportion t along the line in display coords + # default is 0.5 so text appears at the halfway point + t = self.label_pos + tt = 1 - t + x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2 + y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2 + if self.labels_horizontal: + # Horizontal text labels + angle = 0 + else: + # Labels parallel to curve + change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx) + change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy) + angle = (np.arctan2(change_y, change_x) / (2 * np.pi)) * 360 + # Text is "right way up" + if angle > 90: + angle -= 180 + if angle < -90: + angle += 180 + (x, y) = self.ax.transData.inverted().transform((x, y)) + return x, y, angle + + def draw(self, renderer): + # recalculate the text position and angle + self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow) + self.set_position((self.x, self.y)) + self.set_rotation(self.angle) + # redraw text + super().draw(renderer) + + ### Draw the nodes first + node_visible = kwargs.get("node_visible", "visible") + if isinstance(node_visible, bool): + if node_visible: + visible_nodes = G.nodes() + else: + visible_nodes = [] + else: + visible_nodes = [ + n for n, v in nx.get_node_attributes(G, node_visible, True).items() if v + ] + + node_subgraph = G.subgraph(visible_nodes) + + # Ignore the default dict value since that's for default values to use, not + # default attribute name + pos = kwargs.get("node_pos", "pos") + + default_display_pos_attr = "display's position attribute name" + if callable(pos): + nx.set_node_attributes( + node_subgraph, pos(node_subgraph), default_display_pos_attr + ) + pos = default_display_pos_attr + kwargs["node_pos"] = default_display_pos_attr + elif nx.get_node_attributes(G, pos) == {}: + nx.set_node_attributes( + node_subgraph, nx.spring_layout(node_subgraph), default_display_pos_attr + ) + pos = default_display_pos_attr + kwargs["node_pos"] = default_display_pos_attr + + # Each shape requires a new scatter object since they can't have different + # shapes. + if len(visible_nodes) > 0: + node_shape = kwargs.get("node_shape", "shape") + for shape in Counter( + nx.get_node_attributes( + node_subgraph, node_shape, defaults["node_shape"] + ).values() + ): + # Filter position just on this shape. + nodes_with_shape = [ + n + for n, s in node_subgraph.nodes(data=node_shape) + if s == shape or (s is None and shape == defaults["node_shape"]) + ] + # There are two property sequences to create before hand. + # 1. position, since it is used for x and y parameters to scatter + # 2. edgecolor, since the spaeical 'face' parameter value can only be + # be passed in as the sole string, not part of a list of strings. + position = np.asarray(node_property_sequence(nodes_with_shape, "pos")) + color = np.asarray( + [ + compute_colors(c, a) + for c, a in zip( + node_property_sequence(nodes_with_shape, "color"), + node_property_sequence(nodes_with_shape, "alpha"), + ) + ] + ) + border_color = np.asarray( + [ + ( + c + if ( + c := get_node_attr( + n, + "border_color", + False, + ) + ) + != "face" + else color[i] + ) + for i, n in enumerate(nodes_with_shape) + ] + ) + canvas.scatter( + position[:, 0], + position[:, 1], + s=node_property_sequence(nodes_with_shape, "size"), + c=color, + marker=shape, + linewidths=node_property_sequence(nodes_with_shape, "border_width"), + edgecolors=border_color, + zorder=2, + ) + + ### Draw node labels + node_label = kwargs.get("node_label", "label") + # Plot labels if node_label is not None and not False + if node_label is not None and node_label is not False: + default_dict = {} + if isinstance(node_label, dict): + default_dict = node_label + node_label = None + + for n, lbl in node_subgraph.nodes(data=node_label): + if lbl is False: + continue + + # We work with label dicts down here... + if not isinstance(lbl, dict): + lbl = {"label": lbl if lbl is not None else n} + + lbl_text = lbl.get("label", n) + if not isinstance(lbl_text, str): + lbl_text = str(lbl_text) + + lbl.update(default_dict) + x, y = node_subgraph.nodes[n][pos] + canvas.text( + x, + y, + lbl_text, + size=lbl.get("size", defaults["node_label"]["size"]), + color=lbl.get("color", defaults["node_label"]["color"]), + family=lbl.get("family", defaults["node_label"]["family"]), + weight=lbl.get("weight", defaults["node_label"]["weight"]), + horizontalalignment=lbl.get( + "h_align", defaults["node_label"]["h_align"] + ), + verticalalignment=lbl.get("v_align", defaults["node_label"]["v_align"]), + transform=canvas.transData, + bbox=lbl.get("bbox", defaults["node_label"]["bbox"]), + ) + + ### Draw edges + + edge_visible = kwargs.get("edge_visible", "visible") + if isinstance(edge_visible, bool): + if edge_visible: + visible_edges = G.edges() + else: + visible_edges = [] + else: + visible_edges = [ + e for e, v in nx.get_edge_attributes(G, edge_visible, True).items() if v + ] + + edge_subgraph = G.edge_subgraph(visible_edges) + print(nx.get_node_attributes(node_subgraph, pos)) + nx.set_node_attributes( + edge_subgraph, nx.get_node_attributes(node_subgraph, pos), name=pos + ) + + collection_edges = ( + [e for e in edge_subgraph.edges(keys=True) if collection_compatible(e)] + if edge_subgraph.is_multigraph() + else [e for e in edge_subgraph.edges() if collection_compatible(e)] + ) + non_collection_edges = ( + [e for e in edge_subgraph.edges(keys=True) if not collection_compatible(e)] + if edge_subgraph.is_multigraph() + else [e for e in edge_subgraph.edges() if not collection_compatible(e)] + ) + edge_position = np.asarray( + [ + ( + get_node_attr(u, "pos", use_edge_subgraph=True), + get_node_attr(v, "pos", use_edge_subgraph=True), + ) + for u, v, *_ in collection_edges + ] + ) + + # Only plot a line collection if needed + if len(collection_edges) > 0: + edge_collection = mpl.collections.LineCollection( + edge_position, + colors=edge_property_sequence(collection_edges, "color"), + linewidths=edge_property_sequence(collection_edges, "width"), + linestyle=edge_property_sequence(collection_edges, "style"), + alpha=edge_property_sequence(collection_edges, "alpha"), + antialiaseds=(1,), + zorder=1, + ) + canvas.add_collection(edge_collection) + + fancy_arrows = {} + if len(non_collection_edges) > 0: + for e in non_collection_edges: + # Cache results for use in edge labels + fancy_arrows[e] = build_fancy_arrow(e) + canvas.add_patch(fancy_arrows[e]) + + ### Draw edge labels + edge_label = kwargs.get("edge_label", "label") + default_dict = {} + if isinstance(edge_label, dict): + default_dict = edge_label + # Restore the default label attribute key of 'label' + edge_label = "label" + + # Handle multigraphs + edge_label_data = ( + edge_subgraph.edges(data=edge_label, keys=True) + if edge_subgraph.is_multigraph() + else edge_subgraph.edges(data=edge_label) + ) + if edge_label is not None and edge_label is not False: + for *e, lbl in edge_label_data: + e = tuple(e) + # I'm not sure how I want to handle None here... For now it means no label + if lbl is False or lbl is None: + continue + + if not isinstance(lbl, dict): + lbl = {"label": lbl} + + lbl.update(default_dict) + lbl_text = lbl.get("label") + if not isinstance(lbl_text, str): + lbl_text = str(lbl_text) + + # In the old code, every non-self-loop is placed via a fancy arrow patch + # Only compute a new fancy arrow if needed by caching the results from + # edge placement. + try: + arrow = fancy_arrows[e] + except KeyError: + arrow = build_fancy_arrow(e) + + if e[0] == e[1]: + # Taken directly from draw_networkx_edge_labels + connectionstyle_obj = arrow.get_connectionstyle() + posA = canvas.transData.transform(edge_subgraph.nodes[e[0]][pos]) + path_disp = connectionstyle_obj(posA, posA) + path_data = canvas.transData.inverted().transform_path(path_disp) + x, y = path_data.vertices[0] + canvas.text( + x, + y, + lbl_text, + size=lbl.get("size", defaults["edge_label"]["size"]), + color=lbl.get("color", defaults["edge_label"]["color"]), + family=lbl.get("family", defaults["edge_label"]["family"]), + weight=lbl.get("weight", defaults["edge_label"]["weight"]), + alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]), + horizontalalignment=lbl.get( + "h_align", defaults["edge_label"]["h_align"] + ), + verticalalignment=lbl.get( + "v_align", defaults["edge_label"]["v_align"] + ), + rotation=0, + transform=canvas.transData, + bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]), + zorder=1, + ) + continue + + CurvedArrowText( + arrow, + lbl_text, + size=lbl.get("size", defaults["edge_label"]["size"]), + color=lbl.get("color", defaults["edge_label"]["color"]), + family=lbl.get("family", defaults["edge_label"]["family"]), + weight=lbl.get("weight", defaults["edge_label"]["weight"]), + alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]), + bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]), + horizontalalignment=lbl.get( + "h_align", defaults["edge_label"]["h_align"] + ), + verticalalignment=lbl.get("v_align", defaults["edge_label"]["v_align"]), + label_pos=lbl.get("pos", defaults["edge_label"]["pos"]), + labels_horizontal=lbl.get("rotate", defaults["edge_label"]["rotate"]), + transform=canvas.transData, + zorder=1, + ax=canvas, + ) + + # If we had to add an attribute, remove it here + if pos == default_display_pos_attr: + nx.remove_node_attributes(G, default_display_pos_attr) + + return G + + +def draw(G, pos=None, ax=None, **kwds): + """Draw the graph G with Matplotlib. + + Draw the graph as a simple representation with no node + labels or edge labels and using the full Matplotlib figure area + and no axis labels by default. See draw_networkx() for more + full-featured drawing that allows title, axis labels etc. + + Parameters + ---------- + G : graph + A networkx graph + + pos : dictionary, optional + A dictionary with nodes as keys and positions as values. + If not specified a spring layout positioning will be computed. + See :py:mod:`networkx.drawing.layout` for functions that + compute node positions. + + ax : Matplotlib Axes object, optional + Draw the graph in specified Matplotlib axes. + + kwds : optional keywords + See networkx.draw_networkx() for a description of optional keywords. + + Examples + -------- + >>> G = nx.dodecahedral_graph() + >>> nx.draw(G) + >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout + + See Also + -------- + draw_networkx + draw_networkx_nodes + draw_networkx_edges + draw_networkx_labels + draw_networkx_edge_labels + + Notes + ----- + This function has the same name as pylab.draw and pyplot.draw + so beware when using `from networkx import *` + + since you might overwrite the pylab.draw function. + + With pyplot use + + >>> import matplotlib.pyplot as plt + >>> G = nx.dodecahedral_graph() + >>> nx.draw(G) # networkx draw() + >>> plt.draw() # pyplot draw() + + Also see the NetworkX drawing examples at + https://networkx.org/documentation/latest/auto_examples/index.html + """ + + import matplotlib.pyplot as plt + + if ax is None: + cf = plt.gcf() + else: + cf = ax.get_figure() + cf.set_facecolor("w") + if ax is None: + if cf.axes: + ax = cf.gca() + else: + ax = cf.add_axes((0, 0, 1, 1)) + + if "with_labels" not in kwds: + kwds["with_labels"] = "labels" in kwds + + draw_networkx(G, pos=pos, ax=ax, **kwds) + ax.set_axis_off() + plt.draw_if_interactive() + return + + +def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds): + r"""Draw the graph G using Matplotlib. + + Draw the graph with Matplotlib with options for node positions, + labeling, titles, and many other drawing features. + See draw() for simple drawing without labels or axes. + + Parameters + ---------- + G : graph + A networkx graph + + pos : dictionary, optional + A dictionary with nodes as keys and positions as values. + If not specified a spring layout positioning will be computed. + See :py:mod:`networkx.drawing.layout` for functions that + compute node positions. + + arrows : bool or None, optional (default=None) + If `None`, directed graphs draw arrowheads with + `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges + via `~matplotlib.collections.LineCollection` for speed. + If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). + If `False`, draw edges using LineCollection (linear and fast). + For directed graphs, if True draw arrowheads. + Note: Arrows will be the same color as edges. + + arrowstyle : str (default='-\|>' for directed graphs) + For directed graphs, choose the style of the arrowsheads. + For undirected graphs default to '-' + + See `matplotlib.patches.ArrowStyle` for more options. + + arrowsize : int or list (default=10) + For directed graphs, choose the size of the arrow head's length and + width. A list of values can be passed in to assign a different size for arrow head's length and width. + See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale` + for more info. + + with_labels : bool (default=True) + Set to True to draw labels on the nodes. + + ax : Matplotlib Axes object, optional + Draw the graph in the specified Matplotlib axes. + + nodelist : list (default=list(G)) + Draw only specified nodes + + edgelist : list (default=list(G.edges())) + Draw only specified edges + + node_size : scalar or array (default=300) + Size of nodes. If an array is specified it must be the + same length as nodelist. + + node_color : color or array of colors (default='#1f78b4') + Node color. Can be a single color or a sequence of colors with the same + length as nodelist. Color can be string or rgb (or rgba) tuple of + floats from 0-1. If numeric values are specified they will be + mapped to colors using the cmap and vmin,vmax parameters. See + matplotlib.scatter for more details. + + node_shape : string (default='o') + The shape of the node. Specification is as matplotlib.scatter + marker, one of 'so^>v<dph8'. + + alpha : float or None (default=None) + The node and edge transparency + + cmap : Matplotlib colormap, optional + Colormap for mapping intensities of nodes + + vmin,vmax : float, optional + Minimum and maximum for node colormap scaling + + linewidths : scalar or sequence (default=1.0) + Line width of symbol border + + width : float or array of floats (default=1.0) + Line width of edges + + edge_color : color or array of colors (default='k') + Edge color. Can be a single color or a sequence of colors with the same + length as edgelist. Color can be string or rgb (or rgba) tuple of + floats from 0-1. If numeric values are specified they will be + mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. + + edge_cmap : Matplotlib colormap, optional + Colormap for mapping intensities of edges + + edge_vmin,edge_vmax : floats, optional + Minimum and maximum for edge colormap scaling + + style : string (default=solid line) + Edge line style e.g.: '-', '--', '-.', ':' + or words like 'solid' or 'dashed'. + (See `matplotlib.patches.FancyArrowPatch`: `linestyle`) + + labels : dictionary (default=None) + Node labels in a dictionary of text labels keyed by node + + font_size : int (default=12 for nodes, 10 for edges) + Font size for text labels + + font_color : color (default='k' black) + Font color string. Color can be string or rgb (or rgba) tuple of + floats from 0-1. + + font_weight : string (default='normal') + Font weight + + font_family : string (default='sans-serif') + Font family + + label : string, optional + Label for graph legend + + hide_ticks : bool, optional + Hide ticks of axes. When `True` (the default), ticks and ticklabels + are removed from the axes. To set ticks and tick labels to the pyplot default, + use ``hide_ticks=False``. + + kwds : optional keywords + See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and + networkx.draw_networkx_labels() for a description of optional keywords. + + Notes + ----- + For directed graphs, arrows are drawn at the head end. Arrows can be + turned off with keyword arrows=False. + + Examples + -------- + >>> G = nx.dodecahedral_graph() + >>> nx.draw(G) + >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout + + >>> import matplotlib.pyplot as plt + >>> limits = plt.axis("off") # turn off axis + + Also see the NetworkX drawing examples at + https://networkx.org/documentation/latest/auto_examples/index.html + + See Also + -------- + draw + draw_networkx_nodes + draw_networkx_edges + draw_networkx_labels + draw_networkx_edge_labels + """ + from inspect import signature + + import matplotlib.pyplot as plt + + # Get all valid keywords by inspecting the signatures of draw_networkx_nodes, + # draw_networkx_edges, draw_networkx_labels + + valid_node_kwds = signature(draw_networkx_nodes).parameters.keys() + valid_edge_kwds = signature(draw_networkx_edges).parameters.keys() + valid_label_kwds = signature(draw_networkx_labels).parameters.keys() + + # Create a set with all valid keywords across the three functions and + # remove the arguments of this function (draw_networkx) + valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - { + "G", + "pos", + "arrows", + "with_labels", + } + + if any(k not in valid_kwds for k in kwds): + invalid_args = ", ".join([k for k in kwds if k not in valid_kwds]) + raise ValueError(f"Received invalid argument(s): {invalid_args}") + + node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds} + edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds} + label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds} + + if pos is None: + pos = nx.drawing.spring_layout(G) # default to spring layout + + draw_networkx_nodes(G, pos, **node_kwds) + draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds) + if with_labels: + draw_networkx_labels(G, pos, **label_kwds) + plt.draw_if_interactive() + + +def draw_networkx_nodes( + G, + pos, + nodelist=None, + node_size=300, + node_color="#1f78b4", + node_shape="o", + alpha=None, + cmap=None, + vmin=None, + vmax=None, + ax=None, + linewidths=None, + edgecolors=None, + label=None, + margins=None, + hide_ticks=True, +): + """Draw the nodes of the graph G. + + This draws only the nodes of the graph G. + + Parameters + ---------- + G : graph + A networkx graph + + pos : dictionary + A dictionary with nodes as keys and positions as values. + Positions should be sequences of length 2. + + ax : Matplotlib Axes object, optional + Draw the graph in the specified Matplotlib axes. + + nodelist : list (default list(G)) + Draw only specified nodes + + node_size : scalar or array (default=300) + Size of nodes. If an array it must be the same length as nodelist. + + node_color : color or array of colors (default='#1f78b4') + Node color. Can be a single color or a sequence of colors with the same + length as nodelist. Color can be string or rgb (or rgba) tuple of + floats from 0-1. If numeric values are specified they will be + mapped to colors using the cmap and vmin,vmax parameters. See + matplotlib.scatter for more details. + + node_shape : string (default='o') + The shape of the node. Specification is as matplotlib.scatter + marker, one of 'so^>v<dph8'. + + alpha : float or array of floats (default=None) + The node transparency. This can be a single alpha value, + in which case it will be applied to all the nodes of color. Otherwise, + if it is an array, the elements of alpha will be applied to the colors + in order (cycling through alpha multiple times if necessary). + + cmap : Matplotlib colormap (default=None) + Colormap for mapping intensities of nodes + + vmin,vmax : floats or None (default=None) + Minimum and maximum for node colormap scaling + + linewidths : [None | scalar | sequence] (default=1.0) + Line width of symbol border + + edgecolors : [None | scalar | sequence] (default = node_color) + Colors of node borders. Can be a single color or a sequence of colors with the + same length as nodelist. Color can be string or rgb (or rgba) tuple of floats + from 0-1. If numeric values are specified they will be mapped to colors + using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details. + + label : [None | string] + Label for legend + + margins : float or 2-tuple, optional + Sets the padding for axis autoscaling. Increase margin to prevent + clipping for nodes that are near the edges of an image. Values should + be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins` + for details. The default is `None`, which uses the Matplotlib default. + + hide_ticks : bool, optional + Hide ticks of axes. When `True` (the default), ticks and ticklabels + are removed from the axes. To set ticks and tick labels to the pyplot default, + use ``hide_ticks=False``. + + Returns + ------- + matplotlib.collections.PathCollection + `PathCollection` of the nodes. + + Examples + -------- + >>> G = nx.dodecahedral_graph() + >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G)) + + Also see the NetworkX drawing examples at + https://networkx.org/documentation/latest/auto_examples/index.html + + See Also + -------- + draw + draw_networkx + draw_networkx_edges + draw_networkx_labels + draw_networkx_edge_labels + """ + from collections.abc import Iterable + + import matplotlib as mpl + import matplotlib.collections # call as mpl.collections + import matplotlib.pyplot as plt + import numpy as np + + if ax is None: + ax = plt.gca() + + if nodelist is None: + nodelist = list(G) + + if len(nodelist) == 0: # empty nodelist, no drawing + return mpl.collections.PathCollection(None) + + try: + xy = np.asarray([pos[v] for v in nodelist]) + except KeyError as err: + raise nx.NetworkXError(f"Node {err} has no position.") from err + + if isinstance(alpha, Iterable): + node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax) + alpha = None + + if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list): + node_shape = np.array([node_shape for _ in range(len(nodelist))]) + + for shape in np.unique(node_shape): + node_collection = ax.scatter( + xy[node_shape == shape, 0], + xy[node_shape == shape, 1], + s=node_size, + c=node_color, + marker=shape, + cmap=cmap, + vmin=vmin, + vmax=vmax, + alpha=alpha, + linewidths=linewidths, + edgecolors=edgecolors, + label=label, + ) + if hide_ticks: + ax.tick_params( + axis="both", + which="both", + bottom=False, + left=False, + labelbottom=False, + labelleft=False, + ) + + if margins is not None: + if isinstance(margins, Iterable): + ax.margins(*margins) + else: + ax.margins(margins) + + node_collection.set_zorder(2) + return node_collection + + +class FancyArrowFactory: + """Draw arrows with `matplotlib.patches.FancyarrowPatch`""" + + class ConnectionStyleFactory: + def __init__(self, connectionstyles, selfloop_height, ax=None): + import matplotlib as mpl + import matplotlib.path # call as mpl.path + import numpy as np + + self.ax = ax + self.mpl = mpl + self.np = np + self.base_connection_styles = [ + mpl.patches.ConnectionStyle(cs) for cs in connectionstyles + ] + self.n = len(self.base_connection_styles) + self.selfloop_height = selfloop_height + + def curved(self, edge_index): + return self.base_connection_styles[edge_index % self.n] + + def self_loop(self, edge_index): + def self_loop_connection(posA, posB, *args, **kwargs): + if not self.np.all(posA == posB): + raise nx.NetworkXError( + "`self_loop` connection style method" + "is only to be used for self-loops" + ) + # this is called with _screen space_ values + # so convert back to data space + data_loc = self.ax.transData.inverted().transform(posA) + v_shift = 0.1 * self.selfloop_height + h_shift = v_shift * 0.5 + # put the top of the loop first so arrow is not hidden by node + path = self.np.asarray( + [ + # 1 + [0, v_shift], + # 4 4 4 + [h_shift, v_shift], + [h_shift, 0], + [0, 0], + # 4 4 4 + [-h_shift, 0], + [-h_shift, v_shift], + [0, v_shift], + ] + ) + # Rotate self loop 90 deg. if more than 1 + # This will allow for maximum of 4 visible self loops + if edge_index % 4: + x, y = path.T + for _ in range(edge_index % 4): + x, y = y, -x + path = self.np.array([x, y]).T + return self.mpl.path.Path( + self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4] + ) + + return self_loop_connection + + def __init__( + self, + edge_pos, + edgelist, + nodelist, + edge_indices, + node_size, + selfloop_height, + connectionstyle="arc3", + node_shape="o", + arrowstyle="-", + arrowsize=10, + edge_color="k", + alpha=None, + linewidth=1.0, + style="solid", + min_source_margin=0, + min_target_margin=0, + ax=None, + ): + import matplotlib as mpl + import matplotlib.patches # call as mpl.patches + import matplotlib.pyplot as plt + import numpy as np + + if isinstance(connectionstyle, str): + connectionstyle = [connectionstyle] + elif np.iterable(connectionstyle): + connectionstyle = list(connectionstyle) + else: + msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable" + raise nx.NetworkXError(msg) + self.ax = ax + self.mpl = mpl + self.np = np + self.edge_pos = edge_pos + self.edgelist = edgelist + self.nodelist = nodelist + self.node_shape = node_shape + self.min_source_margin = min_source_margin + self.min_target_margin = min_target_margin + self.edge_indices = edge_indices + self.node_size = node_size + self.connectionstyle_factory = self.ConnectionStyleFactory( + connectionstyle, selfloop_height, ax + ) + self.arrowstyle = arrowstyle + self.arrowsize = arrowsize + self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha) + self.linewidth = linewidth + self.style = style + if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos): + raise ValueError("arrowsize should have the same length as edgelist") + + def __call__(self, i): + (x1, y1), (x2, y2) = self.edge_pos[i] + shrink_source = 0 # space from source to tail + shrink_target = 0 # space from head to target + if ( + self.np.iterable(self.min_source_margin) + and not isinstance(self.min_source_margin, str) + and not isinstance(self.min_source_margin, tuple) + ): + min_source_margin = self.min_source_margin[i] + else: + min_source_margin = self.min_source_margin + + if ( + self.np.iterable(self.min_target_margin) + and not isinstance(self.min_target_margin, str) + and not isinstance(self.min_target_margin, tuple) + ): + min_target_margin = self.min_target_margin[i] + else: + min_target_margin = self.min_target_margin + + if self.np.iterable(self.node_size): # many node sizes + source, target = self.edgelist[i][:2] + source_node_size = self.node_size[self.nodelist.index(source)] + target_node_size = self.node_size[self.nodelist.index(target)] + shrink_source = self.to_marker_edge(source_node_size, self.node_shape) + shrink_target = self.to_marker_edge(target_node_size, self.node_shape) + else: + shrink_source = self.to_marker_edge(self.node_size, self.node_shape) + shrink_target = shrink_source + shrink_source = max(shrink_source, min_source_margin) + shrink_target = max(shrink_target, min_target_margin) + + # scale factor of arrow head + if isinstance(self.arrowsize, list): + mutation_scale = self.arrowsize[i] + else: + mutation_scale = self.arrowsize + + if len(self.arrow_colors) > i: + arrow_color = self.arrow_colors[i] + elif len(self.arrow_colors) == 1: + arrow_color = self.arrow_colors[0] + else: # Cycle through colors + arrow_color = self.arrow_colors[i % len(self.arrow_colors)] + + if self.np.iterable(self.linewidth): + if len(self.linewidth) > i: + linewidth = self.linewidth[i] + else: + linewidth = self.linewidth[i % len(self.linewidth)] + else: + linewidth = self.linewidth + + if ( + self.np.iterable(self.style) + and not isinstance(self.style, str) + and not isinstance(self.style, tuple) + ): + if len(self.style) > i: + linestyle = self.style[i] + else: # Cycle through styles + linestyle = self.style[i % len(self.style)] + else: + linestyle = self.style + + if x1 == x2 and y1 == y2: + connectionstyle = self.connectionstyle_factory.self_loop( + self.edge_indices[i] + ) + else: + connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i]) + + if ( + self.np.iterable(self.arrowstyle) + and not isinstance(self.arrowstyle, str) + and not isinstance(self.arrowstyle, tuple) + ): + arrowstyle = self.arrowstyle[i] + else: + arrowstyle = self.arrowstyle + + return self.mpl.patches.FancyArrowPatch( + (x1, y1), + (x2, y2), + arrowstyle=arrowstyle, + shrinkA=shrink_source, + shrinkB=shrink_target, + mutation_scale=mutation_scale, + color=arrow_color, + linewidth=linewidth, + connectionstyle=connectionstyle, + linestyle=linestyle, + zorder=1, # arrows go behind nodes + ) + + def to_marker_edge(self, marker_size, marker): + if marker in "s^>v<d": # `large` markers need extra space + return self.np.sqrt(2 * marker_size) / 2 + else: + return self.np.sqrt(marker_size) / 2 + + +def draw_networkx_edges( + G, + pos, + edgelist=None, + width=1.0, + edge_color="k", + style="solid", + alpha=None, + arrowstyle=None, + arrowsize=10, + edge_cmap=None, + edge_vmin=None, + edge_vmax=None, + ax=None, + arrows=None, + label=None, + node_size=300, + nodelist=None, + node_shape="o", + connectionstyle="arc3", + min_source_margin=0, + min_target_margin=0, + hide_ticks=True, +): + r"""Draw the edges of the graph G. + + This draws only the edges of the graph G. + + Parameters + ---------- + G : graph + A networkx graph + + pos : dictionary + A dictionary with nodes as keys and positions as values. + Positions should be sequences of length 2. + + edgelist : collection of edge tuples (default=G.edges()) + Draw only specified edges + + width : float or array of floats (default=1.0) + Line width of edges + + edge_color : color or array of colors (default='k') + Edge color. Can be a single color or a sequence of colors with the same + length as edgelist. Color can be string or rgb (or rgba) tuple of + floats from 0-1. If numeric values are specified they will be + mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. + + style : string or array of strings (default='solid') + Edge line style e.g.: '-', '--', '-.', ':' + or words like 'solid' or 'dashed'. + Can be a single style or a sequence of styles with the same + length as the edge list. + If less styles than edges are given the styles will cycle. + If more styles than edges are given the styles will be used sequentially + and not be exhausted. + Also, `(offset, onoffseq)` tuples can be used as style instead of a strings. + (See `matplotlib.patches.FancyArrowPatch`: `linestyle`) + + alpha : float or array of floats (default=None) + The edge transparency. This can be a single alpha value, + in which case it will be applied to all specified edges. Otherwise, + if it is an array, the elements of alpha will be applied to the colors + in order (cycling through alpha multiple times if necessary). + + edge_cmap : Matplotlib colormap, optional + Colormap for mapping intensities of edges + + edge_vmin,edge_vmax : floats, optional + Minimum and maximum for edge colormap scaling + + ax : Matplotlib Axes object, optional + Draw the graph in the specified Matplotlib axes. + + arrows : bool or None, optional (default=None) + If `None`, directed graphs draw arrowheads with + `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges + via `~matplotlib.collections.LineCollection` for speed. + If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). + If `False`, draw edges using LineCollection (linear and fast). + + Note: Arrowheads will be the same color as edges. + + arrowstyle : str or list of strs (default='-\|>' for directed graphs) + For directed graphs and `arrows==True` defaults to '-\|>', + For undirected graphs default to '-'. + + See `matplotlib.patches.ArrowStyle` for more options. + + arrowsize : int or list of ints(default=10) + For directed graphs, choose the size of the arrow head's length and + width. See `matplotlib.patches.FancyArrowPatch` for attribute + `mutation_scale` for more info. + + connectionstyle : string or iterable of strings (default="arc3") + Pass the connectionstyle parameter to create curved arc of rounding + radius rad. For example, connectionstyle='arc3,rad=0.2'. + See `matplotlib.patches.ConnectionStyle` and + `matplotlib.patches.FancyArrowPatch` for more info. + If Iterable, index indicates i'th edge key of MultiGraph + + node_size : scalar or array (default=300) + Size of nodes. Though the nodes are not drawn with this function, the + node size is used in determining edge positioning. + + nodelist : list, optional (default=G.nodes()) + This provides the node order for the `node_size` array (if it is an array). + + node_shape : string (default='o') + The marker used for nodes, used in determining edge positioning. + Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'. + + label : None or string + Label for legend + + min_source_margin : int or list of ints (default=0) + The minimum margin (gap) at the beginning of the edge at the source. + + min_target_margin : int or list of ints (default=0) + The minimum margin (gap) at the end of the edge at the target. + + hide_ticks : bool, optional + Hide ticks of axes. When `True` (the default), ticks and ticklabels + are removed from the axes. To set ticks and tick labels to the pyplot default, + use ``hide_ticks=False``. + + Returns + ------- + matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch + If ``arrows=True``, a list of FancyArrowPatches is returned. + If ``arrows=False``, a LineCollection is returned. + If ``arrows=None`` (the default), then a LineCollection is returned if + `G` is undirected, otherwise returns a list of FancyArrowPatches. + + Notes + ----- + For directed graphs, arrows are drawn at the head end. Arrows can be + turned off with keyword arrows=False or by passing an arrowstyle without + an arrow on the end. + + Be sure to include `node_size` as a keyword argument; arrows are + drawn considering the size of nodes. + + Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch` + regardless of the value of `arrows` or whether `G` is directed. + When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the + FancyArrowPatches corresponding to the self-loops are not explicitly + returned. They should instead be accessed via the ``Axes.patches`` + attribute (see examples). + + Examples + -------- + >>> G = nx.dodecahedral_graph() + >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) + + >>> G = nx.DiGraph() + >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)]) + >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) + >>> alphas = [0.3, 0.4, 0.5] + >>> for i, arc in enumerate(arcs): # change alpha values of arcs + ... arc.set_alpha(alphas[i]) + + The FancyArrowPatches corresponding to self-loops are not always + returned, but can always be accessed via the ``patches`` attribute of the + `matplotlib.Axes` object. + + >>> import matplotlib.pyplot as plt + >>> fig, ax = plt.subplots() + >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0 + >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax) + >>> self_loop_fap = ax.patches[0] + + Also see the NetworkX drawing examples at + https://networkx.org/documentation/latest/auto_examples/index.html + + See Also + -------- + draw + draw_networkx + draw_networkx_nodes + draw_networkx_labels + draw_networkx_edge_labels + + """ + import warnings + + import matplotlib as mpl + import matplotlib.collections # call as mpl.collections + import matplotlib.colors # call as mpl.colors + import matplotlib.pyplot as plt + import numpy as np + + # The default behavior is to use LineCollection to draw edges for + # undirected graphs (for performance reasons) and use FancyArrowPatches + # for directed graphs. + # The `arrows` keyword can be used to override the default behavior + if arrows is None: + use_linecollection = not (G.is_directed() or G.is_multigraph()) + else: + if not isinstance(arrows, bool): + raise TypeError("Argument `arrows` must be of type bool or None") + use_linecollection = not arrows + + if isinstance(connectionstyle, str): + connectionstyle = [connectionstyle] + elif np.iterable(connectionstyle): + connectionstyle = list(connectionstyle) + else: + msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable" + raise nx.NetworkXError(msg) + + # Some kwargs only apply to FancyArrowPatches. Warn users when they use + # non-default values for these kwargs when LineCollection is being used + # instead of silently ignoring the specified option + if use_linecollection: + msg = ( + "\n\nThe {0} keyword argument is not applicable when drawing edges\n" + "with LineCollection.\n\n" + "To make this warning go away, either specify `arrows=True` to\n" + "force FancyArrowPatches or use the default values.\n" + "Note that using FancyArrowPatches may be slow for large graphs.\n" + ) + if arrowstyle is not None: + warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2) + if arrowsize != 10: + warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2) + if min_source_margin != 0: + warnings.warn( + msg.format("min_source_margin"), category=UserWarning, stacklevel=2 + ) + if min_target_margin != 0: + warnings.warn( + msg.format("min_target_margin"), category=UserWarning, stacklevel=2 + ) + if any(cs != "arc3" for cs in connectionstyle): + warnings.warn( + msg.format("connectionstyle"), category=UserWarning, stacklevel=2 + ) + + # NOTE: Arrowstyle modification must occur after the warnings section + if arrowstyle is None: + arrowstyle = "-|>" if G.is_directed() else "-" + + if ax is None: + ax = plt.gca() + + if edgelist is None: + edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise + + if len(edgelist): + if G.is_multigraph(): + key_count = collections.defaultdict(lambda: itertools.count(0)) + edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist] + else: + edge_indices = [0] * len(edgelist) + else: # no edges! + return [] + + if nodelist is None: + nodelist = list(G.nodes()) + + # FancyArrowPatch handles color=None different from LineCollection + if edge_color is None: + edge_color = "k" + + # set edge positions + edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) + + # Check if edge_color is an array of floats and map to edge_cmap. + # This is the only case handled differently from matplotlib + if ( + np.iterable(edge_color) + and (len(edge_color) == len(edge_pos)) + and np.all([isinstance(c, Number) for c in edge_color]) + ): + if edge_cmap is not None: + assert isinstance(edge_cmap, mpl.colors.Colormap) + else: + edge_cmap = plt.get_cmap() + if edge_vmin is None: + edge_vmin = min(edge_color) + if edge_vmax is None: + edge_vmax = max(edge_color) + color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax) + edge_color = [edge_cmap(color_normal(e)) for e in edge_color] + + # compute initial view + minx = np.amin(np.ravel(edge_pos[:, :, 0])) + maxx = np.amax(np.ravel(edge_pos[:, :, 0])) + miny = np.amin(np.ravel(edge_pos[:, :, 1])) + maxy = np.amax(np.ravel(edge_pos[:, :, 1])) + w = maxx - minx + h = maxy - miny + + # Self-loops are scaled by view extent, except in cases the extent + # is 0, e.g. for a single node. In this case, fall back to scaling + # by the maximum node size + selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max() + fancy_arrow_factory = FancyArrowFactory( + edge_pos, + edgelist, + nodelist, + edge_indices, + node_size, + selfloop_height, + connectionstyle, + node_shape, + arrowstyle, + arrowsize, + edge_color, + alpha, + width, + style, + min_source_margin, + min_target_margin, + ax=ax, + ) + + # Draw the edges + if use_linecollection: + edge_collection = mpl.collections.LineCollection( + edge_pos, + colors=edge_color, + linewidths=width, + antialiaseds=(1,), + linestyle=style, + alpha=alpha, + ) + edge_collection.set_cmap(edge_cmap) + edge_collection.set_clim(edge_vmin, edge_vmax) + edge_collection.set_zorder(1) # edges go behind nodes + edge_collection.set_label(label) + ax.add_collection(edge_collection) + edge_viz_obj = edge_collection + + # Make sure selfloop edges are also drawn + # --------------------------------------- + selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist] + if selfloops_to_draw: + edgelist_tuple = list(map(tuple, edgelist)) + arrow_collection = [] + for loop in selfloops_to_draw: + i = edgelist_tuple.index(loop) + arrow = fancy_arrow_factory(i) + arrow_collection.append(arrow) + ax.add_patch(arrow) + else: + edge_viz_obj = [] + for i in range(len(edgelist)): + arrow = fancy_arrow_factory(i) + ax.add_patch(arrow) + edge_viz_obj.append(arrow) + + # update view after drawing + padx, pady = 0.05 * w, 0.05 * h + corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady) + ax.update_datalim(corners) + ax.autoscale_view() + + if hide_ticks: + ax.tick_params( + axis="both", + which="both", + bottom=False, + left=False, + labelbottom=False, + labelleft=False, + ) + + return edge_viz_obj + + +def draw_networkx_labels( + G, + pos, + labels=None, + font_size=12, + font_color="k", + font_family="sans-serif", + font_weight="normal", + alpha=None, + bbox=None, + horizontalalignment="center", + verticalalignment="center", + ax=None, + clip_on=True, + hide_ticks=True, +): + """Draw node labels on the graph G. + + Parameters + ---------- + G : graph + A networkx graph + + pos : dictionary + A dictionary with nodes as keys and positions as values. + Positions should be sequences of length 2. + + labels : dictionary (default={n: n for n in G}) + Node labels in a dictionary of text labels keyed by node. + Node-keys in labels should appear as keys in `pos`. + If needed use: `{n:lab for n,lab in labels.items() if n in pos}` + + font_size : int or dictionary of nodes to ints (default=12) + Font size for text labels. + + font_color : color or dictionary of nodes to colors (default='k' black) + Font color string. Color can be string or rgb (or rgba) tuple of + floats from 0-1. + + font_weight : string or dictionary of nodes to strings (default='normal') + Font weight. + + font_family : string or dictionary of nodes to strings (default='sans-serif') + Font family. + + alpha : float or None or dictionary of nodes to floats (default=None) + The text transparency. + + bbox : Matplotlib bbox, (default is Matplotlib's ax.text default) + Specify text box properties (e.g. shape, color etc.) for node labels. + + horizontalalignment : string or array of strings (default='center') + Horizontal alignment {'center', 'right', 'left'}. If an array is + specified it must be the same length as `nodelist`. + + verticalalignment : string (default='center') + Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}. + If an array is specified it must be the same length as `nodelist`. + + ax : Matplotlib Axes object, optional + Draw the graph in the specified Matplotlib axes. + + clip_on : bool (default=True) + Turn on clipping of node labels at axis boundaries + + hide_ticks : bool, optional + Hide ticks of axes. When `True` (the default), ticks and ticklabels + are removed from the axes. To set ticks and tick labels to the pyplot default, + use ``hide_ticks=False``. + + Returns + ------- + dict + `dict` of labels keyed on the nodes + + Examples + -------- + >>> G = nx.dodecahedral_graph() + >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G)) + + Also see the NetworkX drawing examples at + https://networkx.org/documentation/latest/auto_examples/index.html + + See Also + -------- + draw + draw_networkx + draw_networkx_nodes + draw_networkx_edges + draw_networkx_edge_labels + """ + import matplotlib.pyplot as plt + + if ax is None: + ax = plt.gca() + + if labels is None: + labels = {n: n for n in G.nodes()} + + individual_params = set() + + def check_individual_params(p_value, p_name): + if isinstance(p_value, dict): + if len(p_value) != len(labels): + raise ValueError(f"{p_name} must have the same length as labels.") + individual_params.add(p_name) + + def get_param_value(node, p_value, p_name): + if p_name in individual_params: + return p_value[node] + return p_value + + check_individual_params(font_size, "font_size") + check_individual_params(font_color, "font_color") + check_individual_params(font_weight, "font_weight") + check_individual_params(font_family, "font_family") + check_individual_params(alpha, "alpha") + + text_items = {} # there is no text collection so we'll fake one + for n, label in labels.items(): + (x, y) = pos[n] + if not isinstance(label, str): + label = str(label) # this makes "1" and 1 labeled the same + t = ax.text( + x, + y, + label, + size=get_param_value(n, font_size, "font_size"), + color=get_param_value(n, font_color, "font_color"), + family=get_param_value(n, font_family, "font_family"), + weight=get_param_value(n, font_weight, "font_weight"), + alpha=get_param_value(n, alpha, "alpha"), + horizontalalignment=horizontalalignment, + verticalalignment=verticalalignment, + transform=ax.transData, + bbox=bbox, + clip_on=clip_on, + ) + text_items[n] = t + + if hide_ticks: + ax.tick_params( + axis="both", + which="both", + bottom=False, + left=False, + labelbottom=False, + labelleft=False, + ) + + return text_items + + +def draw_networkx_edge_labels( + G, + pos, + edge_labels=None, + label_pos=0.5, + font_size=10, + font_color="k", + font_family="sans-serif", + font_weight="normal", + alpha=None, + bbox=None, + horizontalalignment="center", + verticalalignment="center", + ax=None, + rotate=True, + clip_on=True, + node_size=300, + nodelist=None, + connectionstyle="arc3", + hide_ticks=True, +): + """Draw edge labels. + + Parameters + ---------- + G : graph + A networkx graph + + pos : dictionary + A dictionary with nodes as keys and positions as values. + Positions should be sequences of length 2. + + edge_labels : dictionary (default=None) + Edge labels in a dictionary of labels keyed by edge two-tuple. + Only labels for the keys in the dictionary are drawn. + + label_pos : float (default=0.5) + Position of edge label along edge (0=head, 0.5=center, 1=tail) + + font_size : int (default=10) + Font size for text labels + + font_color : color (default='k' black) + Font color string. Color can be string or rgb (or rgba) tuple of + floats from 0-1. + + font_weight : string (default='normal') + Font weight + + font_family : string (default='sans-serif') + Font family + + alpha : float or None (default=None) + The text transparency + + bbox : Matplotlib bbox, optional + Specify text box properties (e.g. shape, color etc.) for edge labels. + Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}. + + horizontalalignment : string (default='center') + Horizontal alignment {'center', 'right', 'left'} + + verticalalignment : string (default='center') + Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} + + ax : Matplotlib Axes object, optional + Draw the graph in the specified Matplotlib axes. + + rotate : bool (default=True) + Rotate edge labels to lie parallel to edges + + clip_on : bool (default=True) + Turn on clipping of edge labels at axis boundaries + + node_size : scalar or array (default=300) + Size of nodes. If an array it must be the same length as nodelist. + + nodelist : list, optional (default=G.nodes()) + This provides the node order for the `node_size` array (if it is an array). + + connectionstyle : string or iterable of strings (default="arc3") + Pass the connectionstyle parameter to create curved arc of rounding + radius rad. For example, connectionstyle='arc3,rad=0.2'. + See `matplotlib.patches.ConnectionStyle` and + `matplotlib.patches.FancyArrowPatch` for more info. + If Iterable, index indicates i'th edge key of MultiGraph + + hide_ticks : bool, optional + Hide ticks of axes. When `True` (the default), ticks and ticklabels + are removed from the axes. To set ticks and tick labels to the pyplot default, + use ``hide_ticks=False``. + + Returns + ------- + dict + `dict` of labels keyed by edge + + Examples + -------- + >>> G = nx.dodecahedral_graph() + >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G)) + + Also see the NetworkX drawing examples at + https://networkx.org/documentation/latest/auto_examples/index.html + + See Also + -------- + draw + draw_networkx + draw_networkx_nodes + draw_networkx_edges + draw_networkx_labels + """ + import matplotlib as mpl + import matplotlib.pyplot as plt + import numpy as np + + class CurvedArrowText(mpl.text.Text): + def __init__( + self, + arrow, + *args, + label_pos=0.5, + labels_horizontal=False, + ax=None, + **kwargs, + ): + # Bind to FancyArrowPatch + self.arrow = arrow + # how far along the text should be on the curve, + # 0 is at start, 1 is at end etc. + self.label_pos = label_pos + self.labels_horizontal = labels_horizontal + if ax is None: + ax = plt.gca() + self.ax = ax + self.x, self.y, self.angle = self._update_text_pos_angle(arrow) + + # Create text object + super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs) + # Bind to axis + self.ax.add_artist(self) + + def _get_arrow_path_disp(self, arrow): + """ + This is part of FancyArrowPatch._get_path_in_displaycoord + It omits the second part of the method where path is converted + to polygon based on width + The transform is taken from ax, not the object, as the object + has not been added yet, and doesn't have transform + """ + dpi_cor = arrow._dpi_cor + trans_data = self.ax.transData + if arrow._posA_posB is None: + raise ValueError( + "Can only draw labels for fancy arrows with " + "posA and posB inputs, not custom path" + ) + posA = arrow._convert_xy_units(arrow._posA_posB[0]) + posB = arrow._convert_xy_units(arrow._posA_posB[1]) + (posA, posB) = trans_data.transform((posA, posB)) + _path = arrow.get_connectionstyle()( + posA, + posB, + patchA=arrow.patchA, + patchB=arrow.patchB, + shrinkA=arrow.shrinkA * dpi_cor, + shrinkB=arrow.shrinkB * dpi_cor, + ) + # Return is in display coordinates + return _path + + def _update_text_pos_angle(self, arrow): + # Fractional label position + # Text position at a proportion t along the line in display coords + # default is 0.5 so text appears at the halfway point + t = self.label_pos + tt = 1 - t + path_disp = self._get_arrow_path_disp(arrow) + is_bar_style = isinstance( + arrow.get_connectionstyle(), mpl.patches.ConnectionStyle.Bar + ) + # 1. Calculate x and y + if is_bar_style: + # Bar Connection Style - straight line + _, (cx1, cy1), (cx2, cy2), _ = path_disp.vertices + x = cx1 * tt + cx2 * t + y = cy1 * tt + cy2 * t + else: + # Arc or Angle type Connection Styles - Bezier curve + (x1, y1), (cx, cy), (x2, y2) = path_disp.vertices + x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2 + y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2 + # 2. Calculate Angle + if self.labels_horizontal: + # Horizontal text labels + angle = 0 + else: + # Labels parallel to curve + if is_bar_style: + change_x = (cx2 - cx1) / 2 + change_y = (cy2 - cy1) / 2 + else: + change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx) + change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy) + angle = np.arctan2(change_y, change_x) / (2 * np.pi) * 360 + # Text is "right way up" + if angle > 90: + angle -= 180 + elif angle < -90: + angle += 180 + (x, y) = self.ax.transData.inverted().transform((x, y)) + return x, y, angle + + def draw(self, renderer): + # recalculate the text position and angle + self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow) + self.set_position((self.x, self.y)) + self.set_rotation(self.angle) + # redraw text + super().draw(renderer) + + # use default box of white with white border + if bbox is None: + bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)} + + if isinstance(connectionstyle, str): + connectionstyle = [connectionstyle] + elif np.iterable(connectionstyle): + connectionstyle = list(connectionstyle) + else: + raise nx.NetworkXError( + "draw_networkx_edges arg `connectionstyle` must be" + "string or iterable of strings" + ) + + if ax is None: + ax = plt.gca() + + if edge_labels is None: + kwds = {"keys": True} if G.is_multigraph() else {} + edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)} + # NOTHING TO PLOT + if not edge_labels: + return {} + edgelist, labels = zip(*edge_labels.items()) + + if nodelist is None: + nodelist = list(G.nodes()) + + # set edge positions + edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) + + if G.is_multigraph(): + key_count = collections.defaultdict(lambda: itertools.count(0)) + edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist] + else: + edge_indices = [0] * len(edgelist) + + # Used to determine self loop mid-point + # Note, that this will not be accurate, + # if not drawing edge_labels for all edges drawn + h = 0 + if edge_labels: + miny = np.amin(np.ravel(edge_pos[:, :, 1])) + maxy = np.amax(np.ravel(edge_pos[:, :, 1])) + h = maxy - miny + selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max() + fancy_arrow_factory = FancyArrowFactory( + edge_pos, + edgelist, + nodelist, + edge_indices, + node_size, + selfloop_height, + connectionstyle, + ax=ax, + ) + + individual_params = {} + + def check_individual_params(p_value, p_name): + # TODO should this be list or array (as in a numpy array)? + if isinstance(p_value, list): + if len(p_value) != len(edgelist): + raise ValueError(f"{p_name} must have the same length as edgelist.") + individual_params[p_name] = p_value.iter() + + # Don't need to pass in an edge because these are lists, not dicts + def get_param_value(p_value, p_name): + if p_name in individual_params: + return next(individual_params[p_name]) + return p_value + + check_individual_params(font_size, "font_size") + check_individual_params(font_color, "font_color") + check_individual_params(font_weight, "font_weight") + check_individual_params(alpha, "alpha") + check_individual_params(horizontalalignment, "horizontalalignment") + check_individual_params(verticalalignment, "verticalalignment") + check_individual_params(rotate, "rotate") + check_individual_params(label_pos, "label_pos") + + text_items = {} + for i, (edge, label) in enumerate(zip(edgelist, labels)): + if not isinstance(label, str): + label = str(label) # this makes "1" and 1 labeled the same + + n1, n2 = edge[:2] + arrow = fancy_arrow_factory(i) + if n1 == n2: + connectionstyle_obj = arrow.get_connectionstyle() + posA = ax.transData.transform(pos[n1]) + path_disp = connectionstyle_obj(posA, posA) + path_data = ax.transData.inverted().transform_path(path_disp) + x, y = path_data.vertices[0] + text_items[edge] = ax.text( + x, + y, + label, + size=get_param_value(font_size, "font_size"), + color=get_param_value(font_color, "font_color"), + family=get_param_value(font_family, "font_family"), + weight=get_param_value(font_weight, "font_weight"), + alpha=get_param_value(alpha, "alpha"), + horizontalalignment=get_param_value( + horizontalalignment, "horizontalalignment" + ), + verticalalignment=get_param_value( + verticalalignment, "verticalalignment" + ), + rotation=0, + transform=ax.transData, + bbox=bbox, + zorder=1, + clip_on=clip_on, + ) + else: + text_items[edge] = CurvedArrowText( + arrow, + label, + size=get_param_value(font_size, "font_size"), + color=get_param_value(font_color, "font_color"), + family=get_param_value(font_family, "font_family"), + weight=get_param_value(font_weight, "font_weight"), + alpha=get_param_value(alpha, "alpha"), + horizontalalignment=get_param_value( + horizontalalignment, "horizontalalignment" + ), + verticalalignment=get_param_value( + verticalalignment, "verticalalignment" + ), + transform=ax.transData, + bbox=bbox, + zorder=1, + clip_on=clip_on, + label_pos=get_param_value(label_pos, "label_pos"), + labels_horizontal=not get_param_value(rotate, "rotate"), + ax=ax, + ) + + if hide_ticks: + ax.tick_params( + axis="both", + which="both", + bottom=False, + left=False, + labelbottom=False, + labelleft=False, + ) + + return text_items + + +def draw_bipartite(G, **kwargs): + """Draw the graph `G` with a bipartite layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.bipartite_layout(G), **kwargs) + + Parameters + ---------- + G : graph + A networkx graph + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Raises + ------ + NetworkXError : + If `G` is not bipartite. + + Notes + ----- + The layout is computed each time this function is called. For + repeated drawing it is much more efficient to call + `~networkx.drawing.layout.bipartite_layout` directly and reuse the result:: + + >>> G = nx.complete_bipartite_graph(3, 3) + >>> pos = nx.bipartite_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.complete_bipartite_graph(2, 5) + >>> nx.draw_bipartite(G) + + See Also + -------- + :func:`~networkx.drawing.layout.bipartite_layout` + """ + draw(G, pos=nx.bipartite_layout(G), **kwargs) + + +def draw_circular(G, **kwargs): + """Draw the graph `G` with a circular layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.circular_layout(G), **kwargs) + + Parameters + ---------- + G : graph + A networkx graph + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Notes + ----- + The layout is computed each time this function is called. For + repeated drawing it is much more efficient to call + `~networkx.drawing.layout.circular_layout` directly and reuse the result:: + + >>> G = nx.complete_graph(5) + >>> pos = nx.circular_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.draw_circular(G) + + See Also + -------- + :func:`~networkx.drawing.layout.circular_layout` + """ + draw(G, pos=nx.circular_layout(G), **kwargs) + + +def draw_kamada_kawai(G, **kwargs): + """Draw the graph `G` with a Kamada-Kawai force-directed layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs) + + Parameters + ---------- + G : graph + A networkx graph + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Notes + ----- + The layout is computed each time this function is called. + For repeated drawing it is much more efficient to call + `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the + result:: + + >>> G = nx.complete_graph(5) + >>> pos = nx.kamada_kawai_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.draw_kamada_kawai(G) + + See Also + -------- + :func:`~networkx.drawing.layout.kamada_kawai_layout` + """ + draw(G, pos=nx.kamada_kawai_layout(G), **kwargs) + + +def draw_random(G, **kwargs): + """Draw the graph `G` with a random layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.random_layout(G), **kwargs) + + Parameters + ---------- + G : graph + A networkx graph + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Notes + ----- + The layout is computed each time this function is called. + For repeated drawing it is much more efficient to call + `~networkx.drawing.layout.random_layout` directly and reuse the result:: + + >>> G = nx.complete_graph(5) + >>> pos = nx.random_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.lollipop_graph(4, 3) + >>> nx.draw_random(G) + + See Also + -------- + :func:`~networkx.drawing.layout.random_layout` + """ + draw(G, pos=nx.random_layout(G), **kwargs) + + +def draw_spectral(G, **kwargs): + """Draw the graph `G` with a spectral 2D layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.spectral_layout(G), **kwargs) + + For more information about how node positions are determined, see + `~networkx.drawing.layout.spectral_layout`. + + Parameters + ---------- + G : graph + A networkx graph + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Notes + ----- + The layout is computed each time this function is called. + For repeated drawing it is much more efficient to call + `~networkx.drawing.layout.spectral_layout` directly and reuse the result:: + + >>> G = nx.complete_graph(5) + >>> pos = nx.spectral_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.draw_spectral(G) + + See Also + -------- + :func:`~networkx.drawing.layout.spectral_layout` + """ + draw(G, pos=nx.spectral_layout(G), **kwargs) + + +def draw_spring(G, **kwargs): + """Draw the graph `G` with a spring layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.spring_layout(G), **kwargs) + + Parameters + ---------- + G : graph + A networkx graph + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Notes + ----- + `~networkx.drawing.layout.spring_layout` is also the default layout for + `draw`, so this function is equivalent to `draw`. + + The layout is computed each time this function is called. + For repeated drawing it is much more efficient to call + `~networkx.drawing.layout.spring_layout` directly and reuse the result:: + + >>> G = nx.complete_graph(5) + >>> pos = nx.spring_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.path_graph(20) + >>> nx.draw_spring(G) + + See Also + -------- + draw + :func:`~networkx.drawing.layout.spring_layout` + """ + draw(G, pos=nx.spring_layout(G), **kwargs) + + +def draw_shell(G, nlist=None, **kwargs): + """Draw networkx graph `G` with shell layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs) + + Parameters + ---------- + G : graph + A networkx graph + + nlist : list of list of nodes, optional + A list containing lists of nodes representing the shells. + Default is `None`, meaning all nodes are in a single shell. + See `~networkx.drawing.layout.shell_layout` for details. + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Notes + ----- + The layout is computed each time this function is called. + For repeated drawing it is much more efficient to call + `~networkx.drawing.layout.shell_layout` directly and reuse the result:: + + >>> G = nx.complete_graph(5) + >>> pos = nx.shell_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.path_graph(4) + >>> shells = [[0], [1, 2, 3]] + >>> nx.draw_shell(G, nlist=shells) + + See Also + -------- + :func:`~networkx.drawing.layout.shell_layout` + """ + draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs) + + +def draw_planar(G, **kwargs): + """Draw a planar networkx graph `G` with planar layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.planar_layout(G), **kwargs) + + Parameters + ---------- + G : graph + A planar networkx graph + + kwargs : optional keywords + See `draw_networkx` for a description of optional keywords. + + Raises + ------ + NetworkXException + When `G` is not planar + + Notes + ----- + The layout is computed each time this function is called. + For repeated drawing it is much more efficient to call + `~networkx.drawing.layout.planar_layout` directly and reuse the result:: + + >>> G = nx.path_graph(5) + >>> pos = nx.planar_layout(G) + >>> nx.draw(G, pos=pos) # Draw the original graph + >>> # Draw a subgraph, reusing the same node positions + >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") + + Examples + -------- + >>> G = nx.path_graph(4) + >>> nx.draw_planar(G) + + See Also + -------- + :func:`~networkx.drawing.layout.planar_layout` + """ + draw(G, pos=nx.planar_layout(G), **kwargs) + + +def draw_forceatlas2(G, **kwargs): + """Draw a networkx graph with forceatlas2 layout. + + This is a convenience function equivalent to:: + + nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs) + + Parameters + ---------- + G : graph + A networkx graph + + kwargs : optional keywords + See networkx.draw_networkx() for a description of optional keywords, + with the exception of the pos parameter which is not used by this + function. + """ + draw(G, pos=nx.forceatlas2_layout(G), **kwargs) + + +def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None): + """Apply an alpha (or list of alphas) to the colors provided. + + Parameters + ---------- + + colors : color string or array of floats (default='r') + Color of element. Can be a single color format string, + or a sequence of colors with the same length as nodelist. + If numeric values are specified they will be mapped to + colors using the cmap and vmin,vmax parameters. See + matplotlib.scatter for more details. + + alpha : float or array of floats + Alpha values for elements. This can be a single alpha value, in + which case it will be applied to all the elements of color. Otherwise, + if it is an array, the elements of alpha will be applied to the colors + in order (cycling through alpha multiple times if necessary). + + elem_list : array of networkx objects + The list of elements which are being colored. These could be nodes, + edges or labels. + + cmap : matplotlib colormap + Color map for use if colors is a list of floats corresponding to points + on a color mapping. + + vmin, vmax : float + Minimum and maximum values for normalizing colors if a colormap is used + + Returns + ------- + + rgba_colors : numpy ndarray + Array containing RGBA format values for each of the node colours. + + """ + from itertools import cycle, islice + + import matplotlib as mpl + import matplotlib.cm # call as mpl.cm + import matplotlib.colors # call as mpl.colors + import numpy as np + + # If we have been provided with a list of numbers as long as elem_list, + # apply the color mapping. + if len(colors) == len(elem_list) and isinstance(colors[0], Number): + mapper = mpl.cm.ScalarMappable(cmap=cmap) + mapper.set_clim(vmin, vmax) + rgba_colors = mapper.to_rgba(colors) + # Otherwise, convert colors to matplotlib's RGB using the colorConverter + # object. These are converted to numpy ndarrays to be consistent with the + # to_rgba method of ScalarMappable. + else: + try: + rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)]) + except ValueError: + rgba_colors = np.array( + [mpl.colors.colorConverter.to_rgba(color) for color in colors] + ) + # Set the final column of the rgba_colors to have the relevant alpha values + try: + # If alpha is longer than the number of colors, resize to the number of + # elements. Also, if rgba_colors.size (the number of elements of + # rgba_colors) is the same as the number of elements, resize the array, + # to avoid it being interpreted as a colormap by scatter() + if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list): + rgba_colors = np.resize(rgba_colors, (len(elem_list), 4)) + rgba_colors[1:, 0] = rgba_colors[0, 0] + rgba_colors[1:, 1] = rgba_colors[0, 1] + rgba_colors[1:, 2] = rgba_colors[0, 2] + rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors))) + except TypeError: + rgba_colors[:, -1] = alpha + return rgba_colors |
