Source code for city2graph.graph

"""
Module for creating heterogeneous graph representations of urban environments.

Converts geodataframes containing spatial data into PyTorch Geometric HeteroData objects.
"""

try:
    import torch
    from torch_geometric.data import Data
    from torch_geometric.data import HeteroData
    from torch_geometric.utils import to_networkx as pyg_to_networkx

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False

    # Create placeholder classes to prevent import errors
    class HeteroData:
        pass

    class Data:
        pass


import logging
from typing import Union

import geopandas as gpd
import networkx as nx
import numpy as np
import pandas as pd

from .utils import _validate_gdf
from .utils import _validate_nx

logger = logging.getLogger(__name__)

# Define the public API for this module
__all__ = [
    "from_morphological_graph",
    "gdf_to_pyg",
    "heterogeneous_graph",
    "homogeneous_graph",
    "is_torch_available",
    "nx_to_pyg",
    "pyg_to_gdf",
    "pyg_to_nx",
]


[docs] def is_torch_available() -> bool: """ Check if PyTorch and PyTorch Geometric are available. Returns ------- bool True if PyTorch and PyTorch Geometric are available, False otherwise. """ return TORCH_AVAILABLE
def _get_device( device: Union[str, "torch.device", None] = None, ) -> Union["torch.device", str]: """ Get the appropriate torch device (CUDA if available, otherwise CPU). Parameters ---------- device : str or torch.device, default None Device to use for tensors. Must be 'cuda', 'cpu', torch.device or None. If None, will use CUDA if available, otherwise CPU. Returns ------- torch.device The device to use for tensors Raises ------ ValueError If device is not None, 'cuda', 'cpu', or torch.device ImportError If PyTorch is not installed """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError( msg, ) if device is None: return torch.device("cuda" if torch.cuda.is_available() else "cpu") if isinstance(device, torch.device): return device if device in ["cuda", "cpu"]: return torch.device(device) msg = "Device must be 'cuda', 'cpu', a torch.device object, or None" raise ValueError(msg) def _detect_edge_columns(edge_gdf: gpd.GeoDataFrame, id_col: str | None = None, source_hint: list[str] | None = None, target_hint: list[str] | None = None) -> tuple[str | None, str | None]: """ Detect appropriate source and target columns in an edge GeoDataFrame. Parameters ---------- edge_gdf : pandas.DataFrame DataFrame containing edge data id_col : str, default None Column name used to identify nodes, used as a hint source_hint : list, default None Additional keywords to look for in source column names target_hint : list, default None Additional keywords to look for in target column names Returns ------- tuple (source_col, target_col) - detected column names """ if edge_gdf.empty or len(edge_gdf.columns) < 2: return None, None # Default hint keywords source_keywords = ["from", "source", "start", "u"] target_keywords = ["to", "target", "end", "v"] # Add custom hints if provided if source_hint: source_keywords.extend([hint.lower() for hint in source_hint]) if target_hint: target_keywords.extend([hint.lower() for hint in target_hint]) # Add id_col as hint if provided if id_col: source_keywords.append(id_col.lower()) target_keywords.append(id_col.lower()) # Find columns matching source keywords from_candidates = [ col for col in edge_gdf.columns if any(keyword in col.lower() for keyword in source_keywords) ] # Find columns matching target keywords to_candidates = [ col for col in edge_gdf.columns if any(keyword in col.lower() for keyword in target_keywords) ] # Select best candidates if from_candidates and to_candidates: return from_candidates[0], to_candidates[0] # Fall back to first two columns if needed if len(edge_gdf.columns) >= 2 and "geometry" not in edge_gdf.columns[:2]: return edge_gdf.columns[0], edge_gdf.columns[1] if len(edge_gdf.columns) >= 3 and "geometry" in edge_gdf.columns[0]: return edge_gdf.columns[1], edge_gdf.columns[2] return None, None def _get_edge_columns( edge_gdf: gpd.GeoDataFrame, source_col: str | None, target_col: str | None, source_mapping: dict[str, int], target_mapping: dict[str, int], id_col: str | None = None, ) -> tuple[str | None, str | None]: """Consolidate logic for detecting or confirming source/target columns.""" if source_col is None or target_col is None: detected_source, detected_target = _detect_edge_columns( edge_gdf, id_col=id_col, source_hint=list(source_mapping.keys())[:1] if source_mapping else None, target_hint=list(target_mapping.keys())[:1] if target_mapping else None, ) if source_col is None: source_col = detected_source if target_col is None: target_col = detected_target return source_col, target_col def _extract_node_id_mapping(node_gdf: gpd.GeoDataFrame, id_col: str | None = None) -> tuple[dict[str, int], str]: """ Extract a mapping from node IDs to indices. Parameters ---------- node_gdf : geopandas.GeoDataFrame GeoDataFrame containing node data. id_col : str, optional Column name that uniquely identifies each node. If provided and missing from node_gdf, a ValueError is raised to prevent unintended errors. Returns ------- tuple (id_mapping, used_id_col) - mapping from IDs to indices and the ID column used """ # If id_col is provided but not found, raise a ValueError to alert the user. if id_col is not None and id_col not in node_gdf.columns: msg = f"Provided id_col '{id_col}' not found in node GeoDataFrame columns." raise ValueError( msg, ) if id_col is None: # Use index if id_col is None. id_mapping = {str(idx): i for i, idx in enumerate(node_gdf.index)} return id_mapping, "index" # Use specified column if found. id_mapping = { str(node_id): idx for idx, node_id in enumerate(node_gdf[id_col].unique()) } return id_mapping, id_col # Modified _create_node_features: renamed parameter "attribute_cols" to "feature_cols" def _create_node_features(node_gdf: gpd.GeoDataFrame, feature_cols: list[str] | None = None, device: Union[str, "torch.device"] | None = None) -> "torch.Tensor": """ Create node feature tensors from attribute columns. Parameters ---------- node_gdf : geopandas.GeoDataFrame GeoDataFrame containing node data feature_cols : list, default None List of column names to use as node features device : str, default None Device to use for tensors. Must be 'cuda' or 'cpu' if provided. Returns ------- torch.Tensor Tensor of node features """ device = _get_device(device) if feature_cols is None: return torch.zeros((len(node_gdf), 0), dtype=torch.float, device=device) # Vectorized column validation using set intersection valid_cols = list(set(feature_cols) & set(node_gdf.columns)) if valid_cols: # Direct numpy array conversion for better performance features_array = node_gdf[valid_cols].to_numpy().astype(np.float32) return torch.from_numpy(features_array).to(device=device, dtype=torch.float) return torch.zeros((len(node_gdf), 0), dtype=torch.float, device=device) def _create_edge_features(edge_gdf: gpd.GeoDataFrame, feature_cols: list[str] | None = None, device: Union[str, "torch.device"] | None = None) -> "torch.Tensor": """ Create edge feature tensors from attribute columns in edge_gdf. Parameters ---------- edge_gdf : geopandas.GeoDataFrame GeoDataFrame containing edge data. feature_cols : list, default None List of column names to use as edge features. device : str or torch.device, default None Device to use for tensors. Returns ------- torch.Tensor Tensor of edge features. """ device = _get_device(device) if feature_cols is None: return torch.empty((edge_gdf.shape[0], 0), dtype=torch.float, device=device) # Vectorized column validation using set intersection valid_cols = list(set(feature_cols) & set(edge_gdf.columns)) if not valid_cols: return torch.empty((edge_gdf.shape[0], 0), dtype=torch.float, device=device) # Direct numpy conversion for better performance features_array = edge_gdf[valid_cols].to_numpy().astype(np.float32) return torch.from_numpy(features_array).to(device=device, dtype=torch.float) def _map_edge_strings(edge_gdf: gpd.GeoDataFrame, source_col: str, target_col: str) -> gpd.GeoDataFrame: """Convert source/target columns to string once for vectorized lookups.""" edge_gdf[f"__{source_col}_str"] = edge_gdf[source_col].astype(str) edge_gdf[f"__{target_col}_str"] = edge_gdf[target_col].astype(str) return edge_gdf def _create_edge_idx_pairs(edge_gdf: gpd.GeoDataFrame, source_mapping: dict[str, int], target_mapping: dict[str, int] | None = None, source_col: str | None = None, target_col: str | None = None) -> list[list[int]]: """ Process edges to create edge indices using vectorized operations. Parameters ---------- edge_gdf : pandas.DataFrame DataFrame containing edge data source_mapping : dict Mapping from source IDs to indices target_mapping : dict, default None Mapping from target IDs to indices. If None, will use source_mapping. source_col : str, default None Column name for source node IDs target_col : str, default None Column name for target node IDs Returns ------- list List of [source_idx, target_idx] pairs """ if target_mapping is None: target_mapping = source_mapping # Detect columns if not provided source_col, target_col = _get_edge_columns( edge_gdf, source_col, target_col, source_mapping, target_mapping, ) # Skip if we couldn't determine columns if ( source_col is None or target_col is None or source_col not in edge_gdf.columns or target_col not in edge_gdf.columns ): return [] # Convert IDs to strings once for better performance source_ids = edge_gdf[source_col].astype(str) target_ids = edge_gdf[target_col].astype(str) # Vectorized filtering for valid edges using pandas Series operations valid_src_mask = source_ids.isin(source_mapping.keys()) valid_dst_mask = target_ids.isin(target_mapping.keys()) valid_edges_mask = valid_src_mask & valid_dst_mask # Count missing IDs missing_src_count = (~valid_src_mask).sum() missing_dst_count = (~valid_dst_mask).sum() if missing_src_count > 0 or missing_dst_count > 0: logger.warning( "Missing source IDs: %d, missing target IDs: %d", missing_src_count, missing_dst_count, ) # Process valid edges vectorized if not valid_edges_mask.any(): return [] valid_sources = source_ids[valid_edges_mask] valid_targets = target_ids[valid_edges_mask] # Vectorized mapping using pandas map from_indices = valid_sources.map(source_mapping).to_numpy() to_indices = valid_targets.map(target_mapping).to_numpy() # Create edge list return np.column_stack([from_indices, to_indices]).tolist() # New helper to check if an edge GeoDataFrame is valid def _is_valid_edge_df(edge_gdf: gpd.GeoDataFrame | None) -> bool: return edge_gdf is not None and not edge_gdf.empty # Remove the is_hetero parameter and always expect dictionaries for nodes and edges. # Updated _build_graph_data to accept node_y_attribute_cols parameter # Modified _build_graph_data: renamed parameters and references def _process_node_type(node_type: str, node_gdf: gpd.GeoDataFrame, node_id_cols: dict[str, str], node_feature_cols: dict[str, list[str]], node_label_cols: dict[str, list[str]] | None, device: Union[str, "torch.device"], data: HeteroData) -> dict[str, dict]: """Process a single node type and add to HeteroData.""" if not isinstance(node_gdf, gpd.GeoDataFrame): logger.warning("Expected GeoDataFrame for node type %s, got %s", node_type, type(node_gdf)) # Convert regular DataFrame to GeoDataFrame if it has x/y columns or lat/lon columns import pandas as pd if isinstance(node_gdf, pd.DataFrame): if "x" in node_gdf.columns and "y" in node_gdf.columns: # Vectorized Point creation geometry = gpd.points_from_xy(node_gdf.x, node_gdf.y) node_gdf = gpd.GeoDataFrame(node_gdf, geometry=geometry) elif "lat" in node_gdf.columns and "lon" in node_gdf.columns: # Vectorized Point creation geometry = gpd.points_from_xy(node_gdf.lon, node_gdf.lat) node_gdf = gpd.GeoDataFrame(node_gdf, geometry=geometry) else: return {} else: return {} id_col = node_id_cols.get(node_type) id_mapping, actual_id_col = _extract_node_id_mapping(node_gdf, id_col) feature_cols = node_feature_cols.get(node_type) data[node_type].x = _create_node_features(node_gdf, feature_cols, device) # Vectorized position extraction if geometry is present if "geometry" in node_gdf.columns and hasattr(node_gdf.geometry, "values"): # Use vectorized operations for position extraction geom_series = node_gdf.geometry # Check if all geometries have x, y attributes (Points) is_point_mask = geom_series.geom_type == "Point" if is_point_mask.all(): # All points - direct coordinate extraction pos_data = np.column_stack([geom_series.x.to_numpy(), geom_series.y.to_numpy()]) else: # Mixed geometries - use centroid for non-points pos_data = np.zeros((len(geom_series), 2)) # Points - direct coordinates if is_point_mask.any(): point_coords = np.column_stack([ geom_series[is_point_mask].x.to_numpy(), geom_series[is_point_mask].y.to_numpy(), ]) pos_data[is_point_mask] = point_coords # Non-points - centroids if (~is_point_mask).any(): centroids = geom_series[~is_point_mask].centroid centroid_coords = np.column_stack([ centroids.x.to_numpy(), centroids.y.to_numpy(), ]) pos_data[~is_point_mask] = centroid_coords data[node_type].pos = torch.tensor(pos_data, dtype=torch.float, device=device) # Add label columns if node_label_cols and node_label_cols.get(node_type): data[node_type].y = _create_node_features( node_gdf, node_label_cols[node_type], device, ) elif "y" in node_gdf.columns: data[node_type].y = torch.tensor( node_gdf["y"].to_numpy(), dtype=torch.float, device=device, ) return {"mapping": id_mapping, "id_col": actual_id_col} def _process_edge_type(edge_type: tuple[str, str, str], edge_gdf: gpd.GeoDataFrame, node_id_mappings: dict, edge_source_cols: dict[tuple[str, str, str], str], edge_target_cols: dict[tuple[str, str, str], str], edge_feature_cols: dict[tuple[str, str, str], list[str]] | None, device: Union[str, "torch.device"], data: HeteroData) -> None: """Process a single edge type and add to HeteroData.""" if not isinstance(edge_type, tuple) or len(edge_type) != 3: logger.warning( "Edge type key must be a tuple of (source_type, relation_type, " "target_type). Got %s instead. Skipping.", edge_type, ) return src_type, rel_type, dst_type = edge_type if src_type not in node_id_mappings or dst_type not in node_id_mappings: logger.warning( "Edge type %s references node type(s) not present in nodes. Skipping.", edge_type, ) return src_mapping = node_id_mappings[src_type]["mapping"] dst_mapping = node_id_mappings[dst_type]["mapping"] source_col = edge_source_cols.get(edge_type) target_col = edge_target_cols.get(edge_type) if _is_valid_edge_df(edge_gdf): pairs = _create_edge_idx_pairs( edge_gdf, source_mapping=src_mapping, target_mapping=dst_mapping, source_col=source_col, target_col=target_col, ) if pairs: data[src_type, rel_type, dst_type].edge_index = torch.tensor( np.array(pairs).T, dtype=torch.long, device=device, ) else: data[src_type, rel_type, dst_type].edge_index = torch.zeros( (2, 0), dtype=torch.long, device=device, ) feature_cols = ( edge_feature_cols.get(edge_type) if edge_feature_cols else None ) data[src_type, rel_type, dst_type].edge_attr = _create_edge_features( edge_gdf, feature_cols, device, ) else: data[src_type, rel_type, dst_type].edge_index = torch.zeros( (2, 0), dtype=torch.long, device=device, ) data[src_type, rel_type, dst_type].edge_attr = torch.empty( (0, 0), dtype=torch.float, device=device, ) def _build_graph_data(nodes: dict[str, gpd.GeoDataFrame], edges: dict[tuple[str, str, str], gpd.GeoDataFrame], node_id_cols: dict[str, str], node_feature_cols: dict[str, list[str]], node_label_cols: dict[str, list[str]] | None, edge_source_cols: dict[tuple[str, str, str], str], edge_target_cols: dict[tuple[str, str, str], str], edge_feature_cols: dict[tuple[str, str, str], list[str]] | None, device: Union[str, "torch.device"] | None) -> HeteroData: """ Build a heterogeneous graph (HeteroData) from node and edge GeoDataFrames. Parameters ---------- nodes : dict Dictionary of node GeoDataFrames keyed by node type. edges : dict Dictionary of edge GeoDataFrames keyed by (source_type, relation, target_type). node_id_cols : dict Dictionary mapping node types to the ID column name. node_feature_cols : dict Dictionary mapping node types to lists of feature column names. node_label_cols : dict, optional Dictionary mapping node types to lists of label column names. edge_source_cols : dict Dictionary mapping edge type tuples to source column names. edge_target_cols : dict Dictionary mapping edge type tuples to target column names. edge_feature_cols : dict, optional Dictionary mapping edge type tuples to lists of edge feature columns. device : torch.device or str, optional Device to be used for tensor creation. Returns ------- torch_geometric.data.HeteroData A PyTorch Geometric HeteroData graph object. """ device = _get_device(device) data = HeteroData() node_id_mappings = {} # Process nodes across types for node_type, node_gdf in nodes.items(): mapping_info = _process_node_type( node_type, node_gdf, node_id_cols, node_feature_cols, node_label_cols, device, data, ) if mapping_info: node_id_mappings[node_type] = mapping_info # Process edges across types for edge_type, edge_gdf in edges.items(): _process_edge_type( edge_type, edge_gdf, node_id_mappings, edge_source_cols, edge_target_cols, edge_feature_cols, device, data, ) # Set CRS metadata from node GeoDataFrames crs_values = [gdf.crs for gdf in nodes.values() if hasattr(gdf, "crs") and gdf.crs] if not crs_values: data.crs = {} elif all(crs == crs_values[0] for crs in crs_values): data.crs = crs_values[0] else: msg = "CRS mismatch among node GeoDataFrames." raise ValueError(msg) # Store metadata for reconstruction _store_reconstruction_metadata( data, nodes=nodes, edges=edges, node_id_cols=node_id_cols, node_feature_cols=node_feature_cols, node_label_cols=node_label_cols, edge_source_cols=edge_source_cols, edge_target_cols=edge_target_cols, edge_feature_cols=edge_feature_cols, ) return data
[docs] def homogeneous_graph(nodes_gdf: gpd.GeoDataFrame, edges_gdf: gpd.GeoDataFrame | None = None, node_id_col: str | None = None, node_feature_cols: list[str] | None = None, node_label_cols: list[str] | None = None, edge_source_col: str | None = None, edge_target_col: str | None = None, edge_feature_cols: list[str] | None = None, device: Union[str, "torch.device"] | None = None) -> Data: """ Create a homogeneous graph Data object from nodes and edges GeoDataFrames. Parameters ---------- nodes_gdf : GeoDataFrame GeoDataFrame containing node data. edges_gdf : GeoDataFrame, optional GeoDataFrame containing edge data. node_id_col : str, optional Column name that uniquely identifies each node. node_feature_cols : list of str, optional List of columns to use as node features. node_label_cols : list of str, optional List of columns to use as node labels. edge_source_col : str, optional Column name for source node IDs in the edge GeoDataFrame. edge_target_col : str, optional Column name for target node IDs in the edge GeoDataFrame. edge_feature_cols : list of str, optional List of columns to use as edge features. device : torch.device or str, optional Device for tensor creation. Returns ------- torch_geometric.data.Data A PyTorch Geometric Data graph object. Raises ------ ImportError If PyTorch and PyTorch Geometric are not installed """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError(msg) # Preprocess homogeneous graph inputs into dictionaries. # Ensure at least empty edge type entry for homogeneous graphs nodes_dict = {"node": nodes_gdf} # Use explicit None check to avoid ambiguous truth of GeoDataFrame edges_dict = { ("node", "edge", "node"): edges_gdf if edges_gdf is not None else gpd.GeoDataFrame(), } node_id_cols = {"node": node_id_col} if node_id_col else {} node_feature_cols = {"node": node_feature_cols} if node_feature_cols else {} node_label_cols = {"node": node_label_cols} if node_label_cols else None edge_source_cols = {("node", "edge", "node"): edge_source_col} edge_target_cols = {("node", "edge", "node"): edge_target_col} # Wrap edge features into dict mapping for builder edge_feature_map = ( {("node", "edge", "node"): edge_feature_cols} if edge_feature_cols is not None else None ) hetero_data = _build_graph_data( nodes=nodes_dict, edges=edges_dict, node_id_cols=node_id_cols, node_feature_cols=node_feature_cols, node_label_cols=node_label_cols, edge_source_cols=edge_source_cols, edge_target_cols=edge_target_cols, edge_feature_cols=edge_feature_map, device=device, ) data = Data( x=hetero_data["node"].x, edge_index=hetero_data[("node", "edge", "node")].edge_index, edge_attr=hetero_data[("node", "edge", "node")].edge_attr, pos=hetero_data["node"].get("pos", None), ) # Assign "y" node attribute if exists data.y = hetero_data["node"].get("y", None) # Assign CRS metadata from hetero_data data.crs = hetero_data.crs # Store metadata for reconstruction using the helper function _store_reconstruction_metadata( data, nodes=nodes_gdf, edges=edges_gdf, node_id_cols=node_id_col, node_feature_cols=node_feature_cols, node_label_cols=node_label_cols, edge_source_cols=edge_source_col, edge_target_cols=edge_target_col, edge_feature_cols=edge_feature_cols, ) return data
def _process_single_nodes_gdf(nodes_gdf: gpd.GeoDataFrame) -> dict[str, gpd.GeoDataFrame]: """Process a single nodes GeoDataFrame into a dictionary by type.""" if "type" in nodes_gdf.columns: # Split by type column nodes_dict = {} for node_type in nodes_gdf["type"].unique(): subset = nodes_gdf[nodes_gdf["type"] == node_type].copy() # Ensure we maintain GeoDataFrame type if not isinstance(subset, gpd.GeoDataFrame): geometry = subset.geometry if hasattr(subset, "geometry") else None subset = gpd.GeoDataFrame(subset, geometry=geometry) nodes_dict[node_type] = subset return nodes_dict # Default to single node type return {"default": nodes_gdf} def _process_single_edges_gdf(edges_gdf: gpd.GeoDataFrame) -> dict[tuple[str, str, str], gpd.GeoDataFrame]: """Process a single edges GeoDataFrame into a dictionary by edge type.""" if "edge_type" in edges_gdf.columns: # Split by edge_type column - assume format is "source_relation_target" edges_dict = {} for edge_type_str in edges_gdf["edge_type"].unique(): # Try to parse edge type string parts = str(edge_type_str).split("_", 2) edge_key = tuple(parts) if len(parts) == 3 else ("default", "edge", "default") edges_dict[edge_key] = edges_gdf[edges_gdf["edge_type"] == edge_type_str].copy() return edges_dict # Default to single edge type return {("default", "edge", "default"): edges_gdf}
[docs] def heterogeneous_graph(nodes_dict: dict[str, gpd.GeoDataFrame] | gpd.GeoDataFrame, edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame, node_id_cols: dict[str, str] | None = None, node_feature_cols: dict[str, list[str]] | None = None, node_label_cols: dict[str, list[str]] | None = None, edge_source_cols: dict[tuple[str, str, str], str] | None = None, edge_target_cols: dict[tuple[str, str, str], str] | None = None, edge_feature_cols: dict[tuple[str, str, str], list[str]] | None = None, device: Union[str, "torch.device"] | None = None) -> HeteroData: """ Create a heterogeneous graph HeteroData object from node and edge dictionaries. Parameters ---------- nodes_dict : dict or GeoDataFrame Dictionary of GeoDataFrames for each node type, or a single GeoDataFrame with a 'type' column to automatically split by node type. edges_dict : dict or GeoDataFrame Dictionary of GeoDataFrames for each edge type, with keys as (source_type, relation, target_type), or a single GeoDataFrame with an 'edge_type' column. node_id_cols : dict, optional Dictionary mapping node types to their ID column. node_feature_cols : dict, optional Dictionary mapping node types to lists of feature columns. node_label_cols : dict, optional Dictionary mapping node types to lists of label columns. edge_source_cols : dict, optional Dictionary mapping edge types to source column names. edge_target_cols : dict, optional Dictionary mapping edge types to target column names. edge_feature_cols : dict, optional Dictionary mapping edge types to lists of edge attribute columns. device : torch.device or str, optional Device for tensor creation. Returns ------- torch_geometric.data.HeteroData A PyTorch Geometric HeteroData graph object. Raises ------ ImportError If PyTorch and PyTorch Geometric are not installed """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError(msg) if node_id_cols is None: node_id_cols = {} if node_feature_cols is None: node_feature_cols = {} if edge_source_cols is None: edge_source_cols = {} if edge_target_cols is None: edge_target_cols = {} # Handle case where nodes_dict is a single GeoDataFrame if isinstance(nodes_dict, gpd.GeoDataFrame): nodes_dict = _process_single_nodes_gdf(nodes_dict) # Handle case where edges_dict is a single GeoDataFrame if isinstance(edges_dict, gpd.GeoDataFrame): edges_dict = _process_single_edges_gdf(edges_dict) return _build_graph_data( nodes=nodes_dict, edges=edges_dict, node_id_cols=node_id_cols, node_feature_cols=node_feature_cols, node_label_cols=node_label_cols, edge_source_cols=edge_source_cols, edge_target_cols=edge_target_cols, edge_feature_cols=edge_feature_cols, device=device, )
[docs] def from_morphological_graph(network_output: dict, # noqa: PLR0915 private_id_col: str = "tess_id", public_id_col: str = "id", private_node_feature_cols: list[str] | None = None, public_node_feature_cols: list[str] | None = None, device: Union[str, "torch.device"] | None = None) -> HeteroData | Data: """ Create a graph representation from the output of morphological_graph. Parameters ---------- network_output : dict Output dictionary from morphological_graph containing: - 'tessellation': GeoDataFrame of tessellation cells (private spaces) - 'segments': GeoDataFrame of road segments (public spaces) - 'private_to_private': GeoDataFrame of connections between tessellation cells - 'public_to_public': GeoDataFrame of connections between road segments - 'private_to_public': GeoDataFrame of connections between tessellation cells and road segments private_id_col : str, default='tess_id' Column name in tessellation GeoDataFrame that uniquely identifies each private space. public_id_col : str, default='id' Column name in segments GeoDataFrame that uniquely identifies each public space. private_node_feature_cols : list, default None Attributes in tessellation GeoDataFrame to use as node features. public_node_feature_cols : list, default None Attributes in segments GeoDataFrame to use as node features. device : str, default None Device to use for tensors. Must be 'cuda' or 'cpu' if provided. If None, will use CUDA if available, otherwise CPU. Returns ------- torch_geometric.data.HeteroData or torch_geometric.data.Data Graph representation. HeteroData is returned if both node types exist, otherwise a homogeneous Data object. Raises ------ ImportError If PyTorch and PyTorch Geometric are not installed ValueError If required data is missing from the network_output dictionary """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError(msg) # Validate device device = _get_device(device) # Extract data from network_output if not isinstance(network_output, dict): msg = "network_output must be a dictionary returned from morphological_graph" raise TypeError(msg) # Check if we have the new pyg_to_gdf compatible structure if "nodes" in network_output and "edges" in network_output: # New structure - extract from nodes/edges dictionaries nodes_dict = network_output["nodes"] edges_dict = network_output["edges"] private_gdf = nodes_dict.get("private") public_gdf = nodes_dict.get("public") private_to_private_gdf = edges_dict.get(("private", "touched_to", "private")) public_to_public_gdf = edges_dict.get(("public", "connected_to", "public")) private_to_public_gdf = edges_dict.get(("private", "faced_to", "public")) else: # Legacy structure - extract from flat dictionary private_gdf = network_output.get("tessellation") public_gdf = network_output.get("segments") private_to_private_gdf = network_output.get("private_to_private") public_to_public_gdf = network_output.get("public_to_public") private_to_public_gdf = network_output.get("private_to_public") # Validate that required data exists has_private = private_gdf is not None and not private_gdf.empty has_public = public_gdf is not None and not public_gdf.empty # Case 1: We have both private and public nodes - create heterogeneous graph if has_private and has_public: # Create nodes dictionary nodes_dict = {"private": private_gdf, "public": public_gdf} # Create edges dictionary with edge type tuples edges_dict = { ("private", "touched_to", "private"): private_to_private_gdf, ("private", "faced_to", "public"): private_to_public_gdf, ("public", "connected_to", "public"): public_to_public_gdf, } # Create node ID columns dictionary node_id_cols = {"private": private_id_col, "public": public_id_col} # Create node feature columns dictionary node_feature_cols = {} if private_node_feature_cols is not None: node_feature_cols["private"] = private_node_feature_cols if public_node_feature_cols is not None: node_feature_cols["public"] = public_node_feature_cols # Prepare edge source/target column mappings based on morphological_graph output columns edge_source_cols = { ("private", "touched_to", "private"): "from_private_id", ("private", "faced_to", "public"): "private_id", ("public", "connected_to", "public"): "from_public_id", } edge_target_cols = { ("private", "touched_to", "private"): "to_private_id", ("private", "faced_to", "public"): "public_id", ("public", "connected_to", "public"): "to_public_id", } # Create the heterogeneous graph using _build_graph_data return _build_graph_data( nodes=nodes_dict, edges=edges_dict, node_id_cols=node_id_cols, node_feature_cols=node_feature_cols, node_label_cols=None, edge_source_cols=edge_source_cols, edge_target_cols=edge_target_cols, edge_feature_cols=None, device=device, ) # Case 2: We only have private nodes - create homogeneous graph if has_private: nodes_dict = {"node": private_gdf} edges_dict = {("node", "edge", "node"): private_to_private_gdf or gpd.GeoDataFrame()} node_id_cols_dict = {"node": private_id_col} node_feature_cols_dict = {"node": private_node_feature_cols} if private_node_feature_cols else {} edge_source_cols_dict = {("node", "edge", "node"): "from_private_id"} edge_target_cols_dict = {("node", "edge", "node"): "to_private_id"} hetero_data = _build_graph_data( nodes=nodes_dict, edges=edges_dict, node_id_cols=node_id_cols_dict, node_feature_cols=node_feature_cols_dict, node_label_cols=None, edge_source_cols=edge_source_cols_dict, edge_target_cols=edge_target_cols_dict, edge_feature_cols=None, device=device, ) return Data( x=hetero_data["node"].x, edge_index=hetero_data[("node", "edge", "node")].edge_index, edge_attr=hetero_data[("node", "edge", "node")].edge_attr, pos=hetero_data["node"].get("pos", None), y=hetero_data["node"].get("y", None), crs=hetero_data.crs, ) # Case 3: We only have public nodes - create homogeneous graph if has_public: nodes_dict = {"node": public_gdf} edges_dict = {("node", "edge", "node"): public_to_public_gdf or gpd.GeoDataFrame()} node_id_cols_dict = {"node": public_id_col} node_feature_cols_dict = {"node": public_node_feature_cols} if public_node_feature_cols else {} edge_source_cols_dict = {("node", "edge", "node"): "from_public_id"} edge_target_cols_dict = {("node", "edge", "node"): "to_public_id"} hetero_data = _build_graph_data( nodes=nodes_dict, edges=edges_dict, node_id_cols=node_id_cols_dict, node_feature_cols=node_feature_cols_dict, node_label_cols=None, edge_source_cols=edge_source_cols_dict, edge_target_cols=edge_target_cols_dict, edge_feature_cols=None, device=device, ) return Data( x=hetero_data["node"].x, edge_index=hetero_data[("node", "edge", "node")].edge_index, edge_attr=hetero_data[("node", "edge", "node")].edge_attr, pos=hetero_data["node"].get("pos", None), y=hetero_data["node"].get("y", None), crs=hetero_data.crs, ) # Case 4: No valid nodes - raise an error to prevent unintended empty graphs. msg = "No valid node data provided; no nodes found." raise ValueError(msg)
def _extract_tensor_features( tensor: "torch.Tensor", column_names: list[str] | None = None, ) -> dict[str, np.ndarray]: """Extract features from tensor into a dictionary with column names.""" if tensor is None or tensor.numel() == 0: return {} # Convert to numpy for faster operations features_array = tensor.detach().cpu().numpy() if column_names is None: # Generate default column names num_features = ( features_array.shape[1] if len(features_array.shape) > 1 else 1 ) column_names = [f"feature_{i}" for i in range(num_features)] # Create dictionary with vectorized operations if len(features_array.shape) == 1: return {column_names[0]: features_array} return { name: features_array[:, i] for i, name in enumerate(column_names[: features_array.shape[1]]) } def _create_geometries_from_pos( pos_tensor: "torch.Tensor", ) -> gpd.array.GeometryArray: """Create Point geometries from position tensor using vectorized operations.""" if pos_tensor is None or pos_tensor.numel() == 0: return gpd.array.from_shapely([]) pos_array = pos_tensor.detach().cpu().numpy() # Vectorized Point creation using geopandas if len(pos_array.shape) == 2 and pos_array.shape[1] >= 2: return gpd.points_from_xy(pos_array[:, 0], pos_array[:, 1]) return gpd.array.from_shapely([]) def _reconstruct_node_gdf( node_type: str, data: Data | HeteroData, is_hetero: bool = False, ) -> gpd.GeoDataFrame: """Reconstruct node GeoDataFrame from PyTorch Geometric data.""" # Get node data based on graph type node_data = data[node_type] if is_hetero else data # Initialize data dictionary gdf_data = {} # Extract node features with proper column names if hasattr(node_data, "x") and node_data.x is not None: # Get stored feature column names feature_cols = getattr(data, "_node_feature_columns", {}).get(node_type, None) # If no stored names, try alternative storage location if feature_cols is None and hasattr(data, "_node_columns"): stored_cols = getattr(data, "_node_columns", {}).get(node_type, []) # Filter out non-feature columns feature_cols = [col for col in stored_cols if col not in ["geometry", "pos"]] features_dict = _extract_tensor_features(node_data.x, feature_cols) gdf_data.update(features_dict) # Extract node labels with proper column names if hasattr(node_data, "y") and node_data.y is not None: # Get stored label column names label_cols = getattr(data, "_node_label_columns", {}).get(node_type, None) labels_dict = _extract_tensor_features(node_data.y, label_cols) gdf_data.update(labels_dict) # Create geometry from positions geometry = None if hasattr(node_data, "pos") and node_data.pos is not None: geometry = _create_geometries_from_pos(node_data.pos) # Create GeoDataFrame if gdf_data: gdf = gpd.GeoDataFrame(gdf_data, geometry=geometry) else: # Create minimal GeoDataFrame with geometry if available num_nodes = ( node_data.x.size(0) if hasattr(node_data, "x") and node_data.x is not None else 0 ) if num_nodes == 0 and geometry is not None: num_nodes = len(geometry) gdf = gpd.GeoDataFrame( {"node_id": range(num_nodes)} if num_nodes > 0 else {}, geometry=geometry, ) # Set CRS if available if hasattr(data, "crs") and data.crs: gdf.crs = data.crs return gdf def _reconstruct_edge_gdf( edge_type: tuple[str, str, str] | str, data: Data | HeteroData, is_hetero: bool = False, ) -> pd.DataFrame: """Reconstruct edge DataFrame from PyTorch Geometric data.""" # Get edge data based on graph type edge_data = data[edge_type] if is_hetero else data # Initialize data dictionary edge_data_dict = {} # Extract edge indices if hasattr(edge_data, "edge_index") and edge_data.edge_index is not None: edge_index_array = edge_data.edge_index.detach().cpu().numpy() if edge_index_array.shape[0] == 2: edge_data_dict["source"] = edge_index_array[0] edge_data_dict["target"] = edge_index_array[1] # Extract edge features with proper column names if hasattr(edge_data, "edge_attr") and edge_data.edge_attr is not None: # Get stored feature column names feature_cols = getattr(data, "_edge_feature_columns", {}).get(edge_type, None) # If no stored names, try alternative storage location if feature_cols is None and hasattr(data, "_edge_columns"): stored_cols = getattr(data, "_edge_columns", {}).get(edge_type, []) # Filter out non-feature columns like geometry feature_cols = [col for col in stored_cols if col not in ["geometry"]] features_dict = _extract_tensor_features(edge_data.edge_attr, feature_cols) edge_data_dict.update(features_dict) # Create DataFrame (edges typically don't have geometry) return pd.DataFrame(edge_data_dict) if edge_data_dict else pd.DataFrame()
[docs] def pyg_to_gdf( data: Data | HeteroData, ) -> dict[str, dict[str, gpd.GeoDataFrame | pd.DataFrame]] | tuple[ gpd.GeoDataFrame, pd.DataFrame | None, ]: """ Convert PyTorch Geometric Data or HeteroData to GeoDataFrames and DataFrames. Parameters ---------- data : Data or HeteroData PyTorch Geometric graph object to convert. Returns ------- dict or tuple For HeteroData: Returns a dictionary with keys 'nodes' and 'edges', where 'nodes' contains GeoDataFrames and 'edges' contains DataFrames. For Data: Returns a tuple of (nodes_gdf, edges_df). Raises ------ ImportError If PyTorch and PyTorch Geometric are not installed """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError(msg) # Check if it's heterogeneous data is_hetero = hasattr(data, "node_types") and hasattr(data, "edge_types") if is_hetero: # Handle HeteroData nodes_dict = {} edges_dict = {} # Reconstruct node GeoDataFrames for each node type for node_type in data.node_types: nodes_dict[node_type] = _reconstruct_node_gdf( node_type, data, is_hetero=True, ) # Reconstruct edge DataFrames for each edge type for edge_type in data.edge_types: edges_dict[edge_type] = _reconstruct_edge_gdf( edge_type, data, is_hetero=True, ) return {"nodes": nodes_dict, "edges": edges_dict} # Handle homogeneous Data nodes_gdf = _reconstruct_node_gdf("node", data, is_hetero=False) # Check if edges exist if ( hasattr(data, "edge_index") and data.edge_index is not None and data.edge_index.numel() > 0 ): edges_df = _reconstruct_edge_gdf("edge", data, is_hetero=False) else: edges_df = None return nodes_gdf, edges_df
[docs] def pyg_to_nx(data: Data | HeteroData) -> nx.Graph: """ Convert PyTorch Geometric Data or HeteroData to NetworkX graph. Parameters ---------- data : Data or HeteroData PyTorch Geometric graph object to convert. Returns ------- nx.Graph NetworkX graph representation. For HeteroData, returns a MultiDiGraph with node and edge type information. For homogeneous Data, returns a standard Graph. Raises ------ ImportError If PyTorch and PyTorch Geometric are not installed """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError(msg) # Check if it's heterogeneous data is_hetero = hasattr(data, "node_types") and hasattr(data, "edge_types") # Get stored attribute column names for reconstruction node_feature_cols = getattr(data, "_node_feature_columns", {}) node_label_cols = getattr(data, "_node_label_columns", {}) edge_feature_cols = getattr(data, "_edge_feature_columns", {}) # Determine node and edge attributes to include if is_hetero: # For heterogeneous data, collect all unique attribute names across node types all_node_attrs = set() for node_type in data.node_types: if node_type in node_feature_cols: all_node_attrs.update(node_feature_cols[node_type]) if node_type in node_label_cols: all_node_attrs.update(node_label_cols[node_type]) # Collect all unique edge attribute names across edge types all_edge_attrs = set() for edge_type in data.edge_types: if edge_type in edge_feature_cols: all_edge_attrs.update(edge_feature_cols[edge_type]) node_attrs = list(all_node_attrs) if all_node_attrs else None edge_attrs = list(all_edge_attrs) if all_edge_attrs else None # Use to_multi=True for heterogeneous graphs to preserve multiple edges graph = pyg_to_networkx( data, node_attrs=node_attrs, edge_attrs=edge_attrs, to_undirected=False, to_multi=True, ) else: # For homogeneous data, get attributes for single node/edge type node_attrs = [] if "node" in node_feature_cols: node_attrs.extend(node_feature_cols["node"]) if "node" in node_label_cols: node_attrs.extend(node_label_cols["node"]) edge_attrs = edge_feature_cols.get(("node", "edge", "node"), []) node_attrs = node_attrs if node_attrs else None edge_attrs = edge_attrs if edge_attrs else None # Use standard conversion for homogeneous graphs graph = pyg_to_networkx( data, node_attrs=node_attrs, edge_attrs=edge_attrs, to_undirected=False, ) # Preserve global attributes if hasattr(data, "crs") and data.crs: graph.graph["crs"] = data.crs return graph
def _determine_graph_type( nodes: dict[str, gpd.GeoDataFrame] | gpd.GeoDataFrame, edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None, ) -> str: """Determine if the graph should be homogeneous or heterogeneous.""" if isinstance(nodes, dict) and len(nodes) > 1: return "heterogeneous" if isinstance(edges, dict) and len(edges) > 1: return "heterogeneous" # Check if edges dict has complex edge types if isinstance(edges, dict) and edges: for edge_type in edges: if isinstance(edge_type, tuple) and len(edge_type) == 3: src_type, relation, dst_type = edge_type if src_type != dst_type or relation != "edge": return "heterogeneous" return "homogeneous"
[docs] def gdf_to_pyg( # noqa: PLR0912 nodes: dict[str, gpd.GeoDataFrame] | gpd.GeoDataFrame, edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None = None, node_id_cols: dict[str, str] | str | None = None, node_feature_cols: dict[str, list[str]] | list[str] | None = None, node_label_cols: dict[str, list[str]] | list[str] | None = None, edge_source_cols: dict[tuple[str, str, str], str] | str | None = None, edge_target_cols: dict[tuple[str, str, str], str] | str | None = None, edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None, device: Union[str, "torch.device", None] = None, ) -> Data | HeteroData: """ Convert GeoDataFrames to PyTorch Geometric graph objects with automatic type detection. Parameters ---------- nodes : dict or GeoDataFrame Dictionary of GeoDataFrames for each node type, or a single GeoDataFrame. edges : dict, GeoDataFrame, or None Dictionary of GeoDataFrames for each edge type, single GeoDataFrame, or None. node_id_cols : dict, str, or None Node ID column specification. node_feature_cols : dict, list, or None Node feature columns specification. node_label_cols : dict, list, or None Node label columns specification. edge_source_cols : dict, str, or None Edge source column specification. edge_target_cols : dict, str, or None Edge target column specification. edge_feature_cols : dict, list, or None Edge feature columns specification. device : torch.device or str, optional Device for tensor creation. Returns ------- Data or HeteroData PyTorch Geometric graph object. Raises ------ ImportError If PyTorch and PyTorch Geometric are not installed """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError(msg) # Validate input data types if isinstance(nodes, dict): for node_gdf in nodes.values(): _validate_gdf(node_gdf, None) else: _validate_gdf(nodes, None) if edges is not None: if isinstance(edges, dict): for edge_gdf in edges.values(): _validate_gdf(None, edge_gdf) else: _validate_gdf(None, edges) # Determine graph type and delegate to appropriate function if _determine_graph_type(nodes, edges) == "heterogeneous": # Handle case where nodes_dict is a single GeoDataFrame if isinstance(nodes, gpd.GeoDataFrame): nodes = _process_single_nodes_gdf(nodes) # Handle case where edges_dict is a single GeoDataFrame if isinstance(edges, gpd.GeoDataFrame): edges = _process_single_edges_gdf(edges) # Convert parameters to appropriate dictionaries if node_id_cols is None: node_id_cols = {} if node_feature_cols is None: node_feature_cols = {} if edge_source_cols is None: edge_source_cols = {} if edge_target_cols is None: edge_target_cols = {} return _build_graph_data( nodes=nodes, edges=edges, node_id_cols=node_id_cols, node_feature_cols=node_feature_cols, node_label_cols=node_label_cols, edge_source_cols=edge_source_cols, edge_target_cols=edge_target_cols, edge_feature_cols=edge_feature_cols, device=device, ) # Extract single values for homogeneous case nodes_gdf = next(iter(nodes.values())) if isinstance(nodes, dict) else nodes edges_gdf = next(iter(edges.values())) if isinstance(edges, dict) and edges else edges # Extract single column specifications node_id_col = ( next(iter(node_id_cols.values())) if isinstance(node_id_cols, dict) else node_id_cols ) node_feature_col_list = ( next(iter(node_feature_cols.values())) if isinstance(node_feature_cols, dict) else node_feature_cols ) node_label_col_list = ( next(iter(node_label_cols.values())) if isinstance(node_label_cols, dict) else node_label_cols ) edge_source_col = ( next(iter(edge_source_cols.values())) if isinstance(edge_source_cols, dict) else edge_source_cols ) edge_target_col = ( next(iter(edge_target_cols.values())) if isinstance(edge_target_cols, dict) else edge_target_cols ) edge_feature_col_list = ( next(iter(edge_feature_cols.values())) if isinstance(edge_feature_cols, dict) else edge_feature_cols ) # Build homogeneous graph using _build_graph_data nodes_dict = {"node": nodes_gdf} edges_dict = { ("node", "edge", "node"): edges_gdf if edges_gdf is not None else gpd.GeoDataFrame(), } node_id_cols_dict = {"node": node_id_col} if node_id_col else {} node_feature_cols_dict = {"node": node_feature_col_list} if node_feature_col_list else {} node_label_cols_dict = {"node": node_label_col_list} if node_label_col_list else None edge_source_cols_dict = {("node", "edge", "node"): edge_source_col} edge_target_cols_dict = {("node", "edge", "node"): edge_target_col} edge_feature_map = ( {("node", "edge", "node"): edge_feature_col_list} if edge_feature_col_list is not None else None ) hetero_data = _build_graph_data( nodes=nodes_dict, edges=edges_dict, node_id_cols=node_id_cols_dict, node_feature_cols=node_feature_cols_dict, node_label_cols=node_label_cols_dict, edge_source_cols=edge_source_cols_dict, edge_target_cols=edge_target_cols_dict, edge_feature_cols=edge_feature_map, device=device, ) # Convert to homogeneous Data object data = Data( x=hetero_data["node"].x, edge_index=hetero_data[("node", "edge", "node")].edge_index, edge_attr=hetero_data[("node", "edge", "node")].edge_attr, pos=hetero_data["node"].get("pos", None), ) data.y = hetero_data["node"].get("y", None) data.crs = hetero_data.crs return data
[docs] def nx_to_pyg( graph: nx.Graph, node_feature_attrs: list[str] | None = None, node_label_attrs: list[str] | None = None, edge_feature_attrs: list[str] | None = None, device: Union[str, "torch.device"] | None = None, ) -> Data | HeteroData: """ Convert NetworkX graph to PyTorch Geometric graph with automatic type detection. Parameters ---------- graph : networkx.Graph NetworkX graph to convert. node_feature_attrs : list of str, optional List of node attributes to use as features. node_label_attrs : list of str, optional List of node attributes to use as labels. edge_feature_attrs : list of str, optional List of edge attributes to use as features. device : torch.device or str, optional Device for tensor creation. Returns ------- Data or HeteroData PyTorch Geometric graph object. Raises ------ ImportError If PyTorch and PyTorch Geometric are not installed ValueError If the graph is empty """ if not TORCH_AVAILABLE: msg = ( "PyTorch and PyTorch Geometric are required for this function. " "Please install them using: poetry install --with torch or " "pip install city2graph[torch]" ) raise ImportError(msg) # Validate NetworkX graph _validate_nx(graph) if len(graph.nodes()) == 0: msg = "Graph has no nodes" raise ValueError(msg) # Check if graph has heterogeneous structure using vectorized operations node_data_list = list(graph.nodes(data=True)) edge_data_list = list(graph.edges(data=True)) # Extract node types efficiently node_types = {data.get("node_type", "default") for _, data in node_data_list} # Extract edge types efficiently edge_types = set() for src, dst, edge_data in edge_data_list: src_type = graph.nodes[src].get("node_type", "default") dst_type = graph.nodes[dst].get("node_type", "default") relation = edge_data.get("relation", edge_data.get("edge_type", "edge")) edge_types.add((src_type, relation, dst_type)) # Determine if heterogeneous is_hetero = len(node_types) > 1 or len(edge_types) > 1 or any( src_type != dst_type or relation != "edge" for src_type, relation, dst_type in edge_types ) if is_hetero: return _nx_to_hetero_pyg(graph, node_feature_attrs, node_label_attrs, edge_feature_attrs, device) return _nx_to_homo_pyg(graph, node_feature_attrs, node_label_attrs, edge_feature_attrs, device)
def _nx_to_homo_pyg(graph: nx.Graph, node_feature_attrs: list[str] | None, node_label_attrs: list[str] | None, edge_feature_attrs: list[str] | None, device: Union[str, "torch.device"] | None) -> Data: """Convert NetworkX graph to homogeneous PyTorch Geometric Data object.""" device = _get_device(device) # Create node mapping using vectorized operations nodes_list = list(graph.nodes()) node_mapping = {node: i for i, node in enumerate(nodes_list)} num_nodes = len(nodes_list) # Vectorized node data extraction nodes_data = [graph.nodes[node] for node in nodes_list] # Extract node features vectorized if node_feature_attrs: feature_matrix = np.array([ [data.get(attr, 0.0) for attr in node_feature_attrs] for data in nodes_data ], dtype=np.float32) x = torch.from_numpy(feature_matrix).to(device=device, dtype=torch.float) else: x = torch.zeros((num_nodes, 0), dtype=torch.float, device=device) # Extract node labels vectorized y = None if node_label_attrs: label_matrix = np.array([ [data.get(attr, 0.0) for attr in node_label_attrs] for data in nodes_data ], dtype=np.float32) y = torch.from_numpy(label_matrix).to(device=device, dtype=torch.float) # Vectorized edge extraction edges_data = list(graph.edges(data=True)) if edges_data: # Extract edge indices vectorized edge_array = np.array([ [node_mapping[src], node_mapping[dst]] for src, dst, _ in edges_data ]) edge_index = torch.from_numpy(edge_array.T).to(device=device, dtype=torch.long) # Extract edge features vectorized if edge_feature_attrs: edge_feature_matrix = np.array([ [data.get(attr, 0.0) for attr in edge_feature_attrs] for _, _, data in edges_data ], dtype=np.float32) edge_attr = torch.from_numpy(edge_feature_matrix).to(device=device, dtype=torch.float) else: edge_attr = torch.zeros((len(edges_data), 0), dtype=torch.float, device=device) else: edge_index = torch.zeros((2, 0), dtype=torch.long, device=device) edge_attr = torch.zeros((0, 0), dtype=torch.float, device=device) # Extract positional information vectorized pos = None if all("pos" in data for data in nodes_data): pos_matrix = np.array([data["pos"] for data in nodes_data], dtype=np.float32) pos = torch.from_numpy(pos_matrix).to(device=device, dtype=torch.float) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos) # Preserve CRS if available if "crs" in graph.graph: data.crs = graph.graph["crs"] # Store metadata for reconstruction _store_nx_metadata(data, graph, node_feature_attrs, node_label_attrs, edge_feature_attrs) return data def _extract_nx_node_data(graph: nx.Graph, node_types: dict, node_feature_attrs: list[str] | None, node_label_attrs: list[str] | None, device: "torch.device") -> tuple[dict, dict]: """Extract node features and labels vectorized by type.""" node_mappings = {} data_store = {} for node_type, nodes in node_types.items(): node_mappings[node_type] = {node: i for i, node in enumerate(nodes)} if not nodes: continue # Vectorized node data extraction all_node_data = [graph.nodes[node] for node in nodes] # Features if node_feature_attrs: feature_matrix = np.array([ [data.get(attr, 0.0) for attr in node_feature_attrs] for data in all_node_data ], dtype=np.float32) data_store[f"{node_type}_x"] = torch.from_numpy(feature_matrix).to(device=device, dtype=torch.float) else: data_store[f"{node_type}_x"] = torch.zeros((len(nodes), 0), dtype=torch.float, device=device) # Labels if node_label_attrs: label_matrix = np.array([ [data.get(attr, 0.0) for attr in node_label_attrs] for data in all_node_data ], dtype=np.float32) data_store[f"{node_type}_y"] = torch.from_numpy(label_matrix).to(device=device, dtype=torch.float) # Positions if all("pos" in data for data in all_node_data): pos_matrix = np.array([data["pos"] for data in all_node_data], dtype=np.float32) data_store[f"{node_type}_pos"] = torch.from_numpy(pos_matrix).to(device=device, dtype=torch.float) return node_mappings, data_store def _extract_nx_edge_data(graph: nx.Graph, node_mappings: dict, edge_feature_attrs: list[str] | None, device: "torch.device") -> dict: """Extract edge data vectorized by type.""" # Group edges by type using vectorized operations edge_groups = {} edges_list = list(graph.edges(data=True)) # Vectorized edge type extraction for src, dst, edge_data in edges_list: src_type = graph.nodes[src].get("node_type", "default") dst_type = graph.nodes[dst].get("node_type", "default") relation = edge_data.get("relation", edge_data.get("edge_type", "edge")) edge_type = (src_type, relation, dst_type) if edge_type not in edge_groups: edge_groups[edge_type] = [] edge_groups[edge_type].append((src, dst, edge_data)) edge_tensors = {} for edge_type, edges in edge_groups.items(): src_type, relation, dst_type = edge_type if not edges: continue # Vectorized edge processing edge_indices = np.array([ [node_mappings[src_type][src], node_mappings[dst_type][dst]] for src, dst, _ in edges ]) edge_tensors[edge_type] = { "edge_index": torch.from_numpy(edge_indices.T).to(device=device, dtype=torch.long), } # Edge features vectorized if edge_feature_attrs: edge_feature_matrix = np.array([ [data.get(attr, 0.0) for attr in edge_feature_attrs] for _, _, data in edges ], dtype=np.float32) edge_tensors[edge_type]["edge_attr"] = torch.from_numpy(edge_feature_matrix).to( device=device, dtype=torch.float, ) else: edge_tensors[edge_type]["edge_attr"] = torch.zeros( (len(edges), 0), dtype=torch.float, device=device, ) return edge_tensors def _nx_to_hetero_pyg(graph: nx.Graph, node_feature_attrs: list[str] | None, node_label_attrs: list[str] | None, edge_feature_attrs: list[str] | None, device: Union[str, "torch.device"] | None) -> HeteroData: """Convert NetworkX graph to heterogeneous PyTorch Geometric HeteroData object.""" device = _get_device(device) data = HeteroData() # Group nodes by type vectorized node_types = {} for node, node_data in graph.nodes(data=True): node_type = node_data.get("node_type", "default") if node_type not in node_types: node_types[node_type] = [] node_types[node_type].append(node) # Extract node data vectorized node_mappings, node_data_store = _extract_nx_node_data( graph, node_types, node_feature_attrs, node_label_attrs, device, ) # Assign node data to HeteroData for node_type in node_types: if f"{node_type}_x" in node_data_store: data[node_type].x = node_data_store[f"{node_type}_x"] if f"{node_type}_y" in node_data_store: data[node_type].y = node_data_store[f"{node_type}_y"] if f"{node_type}_pos" in node_data_store: data[node_type].pos = node_data_store[f"{node_type}_pos"] # Extract edge data vectorized edge_data_store = _extract_nx_edge_data(graph, node_mappings, edge_feature_attrs, device) # Assign edge data to HeteroData for edge_type, edge_data in edge_data_store.items(): data[edge_type].edge_index = edge_data["edge_index"] data[edge_type].edge_attr = edge_data["edge_attr"] # Preserve CRS if available if "crs" in graph.graph: data.crs = graph.graph["crs"] # Store metadata for reconstruction _store_nx_metadata(data, graph, node_feature_attrs, node_label_attrs, edge_feature_attrs) return data def _collect_nx_node_columns(node_attrs: dict, feature_attrs: list[str] | None, label_attrs: list[str] | None) -> list[str]: """Collect all reconstructable attribute names for NetworkX nodes.""" columns = [] # Add feature attributes if feature_attrs: valid_features = [attr for attr in feature_attrs if attr in node_attrs] columns.extend(valid_features) # Add label attributes if label_attrs: valid_labels = [attr for attr in label_attrs if attr in node_attrs] columns.extend(valid_labels) # Add position attribute (stored in data.pos) if "pos" in node_attrs: columns.append("pos") return list(set(columns)) def _collect_nx_edge_columns(edge_attrs: dict, feature_attrs: list[str] | None) -> list[str]: """Collect all reconstructable attribute names for NetworkX edges.""" columns = [] # Add feature attributes if feature_attrs: valid_features = [attr for attr in feature_attrs if attr in edge_attrs] columns.extend(valid_features) return list(set(columns)) def _store_nx_metadata(data: Data | HeteroData, graph: nx.Graph | None, node_feature_attrs: list[str] | None, node_label_attrs: list[str] | None, edge_feature_attrs: list[str] | None) -> None: """Store column names from NetworkX graphs that can be reconstructed from tensors.""" if graph is None: return # Initialize column storage if not hasattr(data, "_node_columns"): data._node_columns = {} if not hasattr(data, "_edge_columns"): data._edge_columns = {} # Group nodes by type node_types = {} for node, node_data in graph.nodes(data=True): node_type = node_data.get("node_type", "default") if node_type not in node_types: node_types[node_type] = [] node_types[node_type].append(node) # Store node attribute names that were used as features/labels for node_type, nodes_list in node_types.items(): if nodes_list: sample_attrs = graph.nodes[nodes_list[0]] columns = _collect_nx_node_columns(sample_attrs, node_feature_attrs, node_label_attrs) if columns: data._node_columns[node_type] = columns # Store edge attribute names that were used as features edge_types = {} for src, dst, edge_data in graph.edges(data=True): src_type = graph.nodes[src].get("node_type", "default") dst_type = graph.nodes[dst].get("node_type", "default") relation = edge_data.get("relation", edge_data.get("edge_type", "edge")) edge_type = (src_type, relation, dst_type) if edge_type not in edge_types: edge_types[edge_type] = [] edge_types[edge_type].append((src, dst, edge_data)) for edge_type, edges_list in edge_types.items(): if edges_list: sample_edge_data = edges_list[0][2] columns = _collect_nx_edge_columns(sample_edge_data, edge_feature_attrs) if columns: data._edge_columns[edge_type] = columns def _store_node_gdf_metadata(data: Data | HeteroData, nodes: dict[str, gpd.GeoDataFrame] | gpd.GeoDataFrame, node_feature_cols: dict[str, list[str]] | list[str] | None, node_label_cols: dict[str, list[str]] | list[str] | None) -> None: """Store node metadata from GeoDataFrames.""" if isinstance(nodes, dict): # Store feature and label column names for each node type for node_type, node_gdf in nodes.items(): reconstructable_cols = [] # Add feature columns that were used (stored in data.x) if isinstance(node_feature_cols, dict) and node_type in node_feature_cols: features = node_feature_cols[node_type] if features and hasattr(node_gdf, "columns"): valid_features = [col for col in features if col in node_gdf.columns] reconstructable_cols.extend(valid_features) # Add label columns that were used (stored in data.y) if isinstance(node_label_cols, dict) and node_type in node_label_cols: labels = node_label_cols[node_type] if labels and hasattr(node_gdf, "columns"): valid_labels = [col for col in labels if col in node_gdf.columns] reconstructable_cols.extend(valid_labels) # Add geometry column (stored in data.pos) if hasattr(node_gdf, "geometry") and "geometry" in node_gdf.columns: reconstructable_cols.append("geometry") # Store all reconstructable columns if reconstructable_cols: if not hasattr(data, "_node_columns"): data._node_columns = {} data._node_columns[node_type] = list(set(reconstructable_cols)) else: # Single GeoDataFrame - store for "node" type reconstructable_cols = [] if isinstance(node_feature_cols, list) and hasattr(nodes, "columns"): valid_features = [col for col in node_feature_cols if col in nodes.columns] reconstructable_cols.extend(valid_features) if isinstance(node_label_cols, list) and hasattr(nodes, "columns"): valid_labels = [col for col in node_label_cols if col in nodes.columns] reconstructable_cols.extend(valid_labels) # Add geometry column if hasattr(nodes, "geometry") and "geometry" in nodes.columns: reconstructable_cols.append("geometry") if reconstructable_cols: data._node_columns = {"node": list(set(reconstructable_cols))} def _store_edge_gdf_metadata(data: Data | HeteroData, edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None, edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None) -> None: # noqa: E501 """Store edge metadata from GeoDataFrames.""" if edges is None: return if isinstance(edges, dict): # Store edge feature column names for each edge type for edge_type, edge_gdf in edges.items(): reconstructable_cols = [] # Add feature columns that were used (stored in edge_attr) if isinstance(edge_feature_cols, dict) and edge_type in edge_feature_cols: features = edge_feature_cols[edge_type] if features and hasattr(edge_gdf, "columns"): valid_features = [col for col in features if col in edge_gdf.columns] reconstructable_cols.extend(valid_features) # Add geometry column if present if hasattr(edge_gdf, "geometry") and "geometry" in edge_gdf.columns: reconstructable_cols.append("geometry") if reconstructable_cols: if not hasattr(data, "_edge_columns"): data._edge_columns = {} data._edge_columns[edge_type] = list(set(reconstructable_cols)) elif isinstance(edge_feature_cols, list) and hasattr(edges, "columns"): # Single GeoDataFrame - store for default edge type reconstructable_cols = [] valid_features = [col for col in edge_feature_cols if col in edges.columns] reconstructable_cols.extend(valid_features) # Add geometry column if hasattr(edges, "geometry") and "geometry" in edges.columns: reconstructable_cols.append("geometry") if reconstructable_cols: data._edge_columns = {("node", "edge", "node"): list(set(reconstructable_cols))} def _store_gdf_metadata(data: Data | HeteroData, nodes: dict[str, gpd.GeoDataFrame] | gpd.GeoDataFrame | None, edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None, node_feature_cols: dict[str, list[str]] | list[str] | None, node_label_cols: dict[str, list[str]] | list[str] | None, edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None) -> None: """Store column names from GeoDataFrames that can be reconstructed from tensors.""" if nodes is not None: _store_node_gdf_metadata(data, nodes, node_feature_cols, node_label_cols) if edges is not None: _store_edge_gdf_metadata(data, edges, edge_feature_cols) def _store_id_mappings(data: Data | HeteroData, node_id_cols: dict[str, str] | str | None, edge_source_cols: dict[tuple[str, str, str], str] | str | None, edge_target_cols: dict[tuple[str, str, str], str] | str | None) -> None: """Store ID and source/target column mappings.""" if isinstance(node_id_cols, dict): data._node_id_cols = node_id_cols elif isinstance(node_id_cols, str): data._node_id_cols = {"node": node_id_cols} if isinstance(edge_source_cols, dict): data._edge_source_cols = edge_source_cols elif isinstance(edge_source_cols, str): data._edge_source_cols = {("node", "edge", "node"): edge_source_cols} if isinstance(edge_target_cols, dict): data._edge_target_cols = edge_target_cols elif isinstance(edge_target_cols, str): data._edge_target_cols = {("node", "edge", "node"): edge_target_cols} def _store_nx_node_metadata(data: Data | HeteroData, graph: nx.Graph, node_feature_cols: dict | list | None, node_label_cols: dict | list | None) -> None: """Store node metadata from NetworkX graphs.""" # Group nodes by type using vectorized operations node_types = {} for node, node_data in graph.nodes(data=True): node_type = node_data.get("node_type", "default") if node_type not in node_types: node_types[node_type] = [] node_types[node_type].append(node) # Store node attributes for node_type, nodes_list in node_types.items(): if not nodes_list: continue sample_attrs = graph.nodes[nodes_list[0]] # Store feature attributes if they were specified if node_feature_cols: feature_attrs = ( node_feature_cols.get(node_type, []) if isinstance(node_feature_cols, dict) else node_feature_cols if isinstance(node_feature_cols, list) else [] ) valid_features = [attr for attr in feature_attrs if attr in sample_attrs] if valid_features: if not hasattr(data, "_node_feature_columns"): data._node_feature_columns = {} data._node_feature_columns[node_type] = valid_features # Store label attributes if they were specified if node_label_cols: label_attrs = ( node_label_cols.get(node_type, []) if isinstance(node_label_cols, dict) else node_label_cols if isinstance(node_label_cols, list) else [] ) valid_labels = [attr for attr in label_attrs if attr in sample_attrs] if valid_labels: if not hasattr(data, "_node_label_columns"): data._node_label_columns = {} data._node_label_columns[node_type] = valid_labels def _store_nx_edge_metadata(data: Data | HeteroData, graph: nx.Graph, edge_feature_cols: dict | list | None) -> None: """Store edge metadata from NetworkX graphs.""" if not edge_feature_cols: return # Store edge attributes edge_types = {} for src, dst, edge_data in graph.edges(data=True): src_type = graph.nodes[src].get("node_type", "default") dst_type = graph.nodes[dst].get("node_type", "default") relation = edge_data.get("relation", edge_data.get("edge_type", "edge")) edge_type = (src_type, relation, dst_type) if edge_type not in edge_types: edge_types[edge_type] = [] edge_types[edge_type].append((src, dst, edge_data)) for edge_type, edges_list in edge_types.items(): if not edges_list: continue sample_edge_data = edges_list[0][2] feature_attrs = ( edge_feature_cols.get(edge_type, []) if isinstance(edge_feature_cols, dict) else edge_feature_cols if isinstance(edge_feature_cols, list) else [] ) valid_features = [attr for attr in feature_attrs if attr in sample_edge_data] if valid_features: if not hasattr(data, "_edge_feature_columns"): data._edge_feature_columns = {} data._edge_feature_columns[edge_type] = valid_features def _store_nx_metadata(data: Data | HeteroData, graph: nx.Graph | None, node_feature_cols: dict | list | None, node_label_cols: dict | list | None, edge_feature_cols: dict | list | None) -> None: """Store metadata from NetworkX graphs.""" if graph is None: return _store_nx_node_metadata(data, graph, node_feature_cols, node_label_cols) _store_nx_edge_metadata(data, graph, edge_feature_cols) def _store_reconstruction_metadata( data: Data | HeteroData, nodes: dict[str, gpd.GeoDataFrame] | gpd.GeoDataFrame | None = None, edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None = None, node_id_cols: dict[str, str] | str | None = None, node_feature_cols: dict[str, list[str]] | list[str] | None = None, node_label_cols: dict[str, list[str]] | list[str] | None = None, edge_source_cols: dict[tuple[str, str, str], str] | str | None = None, edge_target_cols: dict[tuple[str, str, str], str] | str | None = None, edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None, graph: nx.Graph | None = None, ) -> None: """Store metadata in Data/HeteroData object for reconstruction purposes.""" # Initialize metadata storage data._node_feature_columns = {} data._node_label_columns = {} data._edge_feature_columns = {} data._node_id_cols = {} data._edge_source_cols = {} data._edge_target_cols = {} # Store feature and label column names based on the inputs if node_feature_cols is not None: if isinstance(node_feature_cols, dict): data._node_feature_columns.update(node_feature_cols) elif isinstance(node_feature_cols, list): data._node_feature_columns["node"] = node_feature_cols if node_label_cols is not None: if isinstance(node_label_cols, dict): data._node_label_columns.update(node_label_cols) elif isinstance(node_label_cols, list): data._node_label_columns["node"] = node_label_cols if edge_feature_cols is not None: if isinstance(edge_feature_cols, dict): data._edge_feature_columns.update(edge_feature_cols) elif isinstance(edge_feature_cols, list): data._edge_feature_columns[("node", "edge", "node")] = edge_feature_cols # Store GeoDataFrame metadata _store_gdf_metadata( data, nodes, edges, node_feature_cols, node_label_cols, edge_feature_cols, ) # Store ID mappings _store_id_mappings(data, node_id_cols, edge_source_cols, edge_target_cols) # Store NetworkX metadata _store_nx_metadata(data, graph, node_feature_cols, node_label_cols, edge_feature_cols)