bigraph-viz
Advanced tools
| Metadata-Version: 2.1 | ||
| Name: bigraph-viz | ||
| Version: 0.1.6 | ||
| Version: 0.1.7 | ||
| Summary: A graphviz-based plotting tool for compositional bigraph schema | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/vivarium-collective/bigraph-viz |
+414
-504
| import os | ||
| import difflib | ||
| import re | ||
| from collections import defaultdict | ||
| import inspect | ||
| import graphviz | ||
| from itertools import islice | ||
| import numpy as np | ||
@@ -10,41 +12,29 @@ | ||
| # Constants | ||
| PROCESS_SCHEMA_KEYS = [ | ||
| 'config', | ||
| 'address', | ||
| 'interval', | ||
| 'inputs', | ||
| 'outputs', | ||
| 'instance', | ||
| 'bridge', | ||
| ] | ||
| 'config', 'address', 'interval', 'inputs', 'outputs', 'instance', 'bridge'] | ||
| def chunked(iterable, size): | ||
| """Yield successive chunks from iterable.""" | ||
| it = iter(iterable) | ||
| return iter(lambda: tuple(islice(it, size)), ()) | ||
| # Utility: Label formatting | ||
| def make_label(label): | ||
| # Insert line breaks after every max_length characters | ||
| # max_length = 25 | ||
| # lines = [label[i:i+max_length] for i in range(0, len(label), max_length)] | ||
| # label = '<br/>'.join(lines) | ||
| """Wrap a label in angle brackets for Graphviz HTML rendering.""" | ||
| return f'<{label}>' | ||
| def get_graph_wires(ports_schema, wires, graph_dict, schema_key, edge_path, bridge_wires=None): | ||
| """ | ||
| Traverse the port wiring and append wire edges or disconnected ports to graph_dict. | ||
| def get_graph_wires( | ||
| ports_schema, # the ports schema | ||
| wires, # the wires, from port to path | ||
| graph_dict, # the current graph dict that is being built | ||
| schema_key, # inputs or outputs | ||
| edge_path, # the path up to this process | ||
| bridge_wires=None, | ||
| ): | ||
| Parameters: | ||
| ports_schema (dict): Schema for ports (inputs or outputs) | ||
| wires (dict): Wiring structure from the process | ||
| graph_dict (dict): Accumulated graph | ||
| schema_key (str): Either 'inputs' or 'outputs' | ||
| edge_path (tuple): Path of the process node | ||
| bridge_wires (dict, optional): Optional rewiring via 'bridge' dict | ||
| Returns: | ||
| graph_dict (dict): Updated graph dict | ||
| """ | ||
| TODO -- support subwires with advanced wiring. This currently assumes each port has a simple wire. | ||
| """ | ||
| wires = wires or {} | ||
| ports_schema = ports_schema or {} | ||
| inferred_ports = set(list(ports_schema.keys()) + list(wires.keys())) | ||
| inferred_ports = set(ports_schema.keys()) | set(wires.keys()) | ||
@@ -56,59 +46,54 @@ for port in inferred_ports: | ||
| if not wire: | ||
| # there is no wire for this port, it is disconnected | ||
| if schema_key == 'inputs': | ||
| graph_dict['disconnected_input_edges'].append({ | ||
| 'edge_path': edge_path, | ||
| 'port': port, | ||
| 'type': schema_key}) | ||
| elif schema_key == 'outputs': | ||
| graph_dict['disconnected_output_edges'].append({ | ||
| 'edge_path': edge_path, | ||
| 'port': port, | ||
| 'type': schema_key}) | ||
| # If not connected, mark as disconnected | ||
| edge_type = 'disconnected_input_edges' if schema_key == 'inputs' else 'disconnected_output_edges' | ||
| graph_dict[edge_type].append({ | ||
| 'edge_path': edge_path, | ||
| 'port': port, | ||
| 'type': schema_key | ||
| }) | ||
| elif isinstance(wire, (list, tuple, str)): | ||
| graph_dict = get_single_wire(edge_path, graph_dict, port, schema_key, wire) | ||
| elif isinstance(wire, dict): | ||
| flat_wires = hierarchy_depth(wires) | ||
| for subpath, subwire in flat_wires.items(): | ||
| for subpath, subwire in hierarchy_depth(wires).items(): | ||
| subport = '/'.join(subpath) | ||
| graph_dict = get_single_wire(edge_path, graph_dict, subport, schema_key, subwire) | ||
| else: | ||
| raise ValueError(f"Unexpected wire type: {wires}") | ||
| # Handle optional bridge wiring | ||
| if bridge: | ||
| target_path = absolute_path(edge_path, tuple(bridge)) | ||
| if schema_key == 'inputs': | ||
| graph_dict['input_edges'].append({ | ||
| 'edge_path': edge_path, | ||
| 'target_path': target_path, | ||
| 'port': f'bridge_{port}', | ||
| 'type': f'bridge_{schema_key}'}) | ||
| elif schema_key == 'outputs': | ||
| graph_dict['output_edges'].append({ | ||
| 'edge_path': edge_path, | ||
| 'target_path': target_path, | ||
| 'port': f'bridge_{port}', | ||
| 'type': f'bridge_{schema_key}'}) | ||
| edge_key = 'input_edges' if schema_key == 'inputs' else 'output_edges' | ||
| graph_dict[edge_key].append({ | ||
| 'edge_path': edge_path, | ||
| 'target_path': target_path, | ||
| 'port': f'bridge_{port}', | ||
| 'type': f'bridge_{schema_key}' | ||
| }) | ||
| return graph_dict | ||
| # Append a single port wire connection to graph_dict | ||
| def get_single_wire(edge_path, graph_dict, port, schema_key, wire): | ||
| # the wire is defined, add it to edges | ||
| """ | ||
| Add a connection from a port to its wire target. | ||
| Parameters: | ||
| edge_path (tuple): Path to the process | ||
| graph_dict (dict): Current graph dict | ||
| port (str): Name of the port | ||
| schema_key (str): Either 'inputs' or 'outputs' | ||
| wire (str|list): Wire connection(s) | ||
| Returns: | ||
| Updated graph_dict | ||
| """ | ||
| if isinstance(wire, str): | ||
| wire = [wire] | ||
| elif isinstance(wire, (list, tuple)): | ||
| # only use strings in the wire | ||
| # TODO -- make this more general so it only skips integers if they go into an array | ||
| else: | ||
| wire = [item for item in wire if isinstance(item, str)] | ||
| target_path = absolute_path(edge_path[:-1], tuple(wire)) # TODO -- make sure this resolves ".." | ||
| if schema_key == 'inputs': | ||
| edge_key = 'input_edges' | ||
| elif schema_key == 'outputs': | ||
| edge_key = 'output_edges' | ||
| else: | ||
| raise Exception(f'invalid schema key {schema_key}') | ||
| target_path = absolute_path(edge_path[:-1], tuple(wire)) | ||
| edge_key = 'input_edges' if schema_key == 'inputs' else 'output_edges' | ||
| graph_dict[edge_key].append({ | ||
@@ -118,283 +103,272 @@ 'edge_path': edge_path, | ||
| 'port': port, | ||
| 'type': schema_key}) | ||
| 'type': schema_key | ||
| }) | ||
| return graph_dict | ||
| # Plot a labeled edge from a port to a process | ||
| def plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='false'): | ||
| """ | ||
| Add an edge between a target (state node) and process node. | ||
| If target not already rendered, add it to the graph. | ||
| """ | ||
| process_name = str(edge['edge_path']) | ||
| target_name = str(edge['target_path']) | ||
| label = make_label(edge['port']) if port_labels else '' | ||
| def plot_edges( | ||
| graph, | ||
| edge, | ||
| port_labels, | ||
| port_label_size, | ||
| state_node_spec, | ||
| constraint='false', | ||
| ): | ||
| process_path = edge['edge_path'] | ||
| process_name = str(process_path) | ||
| target_path = edge['target_path'] | ||
| port = edge['port'] | ||
| target_name = str(target_path) | ||
| if target_name not in graph.body: | ||
| label_text = make_label(edge['target_path'][-1]) | ||
| graph.node(target_name, label=label_text, **state_node_spec) | ||
| # place it in the graph | ||
| if target_name not in graph.body: # is the source node already in the graph? | ||
| label = make_label(target_path[-1]) | ||
| graph.node(target_name, label=label, **state_node_spec) | ||
| # port label | ||
| label = '' | ||
| if port_labels: | ||
| label = make_label(port) | ||
| with graph.subgraph(name=process_name) as c: | ||
| c.edge( | ||
| target_name, | ||
| process_name, | ||
| constraint=constraint, | ||
| label=label, | ||
| labelloc="t", | ||
| fontsize=port_label_size) | ||
| with graph.subgraph(name=process_name) as sub: | ||
| sub.edge(target_name, process_name, constraint=constraint, label=label, | ||
| labelloc="t", fontsize=port_label_size) | ||
| # Add a node to the graph with optional value/type | ||
| def add_node_to_graph(graph, node, state_node_spec, show_values, show_types, significant_digits): | ||
| """ | ||
| Add a state node to the Graphviz graph. | ||
| def add_node_to_graph(graph, node, state_node_spec, show_values, show_types, significant_digits): | ||
| Parameters: | ||
| graph: The Graphviz object | ||
| node: Dict representing the node | ||
| state_node_spec: Style options | ||
| show_values (bool): Whether to show the node value | ||
| show_types (bool): Whether to show the node type | ||
| significant_digits (int): Digits to round values | ||
| """ | ||
| node_path = node['path'] | ||
| node_name = str(node_path) | ||
| # make the label | ||
| label = node_path[-1] | ||
| schema_label = None | ||
| if show_values: | ||
| if node.get('value'): | ||
| v = node['value'] | ||
| if isinstance(v, float): | ||
| v = round(v, significant_digits) | ||
| if v.is_integer(): | ||
| v = int(v) | ||
| if not schema_label: | ||
| schema_label = '' | ||
| schema_label += f":{v}" | ||
| if show_types: | ||
| if node.get('type'): | ||
| if not schema_label: | ||
| schema_label = '<br/>' | ||
| ntype = node['type'] | ||
| if len(ntype) > 20: # don't show the full type if it's too long | ||
| ntype = '...' | ||
| schema_label += f"[{ntype}]" | ||
| if schema_label: | ||
| label += schema_label | ||
| label = make_label(label) | ||
| label_info = '' | ||
| if show_values and (val := node.get('value')) is not None: | ||
| if isinstance(val, float): | ||
| val = int(val) if val.is_integer() else round(val, significant_digits) | ||
| label_info += f":{val}" | ||
| if show_types and (typ := node.get('type')): | ||
| label_info += f"<br/>[{typ if len(typ) <= 20 else '...'}]" | ||
| full_label = make_label(label + label_info) if label_info else make_label(label) | ||
| graph.attr('node', **state_node_spec) | ||
| graph.node(str(node_name), label=label) | ||
| graph.node(node_name, label=full_label) | ||
| return node_name | ||
| # make the Graphviz figure | ||
| def get_graphviz_fig( | ||
| graph_dict, | ||
| label_margin='0.05', | ||
| node_label_size='12pt', | ||
| process_label_size=None, | ||
| size='16,10', | ||
| rankdir='TB', | ||
| aspect_ratio='auto', # 'compress', 'expand', 'auto', 'fill' | ||
| dpi='70', | ||
| significant_digits=2, | ||
| undirected_edges=False, | ||
| show_values=False, | ||
| show_types=False, | ||
| port_labels=True, | ||
| port_label_size='10pt', | ||
| invisible_edges=False, | ||
| remove_process_place_edges=False, | ||
| node_border_colors=None, | ||
| node_fill_colors=None, | ||
| node_groups=False, | ||
| max_nodes_per_row=None, | ||
| graph_dict, | ||
| label_margin='0.05', | ||
| node_label_size='12pt', | ||
| process_label_size=None, | ||
| size='16,10', | ||
| rankdir='TB', | ||
| aspect_ratio='auto', | ||
| dpi='70', | ||
| significant_digits=2, | ||
| undirected_edges=False, | ||
| show_values=False, | ||
| show_types=False, | ||
| port_labels=True, | ||
| port_label_size='10pt', | ||
| invisible_edges=None, | ||
| remove_process_place_edges=False, | ||
| node_border_colors=None, | ||
| node_fill_colors=None, | ||
| node_groups=None, | ||
| collapse_redundant_processes=False, | ||
| ): | ||
| """make a graphviz figure from a graph_dict""" | ||
| """ | ||
| Generate a Graphviz Digraph from a graph_dict describing a simulation architecture. | ||
| Parameters: | ||
| graph_dict: dict | ||
| Dictionary describing nodes and edges of a simulation bigraph. | ||
| collapse_redundant_processes: bool | ||
| Collapse processes with identical port wiring into a single node. | ||
| All other parameters configure visual style. | ||
| Returns: | ||
| graphviz.Digraph | ||
| Graphviz representation of the graph. | ||
| """ | ||
| import difflib | ||
| from collections import defaultdict | ||
| invisible_edges = invisible_edges or [] | ||
| node_groups = node_groups or [] | ||
| node_names = [] | ||
| invisible_edges = invisible_edges or [] | ||
| process_label_size = process_label_size or node_label_size | ||
| # node specs | ||
| graph = graphviz.Digraph(name='bigraph', engine='dot') | ||
| graph.attr(size=size, overlap='false', rankdir=rankdir, dpi=dpi, ratio=aspect_ratio, splines='true') | ||
| # Define node styles | ||
| state_node_spec = { | ||
| 'shape': 'circle', 'penwidth': '2', 'constraint': 'false', 'margin': label_margin, 'fontsize': node_label_size} | ||
| 'shape': 'circle', 'penwidth': '2', 'constraint': 'false', | ||
| 'margin': label_margin, 'fontsize': node_label_size | ||
| } | ||
| process_node_spec = { | ||
| 'shape': 'box', 'penwidth': '2', 'constraint': 'false', 'margin': label_margin, 'fontsize': process_label_size} | ||
| input_edge_spec = { | ||
| 'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0', 'dir': 'forward'} | ||
| output_edge_spec = { | ||
| 'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0', 'dir': 'back'} | ||
| bidirectional_edge_spec = { | ||
| 'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0', 'dir': 'both'} | ||
| 'shape': 'box', 'penwidth': '2', 'constraint': 'false', | ||
| 'margin': label_margin, 'fontsize': process_label_size | ||
| } | ||
| # Define edge styles | ||
| edge_styles = { | ||
| 'input': {'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0', 'dir': 'forward'}, | ||
| 'output': {'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0', 'dir': 'back'}, | ||
| 'bidirectional': {'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0', 'dir': 'both'}, | ||
| 'place': {'arrowhead': 'none', 'penwidth': '2'} | ||
| } | ||
| if undirected_edges: | ||
| input_edge_spec['dir'] = 'none' | ||
| output_edge_spec['dir'] = 'none' | ||
| bidirectional_edge_spec['dir'] = 'none' | ||
| for spec in edge_styles.values(): | ||
| spec['dir'] = 'none' | ||
| # initialize graph | ||
| graph = graphviz.Digraph(name='bigraph', engine='dot') | ||
| graph.attr(size=size, overlap='false', rankdir=rankdir, dpi=dpi, | ||
| ratio=aspect_ratio, # "fill", | ||
| splines='true', | ||
| ) | ||
| node_names = [] | ||
| # state nodes | ||
| graph.attr('node', **state_node_spec) | ||
| state_nodes = graph_dict['state_nodes'] | ||
| if max_nodes_per_row: | ||
| previous_node = None | ||
| for i, chunk in enumerate(chunked(state_nodes, max_nodes_per_row)): | ||
| with graph.subgraph(name=f'state_row_{i}') as row: | ||
| row.attr(rank='same') | ||
| chunk_node_names = [] | ||
| for node in chunk: | ||
| node_name = add_node_to_graph(graph, node, state_node_spec, show_values, show_types, | ||
| significant_digits) | ||
| node_names.append(node_name) | ||
| chunk_node_names.append(node_name) | ||
| def get_name_template(names): | ||
| """Create a generalized name with wildcards for collapsed process names.""" | ||
| if len(names) == 1: | ||
| return names[0] | ||
| import re | ||
| prefix = os.path.commonprefix(names) | ||
| suffix = os.path.commonprefix([n[::-1] for n in names])[::-1] | ||
| wildcard_middle = '*' if prefix != names[0] or suffix != names[0] else '' | ||
| return f"{prefix}{wildcard_middle}{suffix}" | ||
| # Add invisible edge to stack rows | ||
| if previous_node and chunk_node_names: | ||
| graph.edge(previous_node, chunk_node_names[0], style='invis', weight='10') | ||
| if chunk_node_names: | ||
| previous_node = chunk_node_names[-1] | ||
| else: | ||
| for node in state_nodes: | ||
| node_name = add_node_to_graph(graph, node, state_node_spec, show_values, show_types, significant_digits) | ||
| node_names.append(node_name) | ||
| def add_state_nodes(): | ||
| graph.attr('node', **state_node_spec) | ||
| for node in graph_dict['state_nodes']: | ||
| name = add_node_to_graph(graph, node, state_node_spec, show_values, show_types, significant_digits) | ||
| node_names.append(name) | ||
| # process nodes | ||
| process_paths = [] | ||
| graph.attr('node', **process_node_spec) | ||
| process_nodes = graph_dict['process_nodes'] | ||
| if max_nodes_per_row: | ||
| previous_node = None | ||
| for i, chunk in enumerate(chunked(process_nodes, max_nodes_per_row)): | ||
| with graph.subgraph(name=f'process_row_{i}') as row: | ||
| row.attr(rank='same') | ||
| chunk_node_names = [] | ||
| for node in chunk: | ||
| node_path = node['path'] | ||
| process_paths.append(node_path) | ||
| node_name = str(node_path) | ||
| node_names.append(node_name) | ||
| chunk_node_names.append(node_name) | ||
| label = make_label(node_path[-1]) | ||
| row.node(node_name, label=label) | ||
| def add_process_nodes(): | ||
| """Add process nodes to the graph, with optional collapse of redundant processes.""" | ||
| graph.attr('node', **process_node_spec) | ||
| process_fingerprints = defaultdict(list) | ||
| if previous_node and chunk_node_names: | ||
| graph.edge(previous_node, chunk_node_names[0], style='invis', weight='10') | ||
| if chunk_node_names: | ||
| previous_node = chunk_node_names[-1] | ||
| else: | ||
| for node in process_nodes: | ||
| # Build fingerprints for each process based on edge connectivity | ||
| for node in graph_dict['process_nodes']: | ||
| node_path = node['path'] | ||
| process_paths.append(node_path) | ||
| node_name = str(node_path) | ||
| node_names.append(node_name) | ||
| label = make_label(node_path[-1]) | ||
| graph.node(node_name, label=label) | ||
| path_str = str(node_path) | ||
| node_name = node_path[-1] | ||
| # place edges | ||
| graph.attr('edge', arrowhead='none', penwidth='2') | ||
| for edge in graph_dict['place_edges']: | ||
| # show edge or not | ||
| show_edge = True | ||
| if remove_process_place_edges and edge['child'] in process_paths: | ||
| show_edge = False | ||
| elif edge in invisible_edges: | ||
| show_edge = False | ||
| if show_edge: | ||
| graph.attr('edge', style='filled') | ||
| else: | ||
| graph.attr('edge', style='invis') | ||
| parent_node = str(edge['parent']) | ||
| child_node = str(edge['child']) | ||
| graph.edge(parent_node, child_node, | ||
| dir='forward', constraint='true' | ||
| ) | ||
| fingerprint = [] | ||
| for group, tag in [('input_edges', 'in'), ('output_edges', 'out'), ('bidirectional_edges', 'both')]: | ||
| for edge in graph_dict.get(group, []): | ||
| if edge['edge_path'] == node_path: | ||
| fingerprint.append((tag, edge['port'], str(edge.get('target_path')))) | ||
| fingerprint = tuple(sorted(fingerprint)) | ||
| process_fingerprints[fingerprint].append((node_path, path_str, node_name)) | ||
| # input edges | ||
| for edge in graph_dict['input_edges']: | ||
| if edge['type'] == 'bridge_inputs': | ||
| graph.attr('edge', **output_edge_spec) # reverse arrow direction to go from composite to store | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='false') | ||
| collapse_map = {} | ||
| # Only collapse if flag is enabled | ||
| if collapse_redundant_processes: | ||
| for fingerprint, entries in process_fingerprints.items(): | ||
| names = [entry[2] for entry in entries] | ||
| template = get_name_template(names) | ||
| count = len(entries) | ||
| label = template if count == 1 else f"{template} (x{count})" | ||
| representative = str(entries[0][0]) | ||
| graph.node(representative, label=label) | ||
| node_names.append(representative) | ||
| for path, path_str, _ in entries: | ||
| if str(path) != representative: | ||
| collapse_map[str(path)] = representative | ||
| else: | ||
| graph.attr('edge', **input_edge_spec) | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='true') | ||
| # output edges | ||
| for edge in graph_dict['output_edges']: | ||
| if edge['type'] == 'bridge_outputs': | ||
| graph.attr('edge', **input_edge_spec) # reverse arrow direction to go from store to composite | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='false') | ||
| else: | ||
| graph.attr('edge', **output_edge_spec) | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='true') | ||
| # bidirectional edges | ||
| for edge in graph_dict['bidirectional_edges']: | ||
| if 'bridge_outputs' not in edge['type'] and 'bridge_inputs' not in edge['type']: | ||
| graph.attr('edge', **bidirectional_edge_spec) | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='true') | ||
| else: | ||
| if 'bridge_outputs' in edge['type']: | ||
| graph.attr('edge', **input_edge_spec) # reverse arrow direction to go from store to composite | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='false') | ||
| if 'bridge_inputs' in edge['type']: | ||
| graph.attr('edge', **output_edge_spec) # reverse arrow direction to go from composite to store | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='false') | ||
| # Add all process nodes without collapsing | ||
| for entries in process_fingerprints.values(): | ||
| for path, path_str, name in entries: | ||
| graph.node(path_str, label=name) | ||
| node_names.append(path_str) | ||
| # state nodes again | ||
| # TODO -- this is a hack to make sure the state nodes show up as circles | ||
| graph.attr('node', **state_node_spec) | ||
| for node in graph_dict['state_nodes']: | ||
| node_path = node['path'] | ||
| node_name = add_node_to_graph(graph, node, state_node_spec, show_values, show_types, significant_digits) | ||
| node_names.append(node_name) | ||
| return [entry[0] for entries in process_fingerprints.values() for entry in entries], collapse_map | ||
| # disconnected input edges | ||
| for edge in graph_dict['disconnected_input_edges']: | ||
| process_path = edge['edge_path'] | ||
| port = edge['port'] | ||
| # add invisible node for port | ||
| node_name2 = str(absolute_path(process_path, port)) + '_input' | ||
| graph.node(node_name2, label='', style='invis', width='0') | ||
| edge['target_path'] = node_name2 | ||
| graph.attr('edge', **input_edge_spec) | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='true') | ||
| # disconnected output edges | ||
| for edge in graph_dict['disconnected_output_edges']: | ||
| process_path = edge['edge_path'] | ||
| port = edge['port'] | ||
| # add invisible node for port | ||
| node_name2 = str(absolute_path(process_path, port)) + '_output' | ||
| graph.node(node_name2, label='', style='invis', width='0') | ||
| edge['target_path'] = node_name2 | ||
| graph.attr('edge', **output_edge_spec) | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='true') | ||
| def rewrite_collapsed_edges(collapse_map): | ||
| removed_keys = set(collapse_map.keys()) | ||
| for group in ['input_edges', 'output_edges', 'bidirectional_edges', 'disconnected_input_edges', 'disconnected_output_edges']: | ||
| edges = graph_dict.get(group, []) | ||
| new_edges = [] | ||
| for edge in edges: | ||
| key = str(edge['edge_path']) | ||
| if key in collapse_map: | ||
| edge['edge_path'] = collapse_map[key] | ||
| if edge not in new_edges: | ||
| new_edges.append(edge) | ||
| elif key not in removed_keys: | ||
| new_edges.append(edge) | ||
| graph_dict[group] = new_edges | ||
| # grouped nodes | ||
| for group in node_groups: | ||
| # convert lists to tuples | ||
| group = [tuple(item) for item in group] | ||
| # Remove any place_edges associated with collapsed processes | ||
| new_place_edges = [] | ||
| for edge in graph_dict.get('place_edges', []): | ||
| parent_str = str(edge['parent']) | ||
| child_str = str(edge['child']) | ||
| if parent_str in removed_keys or child_str in removed_keys: | ||
| continue | ||
| new_place_edges.append(edge) | ||
| graph_dict['place_edges'] = new_place_edges | ||
| group_name = str(group) | ||
| with graph.subgraph(name=group_name) as c: | ||
| c.attr(rank='same') | ||
| previous_node = None | ||
| for path in group: | ||
| node_name = str(path) | ||
| if node_name in node_names: | ||
| c.node(node_name) | ||
| if previous_node: | ||
| # out them in the order declared in the group | ||
| c.edge(previous_node, node_name, style='invis', ordering='out') | ||
| previous_node = node_name | ||
| def add_edges(edge_groups): | ||
| for group, style_key in edge_groups: | ||
| for edge in graph_dict.get(group, []): | ||
| if 'bridge_outputs' in edge['type']: | ||
| style, constraint = 'input', 'false' | ||
| elif 'bridge_inputs' in edge['type']: | ||
| style, constraint = 'output', 'false' | ||
| else: | ||
| print(f'node {node_name} not in graph') | ||
| # formatting | ||
| if node_border_colors: | ||
| for node_name, color in node_border_colors.items(): | ||
| graph.node(str(node_name), color=color) | ||
| if node_fill_colors: | ||
| for node_name, color in node_fill_colors.items(): | ||
| graph.node(str(node_name), color=color, style='filled') | ||
| style, constraint = style_key, 'true' | ||
| graph.attr('edge', **edge_styles[style]) | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint=constraint) | ||
| def add_place_edges(process_paths): | ||
| for edge in graph_dict['place_edges']: | ||
| visible = not ((remove_process_place_edges and edge['child'] in process_paths) or (edge in invisible_edges)) | ||
| graph.attr('edge', style='filled' if visible else 'invis') | ||
| graph.edge(str(edge['parent']), str(edge['child']), **edge_styles['place'], constraint='true') | ||
| def add_disconnected_edges(): | ||
| for direction, style_key in [('disconnected_input_edges', 'input'), ('disconnected_output_edges', 'output')]: | ||
| for edge in graph_dict[direction]: | ||
| path = edge['edge_path'] | ||
| port = edge['port'] | ||
| suffix = '_input' if 'input' in direction else '_output' | ||
| dummy = str(absolute_path(path, port)) + suffix | ||
| graph.node(dummy, label='', style='invis', width='0') | ||
| edge['target_path'] = dummy | ||
| graph.attr('edge', **edge_styles[style_key]) | ||
| plot_edges(graph, edge, port_labels, port_label_size, state_node_spec, constraint='true') | ||
| def rank_node_groups(): | ||
| for group in node_groups: | ||
| group = [tuple(g) for g in group] | ||
| with graph.subgraph(name=str(group)) as sg: | ||
| sg.attr(rank='same') | ||
| prev = None | ||
| for path in group: | ||
| name = str(path) | ||
| if name in node_names: | ||
| sg.node(name) | ||
| if prev: | ||
| sg.edge(prev, name, style='invis', ordering='out') | ||
| prev = name | ||
| def apply_custom_colors(): | ||
| if node_border_colors: | ||
| for name, color in node_border_colors.items(): | ||
| graph.node(str(name), color=color) | ||
| if node_fill_colors: | ||
| for name, color in node_fill_colors.items(): | ||
| graph.node(str(name), color=color, style='filled') | ||
| add_state_nodes() | ||
| process_paths, collapse_map = add_process_nodes() | ||
| if collapse_redundant_processes: | ||
| rewrite_collapsed_edges(collapse_map) | ||
| add_place_edges(process_paths) | ||
| add_edges([('input_edges', 'input'), ('output_edges', 'output'), ('bidirectional_edges', 'bidirectional')]) | ||
| add_state_nodes() | ||
| add_disconnected_edges() | ||
| rank_node_groups() | ||
| apply_custom_colors() | ||
| return graph | ||
@@ -404,24 +378,31 @@ | ||
| def plot_bigraph( | ||
| state, | ||
| schema=None, | ||
| core=None, | ||
| out_dir=None, | ||
| filename=None, | ||
| file_format='png', | ||
| **kwargs | ||
| state, | ||
| schema=None, | ||
| core=None, | ||
| out_dir=None, | ||
| filename=None, | ||
| file_format='png', | ||
| **kwargs | ||
| ): | ||
| # inspect the signature of plot_bigraph | ||
| get_graphviz_fig_signature = inspect.signature(get_graphviz_fig) | ||
| """ | ||
| Create and render a bigraph visualization using Graphviz from a given state and optional schema. | ||
| # Filter kwargs to only include those accepted by get_graphviz_fig | ||
| get_graphviz_kwargs = { | ||
| k: v for k, v in kwargs.items() | ||
| if k in get_graphviz_fig_signature.parameters} | ||
| Parameters: | ||
| state (dict): The simulation state. | ||
| schema (dict): Optional schema defining the structure of the state. | ||
| core (VisualizeTypes): Visualization engine. | ||
| out_dir (str): Directory to write output. | ||
| filename (str): Name of the output file. | ||
| file_format (str): Output format (e.g., 'png', 'svg'). | ||
| **kwargs: Additional arguments for styling or traversal. | ||
| # get the remaining kwargs | ||
| viztype_kwargs = { | ||
| k: v for k, v in kwargs.items() | ||
| if k not in get_graphviz_kwargs} | ||
| Returns: | ||
| graphviz.Digraph: Rendered graph object. | ||
| """ | ||
| # Separate kwargs into rendering and traversal arguments | ||
| graphviz_sig = inspect.signature(get_graphviz_fig) | ||
| render_kwargs = {k: v for k, v in kwargs.items() if k in graphviz_sig.parameters} | ||
| traversal_kwargs = {k: v for k, v in kwargs.items() if k not in graphviz_sig.parameters} | ||
| # set defaults if none provided | ||
| # Defaults | ||
| core = core or VisualizeTypes() | ||
@@ -431,8 +412,3 @@ schema = schema or {} | ||
| graph_dict = core.generate_graph_dict( | ||
| schema, | ||
| state, | ||
| (), | ||
| options=viztype_kwargs # TODO | ||
| ) | ||
| graph_dict = core.generate_graph_dict(schema, state, (), options=traversal_kwargs) | ||
@@ -444,3 +420,4 @@ return core.plot_graph( | ||
| file_format=file_format, | ||
| options=get_graphviz_kwargs) | ||
| options=render_kwargs | ||
| ) | ||
@@ -450,19 +427,16 @@ | ||
| def graphviz_any(core, schema, state, path, options, graph): | ||
| """Visualize any type (generic node).""" | ||
| schema = schema or {} | ||
| if len(path) > 0: | ||
| if path: | ||
| node_spec = { | ||
| 'name': path[-1], | ||
| 'path': path, | ||
| 'value': None, | ||
| 'type': core.representation(schema)} | ||
| if not isinstance(state, dict): | ||
| node_spec['value'] = state | ||
| 'value': state if not isinstance(state, dict) else None, | ||
| 'type': core.representation(schema) | ||
| } | ||
| graph['state_nodes'].append(node_spec) | ||
| if len(path) > 1: | ||
| graph['place_edges'].append({ | ||
| 'parent': path[:-1], | ||
| 'child': path}) | ||
| graph['place_edges'].append({'parent': path[:-1], 'child': path}) | ||
@@ -472,15 +446,14 @@ if isinstance(state, dict): | ||
| if not is_schema_key(key): | ||
| subpath = path + (key,) | ||
| graph = core.get_graph_dict( | ||
| schema.get(key, {}), | ||
| value, | ||
| subpath, | ||
| path + (key,), | ||
| options, | ||
| graph) | ||
| graph | ||
| ) | ||
| return graph | ||
| def graphviz_edge(core, schema, state, path, options, graph): | ||
| # add process node to graph | ||
| """Visualize a process node with input/output/bridge wiring.""" | ||
| node_spec = { | ||
@@ -490,5 +463,5 @@ 'name': path[-1], | ||
| 'value': None, | ||
| 'type': core.representation(schema)} | ||
| 'type': core.representation(schema) | ||
| } | ||
| # check if this is actually a composite node | ||
| if state.get('address') == 'local:composite' and node_spec not in graph['process_nodes']: | ||
@@ -500,61 +473,21 @@ graph['process_nodes'].append(node_spec) | ||
| # get the wires and ports | ||
| input_wires = state.get('inputs', {}) | ||
| output_wires = state.get('outputs', {}) | ||
| input_ports = state.get('_inputs', schema.get('_inputs', {})) | ||
| output_ports = state.get('_outputs', schema.get('_outputs', {})) | ||
| # Wiring | ||
| graph = get_graph_wires(schema.get('_inputs', {}), state.get('inputs', {}), graph, 'inputs', path, state.get('bridge', {}).get('inputs', {})) | ||
| graph = get_graph_wires(schema.get('_outputs', {}), state.get('outputs', {}), graph, 'outputs', path, state.get('bridge', {}).get('outputs', {})) | ||
| # bridge | ||
| bridge_wires = state.get('bridge', {}) | ||
| bridge_inputs = bridge_wires.get('inputs', {}) | ||
| bridge_outputs = bridge_wires.get('outputs', {}) | ||
| # Merge bidirectional edges | ||
| def key(edge): return (tuple(edge['edge_path']), tuple(edge['target_path']), edge['port']) | ||
| input_set = {key(e): e for e in graph['input_edges']} | ||
| output_set = {key(e): e for e in graph['output_edges']} | ||
| shared_keys = input_set.keys() & output_set.keys() | ||
| for k in shared_keys: | ||
| graph['bidirectional_edges'].append({ | ||
| 'edge_path': k[0], 'target_path': k[1], 'port': k[2], | ||
| 'type': (input_set[k]['type'], output_set[k]['type']) | ||
| }) | ||
| graph['input_edges'] = [e for k, e in input_set.items() if k not in shared_keys] | ||
| graph['output_edges'] = [e for k, e in output_set.items() if k not in shared_keys] | ||
| # get the input wires | ||
| graph = get_graph_wires( | ||
| input_ports, input_wires, graph, | ||
| schema_key='inputs', edge_path=path, | ||
| bridge_wires=bridge_inputs) | ||
| # get the output wires | ||
| graph = get_graph_wires( | ||
| output_ports, output_wires, graph, | ||
| schema_key='outputs', edge_path=path, | ||
| bridge_wires=bridge_outputs) | ||
| # get bidirectional wires | ||
| input_edges_to_remove = [] | ||
| output_edges_to_remove = [] | ||
| for input_edge in graph['input_edges']: | ||
| for output_edge in graph['output_edges']: | ||
| if (input_edge['target_path'] == output_edge['target_path']) and \ | ||
| (input_edge['port'] == output_edge['port']) and \ | ||
| (input_edge['edge_path'] == output_edge['edge_path']): | ||
| graph['bidirectional_edges'].append({ | ||
| 'edge_path': input_edge['edge_path'], | ||
| 'target_path': input_edge['target_path'], | ||
| 'port': input_edge['port'], | ||
| 'type': (input_edge['type'], output_edge['type']), | ||
| # 'type': 'bidirectional' | ||
| }) | ||
| input_edges_to_remove.append(input_edge) | ||
| output_edges_to_remove.append(output_edge) | ||
| break # prevent matching the same input_edge with multiple output_edges | ||
| # Remove matched edges after iteration | ||
| for edge in input_edges_to_remove: | ||
| graph['input_edges'].remove(edge) | ||
| for edge in output_edges_to_remove: | ||
| graph['output_edges'].remove(edge) | ||
| # get the input and output bridge wires | ||
| if bridge_wires: | ||
| # check that the bridge wires connect to valid ports | ||
| assert set(bridge_wires.keys()).issubset({'inputs', 'outputs'}) | ||
| # add the process node path | ||
| if len(path) > 1: | ||
| graph['place_edges'].append({ | ||
| 'parent': path[:-1], | ||
| 'child': path}) | ||
| graph['place_edges'].append({'parent': path[:-1], 'child': path}) | ||
@@ -565,2 +498,3 @@ return graph | ||
| def graphviz_none(core, schema, state, path, options, graph): | ||
| """No-op visualizer for nodes with no visualization.""" | ||
| return graph | ||
@@ -570,29 +504,21 @@ | ||
| def graphviz_composite(core, schema, state, path, options, graph): | ||
| # add the composite edge | ||
| """Visualize composite nodes by recursing into their internal structure.""" | ||
| graph = graphviz_edge(core, schema, state, path, options, graph) | ||
| # get the inner state and schema | ||
| inner_state = state.get('config', {}).get('state') | ||
| inner_schema = state.get('config', {}).get('composition') | ||
| if inner_state is None: | ||
| inner_state = state | ||
| inner_schema = schema | ||
| inner_state = state.get('config', {}).get('state') or state | ||
| inner_schema = state.get('config', {}).get('composition') or schema | ||
| inner_schema, inner_state = core.generate(inner_schema, inner_state) | ||
| # add the process node path | ||
| if len(path) > 1: | ||
| graph['place_edges'].append({ | ||
| 'parent': path[:-1], | ||
| 'child': path}) | ||
| graph['place_edges'].append({'parent': path[:-1], 'child': path}) | ||
| # add the inner nodes and edges | ||
| for key, value in inner_state.items(): | ||
| if not is_schema_key(key) and key not in PROCESS_SCHEMA_KEYS: | ||
| subpath = path + (key,) | ||
| graph = core.get_graph_dict( | ||
| inner_schema.get(key), | ||
| value, | ||
| subpath, | ||
| path + (key,), | ||
| options, | ||
| graph) | ||
| graph | ||
| ) | ||
@@ -625,3 +551,2 @@ return graph | ||
| # TODO: we want to visualize things that are not yet complete | ||
@@ -1118,88 +1043,55 @@ | ||
| def generate_spec_and_schema(n_rows, n_cols): | ||
| spec = {'cells': {}} | ||
| fields = { | ||
| 'acetate': np.zeros((n_rows, n_cols)), | ||
| 'biomass': np.zeros((n_rows, n_cols)), | ||
| 'glucose': np.zeros((n_rows, n_cols)), | ||
| } | ||
| def test_array_paths(): | ||
| core = VisualizeTypes() | ||
| for i in range(n_rows): | ||
| for j in range(n_cols): | ||
| name = f'dFBA[{i},{j}]' | ||
| cell_spec = { | ||
| '_type': 'process', | ||
| 'address': 'local:DynamicFBA', | ||
| 'inputs': { | ||
| 'substrates': { | ||
| 'acetate': ['..', 'fields', 'acetate', i, j], | ||
| 'biomass': ['..', 'fields', 'biomass', i, j], | ||
| 'glucose': ['..', 'fields', 'glucose', i, j], | ||
| } | ||
| }, | ||
| 'outputs': { | ||
| 'substrates': { | ||
| 'acetate': ['..', 'fields', 'acetate', i, j], | ||
| 'biomass': ['..', 'fields', 'biomass', i, j], | ||
| 'glucose': ['..', 'fields', 'glucose', i, j], | ||
| } | ||
| } | ||
| } | ||
| spec['cells'][name] = cell_spec | ||
| spec = { | ||
| 'dFBA[0,0]': { | ||
| '_type': 'process', | ||
| 'address': 'local:DynamicFBA', | ||
| # 'config': {}, | ||
| 'inputs': { | ||
| 'substrates': { | ||
| 'acetate': ['fields', 'acetate', | ||
| 0, 0 | ||
| ], | ||
| 'biomass': ['fields', 'biomass', | ||
| 0, 0 | ||
| ], | ||
| 'glucose': ['fields', 'glucose', | ||
| 0, 0 | ||
| ]}}, | ||
| 'outputs': { | ||
| 'substrates': { | ||
| 'acetate': ['fields', 'acetate', | ||
| 0, 0 | ||
| ], | ||
| 'biomass': ['fields', 'biomass', | ||
| 0, 0 | ||
| ], | ||
| 'glucose': ['fields', 'glucose', | ||
| 0, 0 | ||
| ]}}, | ||
| }, | ||
| 'dFBA[1,0]': { | ||
| '_type': 'process', | ||
| 'address': 'local:DynamicFBA', | ||
| # 'config': {}, | ||
| 'inputs': { | ||
| 'substrates': { | ||
| 'acetate': ['fields', 'acetate', | ||
| 1, 0 | ||
| ], | ||
| 'biomass': ['fields', 'biomass', | ||
| 1, 0 | ||
| ], | ||
| 'glucose': ['fields', 'glucose', | ||
| 1, 0 | ||
| ]}}, | ||
| 'outputs': { | ||
| 'substrates': { | ||
| 'acetate': ['fields', 'acetate', | ||
| 1, 0 | ||
| ], | ||
| 'biomass': ['fields', 'biomass', | ||
| 1, 0 | ||
| ], | ||
| 'glucose': ['fields', 'glucose', | ||
| 1, 0 | ||
| ]}}, | ||
| }, | ||
| 'fields': { | ||
| 'acetate': np.array([[1.0], [2.0]]), | ||
| 'biomass': np.array([[3.0], [4.0]]), | ||
| 'glucose': np.array([[5.0], [6.0]]) | ||
| } | ||
| } | ||
| # Add fields to spec | ||
| spec['fields'] = fields | ||
| # Generate schema | ||
| schema = { | ||
| 'fields': { | ||
| 'acetate': { | ||
| mol: { | ||
| '_type': 'array', | ||
| '_shape': (2, 1), | ||
| '_shape': (n_rows, n_cols), | ||
| '_data': 'float' | ||
| }, | ||
| 'biomass': { | ||
| '_type': 'array', | ||
| '_shape': (2, 1), | ||
| '_data': 'float', | ||
| }, | ||
| 'glucose': { | ||
| '_type': 'array', | ||
| '_shape': (2, 1), | ||
| '_data': 'float', | ||
| } | ||
| } for mol in ['acetate', 'biomass', 'glucose'] | ||
| } | ||
| } | ||
| return spec, schema | ||
| def test_array_paths(): | ||
| core = VisualizeTypes() | ||
| n_rows, n_cols = 2, 1 # or any desired shape | ||
| spec, schema = generate_spec_and_schema(n_rows, n_cols) | ||
| plot_bigraph( | ||
@@ -1213,2 +1105,19 @@ spec, | ||
| def test_complex_bigraph(): | ||
| core = VisualizeTypes() | ||
| n_rows, n_cols = 6, 6 # or any desired shape | ||
| spec, schema = generate_spec_and_schema(n_rows, n_cols) | ||
| plot_settings['dpi'] = '500' | ||
| plot_bigraph( | ||
| spec, | ||
| schema=schema, | ||
| core=core, | ||
| filename='complex_bigraph', | ||
| collapse_redundant_processes=True, | ||
| # dpi='200', | ||
| **plot_settings) | ||
| if __name__ == '__main__': | ||
@@ -1229,1 +1138,2 @@ test_simple_store() | ||
| test_array_paths() | ||
| test_complex_bigraph() |
+1
-1
| Metadata-Version: 2.1 | ||
| Name: bigraph-viz | ||
| Version: 0.1.6 | ||
| Version: 0.1.7 | ||
| Summary: A graphviz-based plotting tool for compositional bigraph schema | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/vivarium-collective/bigraph-viz |
+1
-1
@@ -7,3 +7,3 @@ [build-system] | ||
| name = "bigraph-viz" | ||
| version = "0.1.6" | ||
| version = "0.1.7" | ||
| description = "A visualization method for displaying the structure of process bigraphs" | ||
@@ -10,0 +10,0 @@ readme = "README.md" |
+1
-1
@@ -5,3 +5,3 @@ import re | ||
| VERSION = '0.1.6' | ||
| VERSION = '0.1.7' | ||
@@ -8,0 +8,0 @@ |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
80423
-2.89%1424
-6.44%