Source code for city2graph.graph

"""
Module for creating heterogeneous graph representations of urban environments.

This module provides comprehensive functionality for converting spatial data
(GeoDataFrames and NetworkX objects) into PyTorch Geometric Data and HeteroData objects,
supporting both homogeneous and heterogeneous graphs. It handles the complex mapping between
geographical coordinates, node/edge features, and the tensor representations
required by graph neural networks.

The module serves as a bridge between geospatial data analysis tools and deep
learning frameworks, enabling seamless integration of spatial urban data with
Graph Neural Networks (GNNs) for tasks of GeoAI such as urban modeling, traffic prediction,
and spatial analysis.
"""

# Future annotations for type hints
from __future__ import annotations

# Standard library imports
import logging
from typing import TYPE_CHECKING

# Third-party imports
import geopandas as gpd
import networkx as nx
import numpy as np
import pandas as pd
from shapely.geometry import LineString

# Local imports
from city2graph.utils import GraphMetadata
from city2graph.utils import nx_to_gdf
from city2graph.utils import validate_gdf
from city2graph.utils import validate_nx

# PyTorch Geometric imports with availability checking
try:
    import torch
    from torch_geometric.data import Data
    from torch_geometric.data import HeteroData

    TORCH_AVAILABLE = True
except ImportError:  # pragma: no cover - makes life easier for docs build.
    TORCH_AVAILABLE = False

    # Create stubs for documentation and fallback functionality
    if TYPE_CHECKING:
        from torch_geometric.data import Data
        from torch_geometric.data import HeteroData
    else:
        torch = None

        class HeteroData:
            """Fallback stub when torch is unavailable."""

        class Data:
            """Fallback stub when torch is unavailable."""


logger = logging.getLogger(__name__)

__all__ = [
    "gdf_to_pyg",
    "is_torch_available",
    "nx_to_pyg",
    "pyg_to_gdf",
    "pyg_to_nx",
    "validate_pyg",
]

# Constants for error messages
TORCH_ERROR_MSG = "PyTorch and PyTorch Geometric required for graph conversion functionality."
DEVICE_ERROR_MSG = "Device must be 'cuda', 'cpu', a torch.device object, or None"
GRAPH_NO_NODES_MSG = "Graph has no nodes"


# ============================================================================
# GRAPH CONVERSION FUNCTIONS
# ============================================================================


[docs] def gdf_to_pyg( nodes: dict[str, gpd.GeoDataFrame] | gpd.GeoDataFrame, edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None = None, node_feature_cols: dict[str, list[str]] | list[str] | None = None, node_label_cols: dict[str, list[str]] | list[str] | None = None, edge_feature_cols: dict[str, list[str]] | list[str] | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> Data | HeteroData: """ Convert GeoDataFrames (nodes/edges) to a PyTorch Geometric object. This function serves as the main entry point for converting spatial data into PyTorch Geometric graph objects. It automatically detects whether to create homogeneous or heterogeneous graphs based on input structure. Node identifiers are taken from the GeoDataFrame index. Edge relationships are defined by a MultiIndex on the edge GeoDataFrame (source ID, target ID). Parameters ---------- nodes : dict[str, geopandas.GeoDataFrame] or geopandas.GeoDataFrame Node data. For homogeneous graphs, provide a single GeoDataFrame. For heterogeneous graphs, provide a dictionary mapping node type names to their respective GeoDataFrames. The index of these GeoDataFrames will be used as node identifiers. edges : dict[tuple[str, str, str], geopandas.GeoDataFrame] or geopandas.GeoDataFrame, optional Edge data. For homogeneous graphs, provide a single GeoDataFrame. For heterogeneous graphs, provide a dictionary mapping edge type tuples (source_type, relation_type, target_type) to their GeoDataFrames. The GeoDataFrame must have a MultiIndex where the first level represents source node IDs and the second level represents target node IDs. node_feature_cols : dict[str, list[str]] or list[str], optional Column names to use as node features. For heterogeneous graphs, provide a dictionary mapping node types to their feature columns. node_label_cols : dict[str, list[str]] or list[str], optional Column names to use as node labels for supervised learning tasks. For heterogeneous graphs, provide a dictionary mapping node types to their label columns. edge_feature_cols : dict[str, list[str]] or list[str], optional Column names to use as edge features. For heterogeneous graphs, provide a dictionary mapping relation types to their feature columns. device : str or torch.device, optional Target device for tensor placement ('cpu', 'cuda', or torch.device). If None, automatically selects CUDA if available, otherwise CPU. dtype : torch.dtype, optional Data type for float tensors (e.g., torch.float32, torch.float16). If None, uses torch.float32 (default PyTorch float type). Returns ------- torch_geometric.data.Data or torch_geometric.data.HeteroData PyTorch Geometric Data object for homogeneous graphs or HeteroData object for heterogeneous graphs. The returned object contains: - Node features (x), positions (pos), and labels (y) if available - Edge connectivity (edge_index) and features (edge_attr) if available - Metadata for reconstruction including ID mappings and column names Raises ------ ImportError If PyTorch Geometric is not installed. ValueError If input GeoDataFrames are invalid or incompatible. See Also -------- pyg_to_gdf : Convert PyTorch Geometric data back to GeoDataFrames. nx_to_pyg : Convert NetworkX graph to PyTorch Geometric object. city2graph.utils.validate_gdf : Validate GeoDataFrame structure. Notes ----- This function automatically detects the graph type based on input structure. For heterogeneous graphs, provide dictionaries mapping types to GeoDataFrames. Node positions are automatically extracted from geometry centroids when available. - Preserves original coordinate reference systems (CRS) - Maintains index structure for bidirectional conversion - Handles both Point and non-Point geometries (using centroids) - Creates empty tensors for missing features/edges - For heterogeneous graphs, ensures consistent node/edge type mapping Examples -------- Create a homogeneous graph from single GeoDataFrames: >>> import geopandas as gpd >>> from city2graph.graph import gdf_to_pyg >>> >>> # Load and prepare node data >>> nodes_gdf = gpd.read_file("nodes.geojson").set_index("node_id") >>> edges_gdf = gpd.read_file("edges.geojson").set_index(["source_id", "target_id"]) >>> >>> # Convert to PyTorch Geometric >>> data = gdf_to_pyg(nodes_gdf, edges_gdf, ... node_feature_cols=['population', 'area']) Create a heterogeneous graph from dictionaries: >>> # Prepare heterogeneous data >>> buildings_gdf = buildings_gdf.set_index("building_id") >>> roads_gdf = roads_gdf.set_index("road_id") >>> connections_gdf = connections_gdf.set_index(["building_id", "road_id"]) >>> >>> # Define node and edge types >>> nodes_dict = {'building': buildings_gdf, 'road': roads_gdf} >>> edges_dict = {('building', 'connects', 'road'): connections_gdf} >>> >>> # Convert to heterogeneous graph with labels >>> data = gdf_to_pyg(nodes_dict, edges_dict, ... node_label_cols={'building': ['type'], 'road': ['category']}) """ # ------------------------------------------------------------------ # 0. Input validation & dispatch # ------------------------------------------------------------------ if not TORCH_AVAILABLE: raise ImportError(TORCH_ERROR_MSG) # Validate input GeoDataFrames and get type information nodes, edges, is_hetero = validate_gdf(nodes_gdf=nodes, edges_gdf=edges) device = _get_device(device) if is_hetero: # Type assertions for heterogeneous graphs assert isinstance(nodes, dict) assert edges is None or isinstance(edges, dict) # Type narrowing for heterogeneous graphs if isinstance(node_feature_cols, dict) or node_feature_cols is None: node_feature_cols_hetero: dict[str, list[str]] | None = node_feature_cols else: msg = "node_feature_cols must be a dict for heterogeneous graphs" raise TypeError(msg) if isinstance(node_label_cols, dict) or node_label_cols is None: node_label_cols_hetero: dict[str, list[str]] | None = node_label_cols else: msg = "node_label_cols must be a dict for heterogeneous graphs" raise TypeError(msg) if isinstance(edge_feature_cols, dict) or edge_feature_cols is None: edge_feature_cols_hetero: dict[str, list[str]] | None = edge_feature_cols else: msg = "edge_feature_cols must be a dict for heterogeneous graphs" raise TypeError(msg) data = _build_heterogeneous_graph( nodes, edges, node_feature_cols_hetero, node_label_cols_hetero, edge_feature_cols_hetero, device, dtype, ) else: # Type assertions for homogeneous graphs assert isinstance(nodes, gpd.GeoDataFrame) or nodes is None assert isinstance(edges, gpd.GeoDataFrame) or edges is None # Type narrowing for homogeneous graphs if isinstance(node_feature_cols, list) or node_feature_cols is None: node_feature_cols_homo: list[str] | None = node_feature_cols else: msg = "node_feature_cols must be a list for homogeneous graphs" raise TypeError(msg) if isinstance(node_label_cols, list) or node_label_cols is None: node_label_cols_homo: list[str] | None = node_label_cols else: msg = "node_label_cols must be a list for homogeneous graphs" raise TypeError(msg) if isinstance(edge_feature_cols, list) or edge_feature_cols is None: edge_feature_cols_homo: list[str] | None = edge_feature_cols else: msg = "edge_feature_cols must be a list for homogeneous graphs" raise TypeError(msg) # Create a homogeneous Data object data = _build_homogeneous_graph( nodes, edges, node_feature_cols_homo, node_label_cols_homo, edge_feature_cols_homo, device, dtype, ) # Validate the created PyG object validate_pyg(data) return data
[docs] def pyg_to_gdf( data: Data | HeteroData, node_types: str | list[str] | None = None, edge_types: str | list[tuple[str, str, str]] | None = None, ) -> ( tuple[dict[str, gpd.GeoDataFrame], dict[tuple[str, str, str], gpd.GeoDataFrame]] | tuple[gpd.GeoDataFrame | None, gpd.GeoDataFrame | None] ): """ Convert PyTorch Geometric data to GeoDataFrames. Reconstructs the original GeoDataFrame structure from PyTorch Geometric Data or HeteroData objects. This function provides bidirectional conversion capability, preserving spatial information, feature data, and metadata. Parameters ---------- data : torch_geometric.data.Data or torch_geometric.data.HeteroData PyTorch Geometric data object to convert back to GeoDataFrames. node_types : str or list[str], optional For heterogeneous graphs, specify which node types to reconstruct. If None, reconstructs all available node types. edge_types : str or list[tuple[str, str, str]], optional For heterogeneous graphs, specify which edge types to reconstruct. Edge types are specified as (source_type, relation_type, target_type) tuples. If None, reconstructs all available edge types. Returns ------- tuple **For HeteroData input:** Returns a tuple containing: - First element: dict[str, geopandas.GeoDataFrame] mapping node type names to GeoDataFrames - Second element: dict[tuple[str, str, str], geopandas.GeoDataFrame] mapping edge types to GeoDataFrames **For Data input:** Returns a tuple containing: - First element: geopandas.GeoDataFrame containing nodes - Second element: geopandas.GeoDataFrame containing edges (or None if no edges) See Also -------- gdf_to_pyg : Convert GeoDataFrames to PyTorch Geometric object. pyg_to_nx : Convert PyTorch Geometric data to NetworkX graph. Notes ----- - Preserves original index structure and names when available - Reconstructs geometry from stored position tensors - Maintains coordinate reference system (CRS) information - Converts feature tensors back to named DataFrame columns - Handles both homogeneous and heterogeneous graph structures Examples -------- Convert homogeneous PyTorch Geometric data back to GeoDataFrames: >>> from city2graph.graph import pyg_to_gdf >>> >>> # Convert back to GeoDataFrames >>> nodes_gdf, edges_gdf = pyg_to_gdf(data) Convert heterogeneous data with specific node types: >>> # Convert only specific node types >>> node_gdfs, edge_gdfs = pyg_to_gdf(hetero_data, ... node_types=['building', 'road']) """ metadata = validate_pyg(data) if metadata.is_hetero: # ------------------------------------------------------------------ # HeteroData → pandas # ------------------------------------------------------------------ node_types_to_process = node_types or metadata.node_types edge_types_to_process = edge_types or metadata.edge_types node_gdfs = { nt: _reconstruct_node_gdf(data, metadata, node_type=nt) for nt in node_types_to_process } edge_gdfs = {et: _reconstruct_edge_gdf(data, metadata, et) for et in edge_types_to_process} return node_gdfs, edge_gdfs # ------------------------------------------------------------------ # Data → pandas # ------------------------------------------------------------------ nodes_gdf = _reconstruct_node_gdf(data, metadata, None) edges_gdf = _reconstruct_edge_gdf(data, metadata, None) return nodes_gdf, edges_gdf
# ============================================================================ # NETWORKX CONVERSION FUNCTIONS # ============================================================================
[docs] def pyg_to_nx(data: Data | HeteroData) -> nx.Graph: """ Convert a PyTorch Geometric object to a NetworkX graph. Converts PyTorch Geometric Data or HeteroData objects to NetworkX graphs, preserving node and edge features as graph attributes. This enables compatibility with the extensive NetworkX ecosystem for graph analysis. Parameters ---------- data : torch_geometric.data.Data or torch_geometric.data.HeteroData PyTorch Geometric data object to convert. Returns ------- networkx.Graph NetworkX graph with node and edge attributes from the PyG object. For heterogeneous graphs, node and edge types are stored as attributes. Raises ------ ImportError If PyTorch Geometric is not installed. See Also -------- nx_to_pyg : Convert NetworkX graph to PyTorch Geometric object. pyg_to_gdf : Convert PyTorch Geometric data to GeoDataFrames. Notes ----- - Node features, positions, and labels are stored as node attributes - Edge features are stored as edge attributes - For heterogeneous graphs, type information is preserved - Geometry information is converted from tensor positions - Maintains compatibility with NetworkX analysis algorithms Examples -------- Convert PyTorch Geometric data to NetworkX: >>> from city2graph.graph import pyg_to_nx >>> import networkx as nx >>> >>> # Convert to NetworkX graph >>> nx_graph = pyg_to_nx(data) >>> >>> # Use NetworkX algorithms >>> centrality = nx.betweenness_centrality(nx_graph) >>> communities = nx.community.greedy_modularity_communities(nx_graph) """ metadata = validate_pyg(data) if metadata.is_hetero: return _convert_hetero_pyg_to_nx(data, metadata) return _convert_homo_pyg_to_nx(data, metadata)
[docs] def nx_to_pyg( graph: nx.Graph, node_feature_cols: list[str] | None = None, node_label_cols: list[str] | None = None, edge_feature_cols: list[str] | None = None, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> Data | HeteroData: """ Convert NetworkX graph to PyTorch Geometric Data object. Converts a NetworkX Graph to a PyTorch Geometric Data object by first converting to GeoDataFrames then using the main conversion pipeline. This provides a bridge between NetworkX's rich graph analysis tools and PyTorch Geometric's deep learning capabilities. Parameters ---------- graph : networkx.Graph NetworkX graph to convert. node_feature_cols : list[str], optional List of node attribute names to use as features. node_label_cols : list[str], optional List of node attribute names to use as labels. edge_feature_cols : list[str], optional List of edge attribute names to use as features. device : torch.device or str, optional Target device for tensor placement ('cpu', 'cuda', or torch.device). If None, automatically selects CUDA if available, otherwise CPU. dtype : torch.dtype, optional Data type for float tensors (e.g., torch.float32, torch.float16). If None, uses torch.float32 (default PyTorch float type). Returns ------- torch_geometric.data.Data or torch_geometric.data.HeteroData PyTorch Geometric Data object for homogeneous graphs or HeteroData object for heterogeneous graphs. The returned object contains: - Node features (x), positions (pos), and labels (y) if available - Edge connectivity (edge_index) and features (edge_attr) if available - Metadata for reconstruction including ID mappings and column names Raises ------ ImportError If PyTorch Geometric is not installed. ValueError If the NetworkX graph is invalid or empty. See Also -------- pyg_to_nx : Convert PyTorch Geometric data to NetworkX graph. gdf_to_pyg : Convert GeoDataFrames to PyTorch Geometric object. city2graph.utils.nx_to_gdf : Convert NetworkX graph to GeoDataFrames. Notes ----- - Uses intermediate GeoDataFrame conversion for consistency - Preserves all graph attributes and metadata - Handles spatial coordinates if present in node attributes - Maintains compatibility with existing city2graph workflows - Automatically creates geometry from 'x', 'y' coordinates if available Examples -------- Convert a NetworkX graph with spatial data: >>> import networkx as nx >>> from city2graph.graph import nx_to_pyg >>> >>> # Create NetworkX graph with spatial attributes >>> G = nx.Graph() >>> G.add_node(0, x=0.0, y=0.0, population=1000) >>> G.add_node(1, x=1.0, y=1.0, population=1500) >>> G.add_edge(0, 1, weight=0.5, road_type='primary') >>> >>> # Convert to PyTorch Geometric >>> data = nx_to_pyg(G, ... node_feature_cols=['population'], ... edge_feature_cols=['weight']) Convert from graph analysis results: >>> # Use NetworkX for analysis, then convert for ML >>> communities = nx.community.greedy_modularity_communities(G) >>> # Add community labels to nodes >>> for i, community in enumerate(communities): ... for node in community: ... G.nodes[node]['community'] = i >>> >>> # Convert with community labels >>> data = nx_to_pyg(G, node_label_cols=['community']) """ # Validate NetworkX graph (includes type checking) validate_nx(graph) # Get nodes and edges GeoDataFrames nodes_gdf, edges_gdf = nx_to_gdf(graph, nodes=True, edges=True) # Convert to PyG using existing function return gdf_to_pyg( nodes=nodes_gdf, edges=edges_gdf, node_feature_cols=node_feature_cols, node_label_cols=node_label_cols, edge_feature_cols=edge_feature_cols, device=device, dtype=dtype, )
# ============================================================================ # TORCH UTILITIES FUNCTIONS # ============================================================================
[docs] def is_torch_available() -> bool: """ Check if PyTorch Geometric is available. This utility function checks whether the required PyTorch and PyTorch Geometric packages are installed and can be imported. It's useful for conditional functionality and providing helpful error messages. Returns ------- bool True if PyTorch Geometric can be imported, False otherwise. See Also -------- gdf_to_pyg : Convert GeoDataFrames to PyTorch Geometric (requires torch). pyg_to_gdf : Convert PyTorch Geometric to GeoDataFrames (requires torch). Notes ----- - Returns False if either PyTorch or PyTorch Geometric is missing - Used internally by torch-dependent functions to provide helpful error messages Examples -------- Check availability before using torch-dependent functions: >>> from city2graph.graph import is_torch_available >>> >>> if is_torch_available(): ... from city2graph.graph import gdf_to_pyg ... data = gdf_to_pyg(nodes_gdf, edges_gdf) ... else: ... print("PyTorch Geometric not available.") """ return TORCH_AVAILABLE
def _get_device(device: str | torch.device | None) -> torch.device: """ Normalize the device argument and return a torch.device instance. This function provides a consistent interface for device specification across the library, handling automatic device selection and validation. Parameters ---------- device : str, torch.device, or None Device specification. Can be 'cpu', 'cuda', a torch.device object, or None. If None, automatically selects CUDA if available, otherwise CPU. Returns ------- torch.device Normalized torch.device object. Raises ------ ImportError If PyTorch is not available. ValueError If device string is not 'cpu' or 'cuda', or if 'cuda' is selected but not available. TypeError If device is not a valid type. See Also -------- torch.device : PyTorch device specification. Examples -------- >>> device = _normalize_device("cuda") >>> device = _normalize_device(None) # Auto-selects best available """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Check for invalid types first if not isinstance(device, (str, torch.device)): raise TypeError(DEVICE_ERROR_MSG) try: result_device = torch.device(device) except RuntimeError as e: # Convert RuntimeError from torch.device() to ValueError for consistency raise ValueError(DEVICE_ERROR_MSG) from e if result_device.type == "cuda" and not torch.cuda.is_available(): msg = f"CUDA selected, but not available. {DEVICE_ERROR_MSG}" raise ValueError(msg) return result_device # ============================================================================ # EDGE COLUMN DETECTION FUNCTIONS # ============================================================================ # Removed: _get_source_target_keywords, _find_column_candidates, # _fallback_column_detection, _detect_edge_columns # These functions are no longer needed as edge relationships are derived from MultiIndex. # ============================================================================ # NODE PREPARATION FUNCTIONS # ============================================================================ def _create_node_id_mapping( node_gdf: gpd.GeoDataFrame, ) -> tuple[dict[str | int, int], str, list[str | int]]: """ Create mapping from node IDs (from index) to sequential integer indices. PyTorch Geometric requires nodes to be identified by sequential integers starting from 0. This function creates the necessary mapping from original node identifiers (taken from the GeoDataFrame index) to these indices. Parameters ---------- node_gdf : geopandas.GeoDataFrame GeoDataFrame containing node data. The index is used for node IDs. Returns ------- dict[str | int, int] Dictionary mapping original IDs to integer indices. str Always "index", indicating the DataFrame index was used. list[str | int] List of original IDs in order. See Also -------- _create_node_features : Convert node attributes to tensors. Examples -------- >>> import geopandas as gpd >>> gdf = gpd.GeoDataFrame({'id': [1, 2, 3]}) >>> mapping, node_type, ids = _create_node_mapping(gdf) """ # Use DataFrame index as the node identifier original_ids = node_gdf.index.tolist() id_mapping = {node_id: i for i, node_id in enumerate(original_ids)} return id_mapping, "index", original_ids def _create_node_features( node_gdf: gpd.GeoDataFrame, feature_cols: list[str] | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> torch.Tensor: """ Convert node attributes to PyTorch feature tensors. Extracts numerical data from specified columns and converts to a tensor suitable for graph neural network processing. Handles missing columns gracefully and ensures consistent data types. Parameters ---------- node_gdf : geopandas.GeoDataFrame GeoDataFrame containing node data. feature_cols : list[str], optional List of column names to use as features (None creates empty tensor). device : str or torch.device, optional Target device for tensor creation. dtype : torch.dtype, optional Data type for the tensor. Returns ------- torch.Tensor Float tensor of shape (num_nodes, num_features) containing node features. See Also -------- _create_node_positions : Extract spatial coordinates from geometries. Examples -------- >>> import geopandas as gpd >>> gdf = gpd.GeoDataFrame({'feature1': [1, 2], 'feature2': [3, 4]}) >>> tensor = _create_node_features(gdf, ['feature1', 'feature2']) """ device = _get_device(device) dtype = dtype or torch.float32 if feature_cols is None: # Return empty tensor when no feature columns specified return torch.zeros((len(node_gdf), 0), dtype=dtype, device=device) # Find valid columns that exist in the GeoDataFrame valid_cols = list(set(feature_cols) & set(node_gdf.columns)) if valid_cols: # Map torch dtype to numpy dtype for consistency numpy_dtype = torch.tensor(0, dtype=dtype).numpy().dtype features_array = node_gdf[valid_cols].to_numpy().astype(numpy_dtype) return torch.from_numpy(features_array).to(device=device, dtype=dtype) # Return empty tensor if no valid columns found return torch.zeros((len(node_gdf), 0), dtype=dtype, device=device) def _create_node_positions( node_gdf: gpd.GeoDataFrame, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> torch.Tensor | None: """ Extract spatial coordinates from node geometries. Converts geometric representations to coordinate tensors suitable for spatial graph neural networks. Handles various geometry types and provides consistent coordinate extraction. Parameters ---------- node_gdf : geopandas.GeoDataFrame GeoDataFrame with geometry column containing spatial data. device : str or torch.device, optional Target device for tensor creation. dtype : torch.dtype, optional Data type for position tensors. If None, uses torch.float32. Returns ------- torch.Tensor or None Float tensor of shape (num_nodes, 2) containing [x, y] coordinates. None if no geometry column found. See Also -------- _create_node_features : Convert node attributes to tensors. Notes ----- - Uses centroid coordinates for all geometry types. - Coordinates are in the original CRS of the GeoDataFrame. Examples -------- >>> import geopandas as gpd >>> from shapely.geometry import Point >>> gdf = gpd.GeoDataFrame(geometry=[Point(0, 0), Point(1, 1)]) >>> coords = _create_node_positions(gdf) """ # Get the device for tensor creation device = _get_device(device) dtype = dtype or torch.float32 # Get the geometry column geom_series = node_gdf.geometry # Get centroids of geometries if geom_series.crs and geom_series.crs.is_geographic: # Reproject to a suitable projected CRS (UTM) to get accurate centroids utm_crs = geom_series.estimate_utm_crs() centroids = geom_series.to_crs(utm_crs).centroid.to_crs(geom_series.crs) else: centroids = geom_series.centroid # Map torch dtype to numpy dtype for consistency numpy_dtype = torch.tensor(0, dtype=dtype).numpy().dtype pos_data = np.column_stack( [ centroids.x.to_numpy(), centroids.y.to_numpy(), ], ).astype(numpy_dtype) return torch.tensor(pos_data, dtype=dtype, device=device) # ============================================================================ # EDGE PREPARATION FUNCTIONS # ============================================================================ def _create_edge_features( edge_gdf: gpd.GeoDataFrame, feature_cols: list[str] | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> torch.Tensor: """ Convert edge attributes to PyTorch feature tensors. Similar to node features but for edge data. Commonly used for edge weights, distances, or other relationship-specific attributes. Parameters ---------- edge_gdf : geopandas.GeoDataFrame GeoDataFrame containing edge data. feature_cols : list[str], optional List of column names to use as features. device : str or torch.device, optional Target device for tensor creation. dtype : torch.dtype, optional Data type for the tensor. Returns ------- torch.Tensor Float tensor of shape (num_edges, num_features) containing edge features. See Also -------- _create_node_features : Convert node attributes to tensors. Examples -------- >>> import geopandas as gpd >>> gdf = gpd.GeoDataFrame({'weight': [1.0, 2.0]}) >>> tensor = _create_edge_features(gdf, ['weight']) """ device = _get_device(device) dtype = dtype or torch.float32 # If no feature columns specified, return empty tensor if feature_cols is None: return torch.empty((edge_gdf.shape[0], 0), dtype=dtype, device=device) # Find valid columns that exist in the GeoDataFrame valid_cols = list(set(feature_cols) & set(edge_gdf.columns)) if not valid_cols: return torch.empty((edge_gdf.shape[0], 0), dtype=dtype, device=device) # Select only numeric columns from valid_cols to prevent conversion errors numeric_cols = edge_gdf[valid_cols].select_dtypes(include=np.number).columns.tolist() # Map torch dtype to numpy dtype for consistency numpy_dtype = torch.tensor(0, dtype=dtype).numpy().dtype features_array = edge_gdf[numeric_cols].to_numpy().astype(numpy_dtype) return torch.from_numpy(features_array).to(device=device, dtype=dtype) def _create_edge_indices( edge_gdf: gpd.GeoDataFrame, source_mapping: dict[str | int, int], target_mapping: dict[str | int, int] | None = None, ) -> list[list[int]]: """ Create edge connectivity matrix from edge data using MultiIndex. Extracts source and target node IDs from the MultiIndex of the edge GeoDataFrame and maps them to sequential integer indices required by PyTorch Geometric. Parameters ---------- edge_gdf : geopandas.GeoDataFrame GeoDataFrame with MultiIndex containing (source_id, target_id) pairs. source_mapping : dict[str | int, int] Mapping from original source node IDs to integer indices. target_mapping : dict[str | int, int], optional Mapping from original target node IDs to integer indices. If None, uses source_mapping. Returns ------- list[list[int]] Edge connectivity matrix as [source_indices, target_indices]. See Also -------- _create_node_mapping : Create node ID mappings. Examples -------- >>> import geopandas as gpd >>> import pandas as pd >>> idx = pd.MultiIndex.from_tuples([(0, 1), (1, 2)]) >>> gdf = gpd.GeoDataFrame(index=idx) >>> mapping = {0: 0, 1: 1, 2: 2} >>> edges = _create_edge_index(gdf, mapping) """ target_mapping = target_mapping or source_mapping # Extract source and target IDs from MultiIndex source_ids, target_ids = _extract_edge_ids(edge_gdf) # Convert types if needed and validate source_ids = pd.Series(source_ids) if isinstance(source_ids, pd.Index) else source_ids target_ids = pd.Series(target_ids) if isinstance(target_ids, pd.Index) else target_ids return _map_edge_ids_to_indices(source_ids, target_ids, source_mapping, target_mapping) def _extract_edge_ids(edge_gdf: gpd.GeoDataFrame) -> tuple[pd.Series, pd.Series]: """ Extract source and target IDs from MultiIndex DataFrame. This helper function extracts the source and target node identifiers from the two levels of a MultiIndex, which represent edge relationships. Parameters ---------- edge_gdf : geopandas.GeoDataFrame GeoDataFrame with MultiIndex containing (source_id, target_id) pairs. Returns ------- tuple[pd.Series, pd.Series] Source IDs and target IDs from the MultiIndex levels. See Also -------- _create_edge_index : Create edge connectivity matrix. Examples -------- >>> import geopandas as gpd >>> import pandas as pd >>> idx = pd.MultiIndex.from_tuples([(0, 1), (1, 2)]) >>> gdf = gpd.GeoDataFrame(index=idx) >>> src, tgt = _extract_edge_ids(gdf) """ return ( edge_gdf.index.get_level_values(0), # First level = source edge_gdf.index.get_level_values(1), ) # Second level = target def _map_edge_ids_to_indices( source_ids: pd.Series, target_ids: pd.Series, source_mapping: dict[str | int, int], target_mapping: dict[str | int, int], ) -> list[list[int]]: """ Map edge IDs to indices. This function converts original edge node IDs to sequential integer indices required by PyTorch Geometric, filtering out invalid edges in the process. Parameters ---------- source_ids : pd.Series Series of source node IDs. target_ids : pd.Series Series of target node IDs. source_mapping : dict[str | int, int] Mapping from source node IDs to indices. target_mapping : dict[str | int, int] Mapping from target node IDs to indices. Returns ------- list[list[int]] Edge connectivity matrix as [source_indices, target_indices]. See Also -------- _extract_edge_ids : Extract IDs from MultiIndex. Examples -------- >>> import pandas as pd >>> src = pd.Series([0, 1]) >>> tgt = pd.Series([1, 2]) >>> mapping = {0: 0, 1: 1, 2: 2} >>> edges = _map_edge_ids_to_indices(src, tgt, mapping, mapping) """ # Find edges with valid source and target nodes valid_src_mask = source_ids.isin(source_mapping.keys()) valid_dst_mask = target_ids.isin(target_mapping.keys()) valid_edges_mask = valid_src_mask & valid_dst_mask # Process valid edges using vectorized operations valid_sources = source_ids[valid_edges_mask] valid_targets = target_ids[valid_edges_mask] # Map original node IDs to integer indices from_indices: np.ndarray[tuple[int, ...], np.dtype[np.int64]] = valid_sources.map( source_mapping, ).to_numpy(dtype=int) to_indices: np.ndarray[tuple[int, ...], np.dtype[np.int64]] = valid_targets.map( target_mapping, ).to_numpy(dtype=int) combined_array = np.column_stack([from_indices, to_indices]).astype(int) result: list[list[int]] = combined_array.tolist() return result def _create_linestring_geometries( edge_index_array: np.ndarray[tuple[int, ...], np.dtype[np.int64]], src_pos: np.ndarray[tuple[int, ...], np.dtype[np.float64]], dst_pos: np.ndarray[tuple[int, ...], np.dtype[np.float64]], ) -> list[LineString | None]: """ Generate LineString geometries from node positions and edge connectivity. Creates geometric representations of edges by connecting source and target node coordinates. Useful for visualization and spatial analysis of networks. Parameters ---------- edge_index_array : np.ndarray Array of shape (2, num_edges) with source/target indices. src_pos : np.ndarray Array of source node coordinates. dst_pos : np.ndarray Array of target node coordinates. Returns ------- list[LineString | None] List of LineString objects connecting source to target nodes. None entries for invalid/out-of-bounds edges. See Also -------- _create_edge_index : Create edge connectivity matrix. Notes ----- - Performs bounds checking to avoid index errors. - Only uses first 2 dimensions of position data (x, y). - Returns None for edges with invalid node indices. Examples -------- >>> import numpy as np >>> edge_index = np.array([[0, 1], [1, 2]]) >>> src_pos = np.array([[0, 0], [1, 1]]) >>> dst_pos = np.array([[1, 1], [2, 2]]) >>> lines = _create_edge_geometries(edge_index, src_pos, dst_pos) """ if edge_index_array.size == 0: return [] src_indices = edge_index_array[0] dst_indices = edge_index_array[1] # Vectorized bounds checking valid_src_mask = src_indices < len(src_pos) valid_dst_mask = dst_indices < len(dst_pos) valid_mask = valid_src_mask & valid_dst_mask # Get valid indices and coordinates valid_src_indices = src_indices[valid_mask] valid_dst_indices = dst_indices[valid_mask] src_coords = src_pos[valid_src_indices][:, :2] dst_coords = dst_pos[valid_dst_indices][:, :2] # Create LineStrings using vectorized coordinate pairing coord_pairs = np.stack([src_coords, dst_coords], axis=1) # Vectorized LineString creation - use map for better performance valid_geometries = list(map(LineString, coord_pairs)) # Vectorized assignment using fancy indexing geometries = np.full(len(src_indices), None, dtype=object) geometries[valid_mask] = valid_geometries return geometries.tolist() # ============================================================================ # GRAPH BUILDING FUNCTIONS # ============================================================================ def _build_homogeneous_graph( nodes_gdf: gpd.GeoDataFrame, edges_gdf: gpd.GeoDataFrame | None = None, node_feature_cols: list[str] | None = None, node_label_cols: list[str] | None = None, edge_feature_cols: list[str] | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> Data: """ Construct a homogeneous PyTorch Geometric Data object. Creates a single-type graph where all nodes and edges are treated uniformly. Node IDs are taken from the nodes_gdf index. Edge relationships are taken from the edges_gdf MultiIndex (source_id, target_id). Processing Pipeline: 1. Create node ID mapping (original IDs from index → integer indices) 2. Extract node features and positions from geometry 3. Process node labels if available 4. Create edge connectivity matrix 5. Extract edge features 6. Package everything into PyG Data object 7. Store metadata for bidirectional conversion Parameters ---------- nodes_gdf : geopandas.GeoDataFrame GeoDataFrame containing node data (index used for IDs). edges_gdf : geopandas.GeoDataFrame, optional GeoDataFrame containing edge data (MultiIndex used for relationships). node_feature_cols : list[str], optional Columns to use as node features. node_label_cols : list[str], optional Columns to use as node labels. edge_feature_cols : list[str], optional Columns to use as edge features. device : str or torch.device, optional Target device for tensor creation. dtype : torch.dtype, optional Data type for float tensors. Returns ------- Data PyTorch Geometric Data object with all graph components. See Also -------- create_heterogeneous_graph : Create multi-type graphs. Notes ----- - Preserves original index names and values for reconstruction. - Stores metadata for bidirectional conversion. - Handles missing edges gracefully (creates empty edge tensors). - Maintains CRS information if available. Examples -------- >>> import geopandas as gpd >>> nodes = gpd.GeoDataFrame({'feature': [1, 2]}) >>> data = create_homogeneous_graph(nodes) """ device = _get_device(device) # Node processing id_mapping, id_col_name, original_ids = _create_node_id_mapping(nodes_gdf) x = _create_node_features(nodes_gdf, node_feature_cols, device, dtype) pos = _create_node_positions(nodes_gdf, device, dtype) # Handle labels y = None if node_label_cols: y = _create_node_features(nodes_gdf, node_label_cols, device, dtype) # Edge processing edge_index = torch.zeros((2, 0), dtype=torch.long, device=device) edge_attr = torch.empty((0, 0), dtype=dtype or torch.float32, device=device) if edges_gdf is not None and not edges_gdf.empty: edge_pairs = _create_edge_indices( edges_gdf, id_mapping, id_mapping, ) if edge_pairs: edge_index = torch.tensor( np.array(edge_pairs).T, dtype=torch.long, device=device, ) edge_attr = _create_edge_features(edges_gdf, edge_feature_cols, device, dtype) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos) # Store metadata metadata = GraphMetadata(is_hetero=False) metadata.node_mappings = { "default": { "mapping": id_mapping, "id_col": id_col_name, "original_ids": original_ids, }, } metadata.node_feature_cols = node_feature_cols or [] metadata.node_label_cols = node_label_cols or [] metadata.edge_feature_cols = edge_feature_cols or [] # Store index names and values for preservation metadata.node_index_names = nodes_gdf.index.names if hasattr(nodes_gdf.index, "names") else None if edges_gdf is not None and hasattr(edges_gdf.index, "names"): metadata.edge_index_names = edges_gdf.index.names # Store original edge index values for reconstruction metadata.edge_index_values = [ edges_gdf.index.get_level_values(i).tolist() for i in range(edges_gdf.index.nlevels) ] else: metadata.edge_index_names = None metadata.edge_index_values = None # Set CRS if hasattr(nodes_gdf, "crs") and nodes_gdf.crs: metadata.crs = nodes_gdf.crs data.crs = metadata.crs data.graph_metadata = metadata return data def _build_heterogeneous_graph( nodes_dict: dict[str, gpd.GeoDataFrame], edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame] | None = None, node_feature_cols: dict[str, list[str]] | None = None, node_label_cols: dict[str, list[str]] | None = None, edge_feature_cols: dict[str, list[str]] | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> HeteroData: """ Build heterogeneous PyTorch Geometric HeteroData object. Creates a multi-type graph where nodes and edges can have different types and schemas. Each node type can have different features and each edge type can connect different node types with different relationship semantics. Parameters ---------- nodes_dict : dict[str, geopandas.GeoDataFrame] Dictionary mapping node type names to their corresponding GeoDataFrames. edges_dict : dict[tuple[str, str, str], geopandas.GeoDataFrame], optional Dictionary mapping edge type tuples (source_type, relation, target_type) to their corresponding GeoDataFrames. node_feature_cols : dict[str, list[str]], optional Dictionary mapping node types to lists of feature column names. node_label_cols : dict[str, list[str]], optional Dictionary mapping node types to lists of label column names. edge_feature_cols : dict[str, list[str]], optional Dictionary mapping edge types to lists of feature column names. device : str or torch.device, optional Target device for tensor creation. dtype : torch.dtype, optional Data type for float tensors. Returns ------- HeteroData PyTorch Geometric HeteroData object with all graph components. See Also -------- create_homogeneous_graph : Create single-type graphs. Examples -------- >>> import geopandas as gpd >>> nodes = {'buildings': gpd.GeoDataFrame({'area': [100, 200]})} >>> data = create_heterogeneous_graph(nodes) """ device = _get_device(device) data = HeteroData() # Default empty dicts edges_dict = edges_dict or {} # Process nodes and get mappings node_mappings = _process_hetero_nodes( data, nodes_dict, node_feature_cols, node_label_cols, device, dtype, ) # Process edges _process_hetero_edges( data, edges_dict, node_mappings, edge_feature_cols, device, dtype, ) # Store metadata _store_hetero_metadata( data, node_mappings, nodes_dict, edges_dict, node_feature_cols, node_label_cols, edge_feature_cols, ) return data def _process_hetero_nodes( data: HeteroData, nodes_dict: dict[str, gpd.GeoDataFrame], node_feature_cols: dict[str, list[str]] | None, node_label_cols: dict[str, list[str]] | None, device: str | torch.device | None, dtype: torch.dtype | None, ) -> dict[str, dict[str, dict[str | int, int] | str | list[str | int]]]: """ Process all node types for heterogeneous graph. Handles node processing for each node type in a heterogeneous graph, creating mappings, features, and labels for each type independently. Parameters ---------- data : HeteroData The HeteroData object to populate with node information. nodes_dict : dict[str, geopandas.GeoDataFrame] Dictionary mapping node type names to their GeoDataFrames. node_feature_cols : dict[str, list[str]], optional Dictionary mapping node types to feature column lists. node_label_cols : dict[str, list[str]], optional Dictionary mapping node types to label column lists. device : str or torch.device, optional Target device for tensor creation. dtype : torch.dtype, optional Data type for float tensors. Returns ------- dict[str, dict[str, dict[str | int, int] | str | list[str | int]]] Dictionary containing node mappings and metadata for each node type. See Also -------- _process_hetero_edges : Process edge types for heterogeneous graphs. Examples -------- >>> data = HeteroData() >>> nodes = {'buildings': gpd.GeoDataFrame()} >>> mappings = _process_hetero_nodes(data, nodes, None, None, 'cpu', torch.float32) """ node_mappings: dict[str, dict[str, dict[str | int, int] | str | list[str | int]]] = {} device = _get_device(device) for node_type, node_gdf in nodes_dict.items(): id_mapping, id_col_name, original_ids = _create_node_id_mapping(node_gdf) # Store mapping with metadata in unified structure node_mappings[node_type] = { "mapping": id_mapping, "id_col": id_col_name, "original_ids": original_ids, } # Features feature_cols = node_feature_cols.get(node_type) if node_feature_cols else None data[node_type].x = _create_node_features(node_gdf, feature_cols, device, dtype) # Positions data[node_type].pos = _create_node_positions(node_gdf, device, dtype) # Labels label_cols = node_label_cols.get(node_type) if node_label_cols else None if label_cols: data[node_type].y = _create_node_features(node_gdf, label_cols, device, dtype) return node_mappings def _process_hetero_edges( data: HeteroData, edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame], node_mappings: dict[str, dict[str, dict[str | int, int] | str | list[str | int]]], edge_feature_cols: dict[str, list[str]] | None, device: str | torch.device | None, dtype: torch.dtype | None, ) -> None: """ Process all edge types for heterogeneous graph. Handles edge processing for each edge type in a heterogeneous graph, creating connectivity matrices and features for relationships between different node types. Parameters ---------- data : HeteroData The HeteroData object to populate with edge information. edges_dict : dict[tuple[str, str, str], geopandas.GeoDataFrame] Dictionary mapping edge type tuples to their GeoDataFrames. node_mappings : dict[str, dict[str, dict[str | int, int] | str | list[str | int]]] Node mappings and metadata from node processing. edge_feature_cols : dict[str, list[str]], optional Dictionary mapping edge types to feature column lists. device : str or torch.device, optional Target device for tensor creation. dtype : torch.dtype, optional Data type for float tensors. See Also -------- _process_hetero_nodes : Process node types for heterogeneous graphs. Examples -------- >>> data = HeteroData() >>> edges = {('building', 'near', 'building'): gpd.GeoDataFrame()} >>> _process_hetero_edges(data, edges, node_mappings, None, 'cpu', torch.float32) """ device = _get_device(device) for edge_type, edge_gdf in edges_dict.items(): # Extract source, relation, and destination types from edge_type tuple src_type, rel_type, dst_type = edge_type # Get the mapping dictionaries (not the full metadata) # The type system guarantees these are dictionaries based on _process_hetero_nodes src_mapping_raw = node_mappings[src_type]["mapping"] dst_mapping_raw = node_mappings[dst_type]["mapping"] # Type assertion for mypy - these are guaranteed to be dicts by construction assert isinstance(src_mapping_raw, dict), f"Expected dict mapping for {src_type}" assert isinstance(dst_mapping_raw, dict), f"Expected dict mapping for {dst_type}" src_mapping: dict[str | int, int] = src_mapping_raw dst_mapping: dict[str | int, int] = dst_mapping_raw if edge_gdf is not None and not edge_gdf.empty: edge_pairs = _create_edge_indices( edge_gdf, src_mapping, dst_mapping, ) edge_index = ( torch.tensor(np.array(edge_pairs).T, dtype=torch.long, device=device) if edge_pairs else torch.zeros((2, 0), dtype=torch.long, device=device) ) data[edge_type].edge_index = edge_index feature_cols = edge_feature_cols.get(rel_type) if edge_feature_cols else None data[edge_type].edge_attr = _create_edge_features(edge_gdf, feature_cols, device, dtype) else: data[edge_type].edge_index = torch.zeros((2, 0), dtype=torch.long, device=device) data[edge_type].edge_attr = torch.empty( (0, 0), dtype=dtype or torch.float32, device=device, ) def _store_hetero_metadata( data: HeteroData, node_mappings: dict[str, dict[str, dict[str | int, int] | str | list[str | int]]], nodes_dict: dict[str, gpd.GeoDataFrame], edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame], node_feature_cols: dict[str, list[str]] | None, node_label_cols: dict[str, list[str]] | None, edge_feature_cols: dict[str, list[str]] | None, ) -> None: """ Store metadata for heterogeneous graph. Stores all necessary metadata for bidirectional conversion between HeteroData and GeoDataFrames, including mappings, column information, and graph structure. Parameters ---------- data : HeteroData The HeteroData object to store metadata in. node_mappings : dict[str, dict[str, dict[str | int, int] | str | list[str | int]]] Node mappings and metadata from node processing. nodes_dict : dict[str, geopandas.GeoDataFrame] Dictionary mapping node type names to their GeoDataFrames. edges_dict : dict[tuple[str, str, str], geopandas.GeoDataFrame] Dictionary mapping edge type tuples to their GeoDataFrames. node_feature_cols : dict[str, list[str]], optional Dictionary mapping node types to feature column lists. node_label_cols : dict[str, list[str]], optional Dictionary mapping node types to label column lists. edge_feature_cols : dict[str, list[str]], optional Dictionary mapping edge types to feature column lists. See Also -------- _process_hetero_nodes : Process node types for heterogeneous graphs. _process_hetero_edges : Process edge types for heterogeneous graphs. Examples -------- >>> data = HeteroData() >>> _store_hetero_metadata(data, mappings, nodes, edges, None, None, None) """ # Store mappings and column metadata metadata = GraphMetadata(is_hetero=True) metadata.node_types = list(nodes_dict.keys()) metadata.edge_types = list(edges_dict.keys()) metadata.node_mappings = node_mappings metadata.node_feature_cols = node_feature_cols or {} metadata.node_label_cols = node_label_cols or {} metadata.edge_feature_cols = edge_feature_cols or {} # Store index names for reconstruction metadata.node_index_names = {} for node_type, node_gdf in nodes_dict.items(): if hasattr(node_gdf.index, "names"): metadata.node_index_names[node_type] = node_gdf.index.names # Store edge index names and values for reconstruction metadata.edge_index_names = {} metadata.edge_index_values = {} for edge_type, edge_gdf in edges_dict.items(): if edge_gdf is not None and hasattr(edge_gdf.index, "names"): # Store edge index names metadata.edge_index_names[edge_type] = edge_gdf.index.names # Store original edge index values for reconstruction metadata.edge_index_values[edge_type] = [ edge_gdf.index.get_level_values(i).tolist() for i in range(edge_gdf.index.nlevels) ] # Set CRS crs_values = [gdf.crs for gdf in nodes_dict.values() if hasattr(gdf, "crs") and gdf.crs] if crs_values and all(crs == crs_values[0] for crs in crs_values): metadata.crs = crs_values[0] data.crs = metadata.crs data.graph_metadata = metadata # ============================================================================ # GRAPH VALIDATION FUNCTIONS # ============================================================================
[docs] def validate_pyg(data: Data | HeteroData) -> GraphMetadata: """ Validate PyTorch Geometric Data or HeteroData objects and return metadata. This centralized validation function performs comprehensive validation of PyG objects, including type checking, metadata validation, and structural consistency checks. It serves as the single point of validation for all PyG objects in city2graph. Parameters ---------- data : torch_geometric.data.Data or torch_geometric.data.HeteroData PyTorch Geometric data object to validate. Returns ------- GraphMetadata Metadata object containing graph information for reconstruction. Raises ------ ImportError If PyTorch Geometric is not installed. TypeError If data is not a valid PyTorch Geometric object. ValueError If the data object is missing required metadata or is inconsistent. See Also -------- pyg_to_gdf : Convert PyG objects to GeoDataFrames. pyg_to_nx : Convert PyG objects to NetworkX graphs. Examples -------- >>> data = create_homogeneous_graph(nodes_gdf) >>> metadata = _validate_pyg_data(data) """ # Check PyTorch availability first if not TORCH_AVAILABLE: raise ImportError(TORCH_ERROR_MSG) # Comprehensive type checking for PyG objects if not isinstance(data, (Data, HeteroData)): # Provide detailed error message based on the actual type actual_type = type(data).__name__ msg = ( f"Input must be a PyTorch Geometric Data or HeteroData object, " f"got {actual_type}. Ensure you have PyTorch Geometric installed " f"and are passing a valid PyG object." ) raise TypeError(msg) # Validate metadata presence and type if not hasattr(data, "graph_metadata"): msg = ( "PyG object is missing 'graph_metadata' attribute. " "This object may not have been created by city2graph. " "Use city2graph.graph.gdf_to_pyg() or city2graph.graph.nx_to_pyg() " "to create compatible PyG objects." ) raise ValueError(msg) if not isinstance(data.graph_metadata, GraphMetadata): actual_metadata_type = type(data.graph_metadata).__name__ msg = ( f"PyG object has 'graph_metadata' of incorrect type: {actual_metadata_type}. " f"Expected GraphMetadata. This object may not have been created by city2graph." ) raise TypeError(msg) metadata = data.graph_metadata is_hetero = isinstance(data, HeteroData) # Validate consistency between PyG object type and metadata if is_hetero and not metadata.is_hetero: msg = ( "Inconsistency detected: PyG object is HeteroData but metadata.is_hetero is False. " "This indicates corrupted metadata or an incorrectly constructed object." ) raise ValueError(msg) if not is_hetero and metadata.is_hetero: msg = ( "Inconsistency detected: PyG object is Data but metadata.is_hetero is True. " "This indicates corrupted metadata or an incorrectly constructed object." ) raise ValueError(msg) # Additional structural validation for heterogeneous graphs if is_hetero: _validate_hetero_structure(data, metadata) else: _validate_homo_structure(data, metadata) return metadata
def _validate_hetero_structure(data: HeteroData, metadata: GraphMetadata) -> None: """ Validate structural consistency of heterogeneous PyG data. Performs comprehensive validation of heterogeneous graph structure, ensuring that node types, edge types, and tensor dimensions are consistent. Parameters ---------- data : HeteroData The heterogeneous PyTorch Geometric data object to validate. metadata : GraphMetadata Metadata containing expected graph structure information. See Also -------- _validate_homo_structure : Validate homogeneous graph structure. Examples -------- >>> data = create_heterogeneous_graph(nodes_dict) >>> metadata = data._metadata >>> _validate_hetero_structure(data, metadata) """ # Check that node types in metadata match actual node types in data if metadata.node_types: actual_node_types = set(data.node_types) expected_node_types = set(metadata.node_types) if actual_node_types != expected_node_types: msg = ( f"Node types mismatch: metadata expects {expected_node_types}, " f"but PyG object has {actual_node_types}" ) raise ValueError(msg) # Check that edge types in metadata match actual edge types in data if metadata.edge_types: actual_edge_types = set(data.edge_types) expected_edge_types = set(metadata.edge_types) if actual_edge_types != expected_edge_types: msg = ( f"Edge types mismatch: metadata expects {expected_edge_types}, " f"but PyG object has {actual_edge_types}" ) raise ValueError(msg) # Validate tensor shape consistency for each node type for node_type in data.node_types: node_data = data[node_type] if hasattr(node_data, "x") and node_data.x is not None: num_nodes = node_data.x.size(0) # Check position tensor consistency if ( hasattr(node_data, "pos") and node_data.pos is not None and node_data.pos.size(0) != num_nodes ): msg = ( f"Node type '{node_type}': position tensor size ({node_data.pos.size(0)}) " f"doesn't match node feature tensor size ({num_nodes})" ) raise ValueError(msg) # Check label tensor consistency if ( hasattr(node_data, "y") and node_data.y is not None and node_data.y.size(0) != num_nodes ): msg = ( f"Node type '{node_type}': label tensor size ({node_data.y.size(0)}) " f"doesn't match node feature tensor size ({num_nodes})" ) raise ValueError(msg) def _validate_homo_structure(data: Data, metadata: GraphMetadata) -> None: """ Validate structural consistency of homogeneous PyG data. Performs comprehensive validation of homogeneous graph structure, ensuring that tensor dimensions are consistent and metadata is properly structured. Parameters ---------- data : Data The homogeneous PyTorch Geometric data object to validate. metadata : GraphMetadata Metadata containing expected graph structure information. See Also -------- _validate_hetero_structure : Validate heterogeneous graph structure. Examples -------- >>> data = create_homogeneous_graph(nodes_gdf) >>> metadata = data._metadata >>> _validate_homo_structure(data, metadata) """ # Validate that metadata has the expected structure for homogeneous graphs if metadata.node_types and len(metadata.node_types) > 0: msg = "Homogeneous graph metadata should not have node_types specified" raise ValueError(msg) if metadata.edge_types and len(metadata.edge_types) > 0: msg = "Homogeneous graph metadata should not have edge_types specified" raise ValueError(msg) # Validate that node mappings use the "default" key for homogeneous graphs if metadata.node_mappings and "default" not in metadata.node_mappings: msg = "Homogeneous graph metadata should use 'default' key in node_mappings" raise ValueError(msg) # Validate that feature/label columns are lists, not dicts if metadata.node_feature_cols and not isinstance(metadata.node_feature_cols, list): msg = "Homogeneous graph metadata should have node_feature_cols as list, not dict" raise ValueError(msg) if metadata.node_label_cols and not isinstance(metadata.node_label_cols, list): msg = "Homogeneous graph metadata should have node_label_cols as list, not dict" raise ValueError(msg) if metadata.edge_feature_cols and not isinstance(metadata.edge_feature_cols, list): msg = "Homogeneous graph metadata should have edge_feature_cols as list, not dict" raise ValueError(msg) # Validate tensor shape consistency if hasattr(data, "x") and data.x is not None: num_nodes = data.x.size(0) # Check position tensor consistency if hasattr(data, "pos") and data.pos is not None and data.pos.size(0) != num_nodes: msg = ( f"Node position tensor size ({data.pos.size(0)}) " f"doesn't match node feature tensor size ({num_nodes})" ) raise ValueError(msg) # Check label tensor consistency if hasattr(data, "y") and data.y is not None and data.y.size(0) != num_nodes: msg = ( f"Node label tensor size ({data.y.size(0)}) " f"doesn't match node feature tensor size ({num_nodes})" ) raise ValueError(msg) # Validate edge tensor consistency if hasattr(data, "edge_index") and data.edge_index is not None: num_edges = data.edge_index.size(1) # Check edge attribute tensor consistency if ( hasattr(data, "edge_attr") and data.edge_attr is not None and data.edge_attr.size(0) != num_edges ): msg = ( f"Edge attribute tensor size ({data.edge_attr.size(0)}) " f"doesn't match number of edges ({num_edges})" ) raise ValueError(msg) # ============================================================================ # GRAPH RECONSTRUCTION FUNCTIONS # ============================================================================ def _extract_tensor_data( tensor: torch.Tensor | None, column_names: list[str] | None = None, ) -> dict[str, np.ndarray[tuple[int, ...], np.dtype[np.float32]]]: """ Extract data from tensor with proper column names. Converts PyTorch tensors to numpy arrays and maps them to column names for reconstruction of GeoDataFrame columns. Parameters ---------- tensor : torch.Tensor, optional Input tensor containing feature data. column_names : list[str], optional List of column names to map tensor columns to. Returns ------- dict[str, np.ndarray] Dictionary mapping column names to numpy arrays. See Also -------- _get_node_data_info : Get node data and count information. Examples -------- >>> tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> cols = ['feature1', 'feature2'] >>> data = _extract_tensor_data(tensor, cols) """ if tensor is None or tensor.numel() == 0 or column_names is None: return {} features_array = tensor.detach().cpu().numpy() num_cols = min(len(column_names), features_array.shape[1]) return {column_names[i]: features_array[:, i] for i in range(num_cols)} def _get_node_data_info( data: Data | HeteroData, node_type: str | None, metadata: GraphMetadata, ) -> tuple[Data | HeteroData, int]: """ Get node data and number of nodes. Extracts node-specific data from PyG objects, handling both homogeneous and heterogeneous graphs appropriately. Parameters ---------- data : Data or HeteroData PyTorch Geometric data object. node_type : str, optional Node type for heterogeneous graphs. metadata : GraphMetadata Metadata containing graph structure information. Returns ------- tuple[Data | HeteroData, int] Node data object and number of nodes. See Also -------- _extract_tensor_data : Extract data from tensors. Examples -------- >>> node_data, num_nodes = _get_node_data_info(data, 'building', metadata) """ node_data = data[node_type] if metadata.is_hetero and node_type else data return node_data, int(node_data.num_nodes) def _get_mapping_info( node_type: str | None, metadata: GraphMetadata, ) -> dict[str, dict[str | int, int] | str | list[str | int]] | None: """ Get mapping info for the given node type. This function retrieves mapping information from the metadata for a specific node type, handling both homogeneous and heterogeneous graphs. Parameters ---------- node_type : str or None The type of node to get mapping info for. If None, uses default mapping. metadata : GraphMetadata The graph metadata containing node mappings. Returns ------- dict or None Dictionary containing mapping information with keys like 'original_ids', or None if no mapping exists for the given node type. See Also -------- _extract_index_values : Extract index values from mapping info. Examples -------- >>> metadata = GraphMetadata(is_hetero=True, node_mappings={'building': {...}}) >>> mapping = _get_mapping_info('building', metadata) """ mapping_key = "default" if not metadata.is_hetero or not node_type else node_type return metadata.node_mappings.get(mapping_key) def _extract_index_values( mapping_info: dict[str, dict[str | int, int] | str | list[str | int]], num_nodes: int, ) -> list[str | int]: """ Extract index values from mapping info. This function extracts the original node IDs from mapping information, ensuring the returned list has the correct length. Parameters ---------- mapping_info : dict Dictionary containing mapping information with 'original_ids' key. num_nodes : int Number of nodes to extract IDs for. Returns ------- list of str or int List of original node IDs, truncated to num_nodes length. See Also -------- _get_mapping_info : Get mapping info for a given node type. Examples -------- >>> mapping_info = {'original_ids': ['a', 'b', 'c', 'd']} >>> ids = _extract_index_values(mapping_info, 3) >>> print(ids) # ['a', 'b', 'c'] """ original_ids = mapping_info.get("original_ids", list(range(num_nodes))) # Convert to list if not already, then slice to num_nodes ids_list = original_ids if isinstance(original_ids, list) else list(range(num_nodes)) return ids_list[:num_nodes] def _create_geometry_from_positions(node_data: Data | HeteroData) -> gpd.array.GeometryArray | None: """ Create geometry from node positions. This function converts node position tensors into GeoPandas geometry objects for spatial analysis and visualization. Parameters ---------- node_data : Data or HeteroData PyTorch Geometric data object containing node positions. Returns ------- gpd.array.GeometryArray or None Array of Point geometries created from node positions, or None if no position data is available. See Also -------- _extract_node_features_and_labels : Extract features and labels from node data. Examples -------- >>> import torch >>> from torch_geometric.data import Data >>> data = Data(pos=torch.tensor([[0.0, 1.0], [2.0, 3.0]])) >>> geom = _create_geometry_from_positions(data) """ if not hasattr(node_data, "pos") or node_data.pos is None: return None pos_array: np.ndarray[tuple[int, ...], np.dtype[np.float32]] = ( node_data.pos.detach().cpu().numpy() ) return gpd.points_from_xy(pos_array[:, 0], pos_array[:, 1]) def _extract_node_features_and_labels( node_data: Data | HeteroData, node_type: str | None, metadata: GraphMetadata, ) -> dict[str, np.ndarray[tuple[int, ...], np.dtype[np.float32]]]: """ Extract features and labels from node data. This function extracts node features and labels from PyTorch Geometric data objects, handling both homogeneous and heterogeneous graphs. Parameters ---------- node_data : Data or HeteroData PyTorch Geometric data object containing node features and labels. node_type : str or None The type of nodes to extract data for. Required for heterogeneous graphs. metadata : GraphMetadata Graph metadata containing information about feature and label mappings. Returns ------- dict Dictionary mapping column names to numpy arrays containing the extracted features and labels. See Also -------- _create_geometry_from_positions : Create geometry from node positions. Examples -------- >>> import torch >>> from torch_geometric.data import Data >>> data = Data(x=torch.tensor([[1.0, 2.0], [3.0, 4.0]])) >>> features = _extract_node_features_and_labels(data, None, metadata) """ gdf_data = {} is_hetero = metadata.is_hetero # Extract features if hasattr(node_data, "x") and node_data.x is not None and metadata.node_feature_cols: feature_cols = metadata.node_feature_cols feature_cols_list: list[str] | None = None if is_hetero and node_type and isinstance(feature_cols, dict): feature_cols_list = feature_cols.get(node_type) elif not is_hetero and isinstance(feature_cols, list): feature_cols_list = feature_cols features_dict = _extract_tensor_data(node_data.x, feature_cols_list) gdf_data.update(features_dict) # Extract labels if hasattr(node_data, "y") and node_data.y is not None and metadata.node_label_cols: label_cols = metadata.node_label_cols label_cols_list: list[str] | None = None if is_hetero and node_type and isinstance(label_cols, dict): label_cols_list = label_cols.get(node_type) elif not is_hetero and isinstance(label_cols, list): label_cols_list = label_cols labels_dict = _extract_tensor_data(node_data.y, label_cols_list) gdf_data.update(labels_dict) return gdf_data def _set_gdf_index_and_crs( gdf: gpd.GeoDataFrame, node_type: str | None, metadata: GraphMetadata, ) -> None: """ Set index names and CRS on GeoDataFrame. This function configures the index names and coordinate reference system for a GeoDataFrame based on metadata information. Parameters ---------- gdf : gpd.GeoDataFrame The GeoDataFrame to configure. node_type : str or None The type of nodes in the GeoDataFrame. metadata : GraphMetadata Graph metadata containing index names and CRS information. See Also -------- _reconstruct_node_gdf : Reconstruct node GeoDataFrame from PyTorch data. Examples -------- >>> import geopandas as gpd >>> gdf = gpd.GeoDataFrame({'col1': [1, 2]}) >>> _set_gdf_index_and_crs(gdf, 'building', metadata) """ # Set index names if metadata.node_index_names: index_names: list[str] | None = None # Get index names based on heterogeneity and node type if metadata.is_hetero and node_type and isinstance(metadata.node_index_names, dict): index_names = metadata.node_index_names.get(node_type) elif not metadata.is_hetero and isinstance(metadata.node_index_names, list): index_names = metadata.node_index_names # Set index name if available if ( index_names and hasattr(gdf.index, "names") and isinstance(index_names, list) and len(index_names) > 0 ): gdf.index.name = index_names[0] # Set CRS if metadata.crs and hasattr(gdf, "geometry") and gdf.geometry is not None: # Check if the geometry column is truly empty or all null if gdf.empty or gdf.geometry.isna().all(): gdf.crs = metadata.crs else: # Use set_crs for non-empty geometries gdf.set_crs(metadata.crs, allow_override=True, inplace=True) # If no geometry column, we can't set CRS - skip silently def _reconstruct_node_gdf( data: Data | HeteroData, metadata: GraphMetadata, node_type: str | None = None, ) -> gpd.GeoDataFrame: """ Reconstruct node GeoDataFrame from PyTorch Geometric data. This function reconstructs a GeoDataFrame containing node information from PyTorch Geometric data objects and metadata. Parameters ---------- data : Data or HeteroData PyTorch Geometric data object containing node information. metadata : GraphMetadata Graph metadata with mapping and feature information. node_type : str, optional The type of nodes to reconstruct. Required for heterogeneous graphs. Returns ------- gpd.GeoDataFrame GeoDataFrame containing reconstructed node data with geometry, features, and proper indexing. See Also -------- _extract_node_features_and_labels : Extract features and labels from node data. _create_geometry_from_positions : Create geometry from node positions. Examples -------- >>> from torch_geometric.data import Data >>> import torch >>> data = Data(x=torch.tensor([[1.0, 2.0]]), pos=torch.tensor([[0.0, 1.0]])) >>> gdf = _reconstruct_node_gdf(data, metadata) """ node_data, num_nodes = _get_node_data_info(data, node_type, metadata) mapping_info = _get_mapping_info(node_type, metadata) # Extract node IDs and features/labels gdf_data = {} features_labels = _extract_node_features_and_labels(node_data, node_type, metadata) gdf_data.update(features_labels) # Create geometry geometry = _create_geometry_from_positions(node_data) index_values = _extract_index_values(mapping_info, num_nodes) if mapping_info else None gdf = gpd.GeoDataFrame(gdf_data, geometry=geometry, index=index_values) _set_gdf_index_and_crs(gdf, node_type, metadata) return gdf def _reconstruct_edge_index( edge_type: str | tuple[str, str, str] | None, is_hetero: bool, edge_data_dict: dict[str, np.ndarray[tuple[int, ...], np.dtype[np.float32]]], metadata: GraphMetadata, ) -> pd.Index | pd.MultiIndex | None: """ Reconstruct edge index from stored values. This function reconstructs pandas Index or MultiIndex objects for edges from stored values in the metadata. Parameters ---------- edge_type : str, tuple, or None The type of edges to reconstruct index for. is_hetero : bool Whether the graph is heterogeneous. edge_data_dict : dict Dictionary containing edge data arrays. metadata : GraphMetadata Graph metadata containing stored edge index values. Returns ------- pd.Index, pd.MultiIndex, or None Reconstructed index for the edges, or None if no stored values exist. See Also -------- _extract_edge_features : Extract edge features from data. Examples -------- >>> edge_data = {'feature1': np.array([1, 2, 3])} >>> index = _reconstruct_edge_index('road', False, edge_data, metadata) """ stored_values: list[list[str | int]] | None = None if is_hetero and edge_type and isinstance(metadata.edge_index_values, dict): if isinstance(edge_type, tuple): stored_values = metadata.edge_index_values.get(edge_type) elif not is_hetero and isinstance(metadata.edge_index_values, list): stored_values = metadata.edge_index_values if not stored_values: return None # Determine number of rows based on edge data or stored values num_rows = len(next(iter(edge_data_dict.values()))) if edge_data_dict else len(stored_values[0]) # Handle MultiIndex case arrays = [stored_values[i][:num_rows] for i in range(len(stored_values))] return pd.MultiIndex.from_arrays(arrays) def _extract_edge_features( edge_data: Data | HeteroData, edge_type: str | tuple[str, str, str] | None, is_hetero: bool, metadata: GraphMetadata, ) -> dict[str, np.ndarray[tuple[int, ...], np.dtype[np.float32]]]: """ Extract edge features from edge data. This function extracts edge features from PyTorch Geometric data objects, handling both homogeneous and heterogeneous graphs. Parameters ---------- edge_data : Data or HeteroData PyTorch Geometric data object containing edge information. edge_type : str, tuple, or None The type of edges to extract features for. is_hetero : bool Whether the graph is heterogeneous. metadata : GraphMetadata Graph metadata containing edge feature column information. Returns ------- dict Dictionary mapping feature column names to numpy arrays containing the extracted edge features. See Also -------- _create_edge_geometries : Create edge geometries from edge indices. Examples -------- >>> import torch >>> from torch_geometric.data import Data >>> data = Data(edge_attr=torch.tensor([[1.0, 2.0], [3.0, 4.0]])) >>> features = _extract_edge_features(data, None, False, metadata) """ edge_data_dict: dict[str, np.ndarray[tuple[int, ...], np.dtype[np.float32]]] = {} if not (hasattr(edge_data, "edge_attr") and edge_data.edge_attr is not None): return edge_data_dict feature_cols = metadata.edge_feature_cols # Determine column names based on graph type cols = None if is_hetero and isinstance(edge_type, tuple) and isinstance(feature_cols, dict): rel_type = edge_type[1] cols = feature_cols.get(rel_type) elif not is_hetero and isinstance(feature_cols, list): cols = feature_cols features_dict = _extract_tensor_data(edge_data.edge_attr, cols) edge_data_dict.update(features_dict) return edge_data_dict def _create_edge_geometries( edge_data: Data, edge_type: str | tuple[str, str, str] | None, is_hetero: bool, data: Data | HeteroData, ) -> gpd.array.GeometryArray | None: """ Create edge geometries from edge indices and node positions. This function creates LineString geometries for edges by connecting the positions of source and destination nodes. Parameters ---------- edge_data : Data PyTorch Geometric data object containing edge information. edge_type : str, tuple, or None The type of edges to create geometries for. is_hetero : bool Whether the graph is heterogeneous. data : Data or HeteroData Complete PyTorch Geometric data object containing node positions. Returns ------- gpd.array.GeometryArray or None Array of LineString geometries for the edges, or None if node positions are not available. See Also -------- _extract_edge_features : Extract edge features from data. Examples -------- >>> import torch >>> from torch_geometric.data import Data >>> data = Data(edge_index=torch.tensor([[0, 1], [1, 0]]), ... pos=torch.tensor([[0.0, 0.0], [1.0, 1.0]])) >>> geom = _create_edge_geometries(data, None, False, data) """ # Get edge index array edge_index_array = edge_data.edge_index.detach().cpu().numpy() # Set default positions as None src_pos_array: np.ndarray[tuple[int, ...], np.dtype[np.float64]] | None = None dst_pos_array: np.ndarray[tuple[int, ...], np.dtype[np.float64]] | None = None # If hetero and specific edge type, get source and destination positions if is_hetero and isinstance(edge_type, tuple) and len(edge_type) == 3: src_type, _, dst_type = edge_type if hasattr(data[src_type], "pos") and data[src_type].pos is not None: src_pos_array = data[src_type].pos.detach().cpu().numpy() if hasattr(data[dst_type], "pos") and data[dst_type].pos is not None: dst_pos_array = data[dst_type].pos.detach().cpu().numpy() # If not hetero or no specific edge type, use default positions elif hasattr(data, "pos") and data.pos is not None: pos_array: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = ( data.pos.detach().cpu().numpy() ) src_pos_array = pos_array dst_pos_array = pos_array if src_pos_array is None or dst_pos_array is None: return None geometries = _create_linestring_geometries(edge_index_array, src_pos_array, dst_pos_array) return gpd.array.from_shapely(geometries) def _set_edge_index_names( gdf: gpd.GeoDataFrame, edge_type: str | tuple[str, str, str] | None, is_hetero: bool, metadata: GraphMetadata, ) -> None: """ Set index names on edge GeoDataFrame. This function configures the index names for an edge GeoDataFrame based on metadata information. Parameters ---------- gdf : gpd.GeoDataFrame The edge GeoDataFrame to configure. edge_type : str, tuple, or None The type of edges in the GeoDataFrame. is_hetero : bool Whether the graph is heterogeneous. metadata : GraphMetadata Graph metadata containing edge index name information. See Also -------- _reconstruct_edge_gdf : Reconstruct edge GeoDataFrame from PyTorch data. Examples -------- >>> import geopandas as gpd >>> gdf = gpd.GeoDataFrame({'col1': [1, 2]}) >>> _set_edge_index_names(gdf, 'road', False, metadata) """ index_names: list[str] | None = None if is_hetero and edge_type and isinstance(metadata.edge_index_names, dict): if isinstance(edge_type, tuple): index_names = metadata.edge_index_names.get(edge_type) elif not is_hetero and isinstance(metadata.edge_index_names, list): index_names = metadata.edge_index_names if ( hasattr(gdf.index, "names") and isinstance(index_names, list) and len(index_names) > 1 and isinstance(gdf.index, pd.MultiIndex) ): gdf.index.names = index_names def _reconstruct_edge_gdf( data: Data | HeteroData, metadata: GraphMetadata, edge_type: str | tuple[str, str, str] | None = None, ) -> gpd.GeoDataFrame: """ Reconstruct edge GeoDataFrame from PyTorch Geometric data. This function reconstructs a GeoDataFrame containing edge information from PyTorch Geometric data objects and metadata. Parameters ---------- data : Data or HeteroData PyTorch Geometric data object containing edge information. metadata : GraphMetadata Graph metadata with mapping and feature information. edge_type : str, tuple, or None, optional The type of edges to reconstruct. Required for heterogeneous graphs. Returns ------- gpd.GeoDataFrame GeoDataFrame containing reconstructed edge data with geometry, features, and proper indexing. See Also -------- _extract_edge_features : Extract edge features from data. _create_edge_geometries : Create edge geometries from edge indices. Examples -------- >>> from torch_geometric.data import Data >>> import torch >>> data = Data(edge_index=torch.tensor([[0, 1], [1, 0]])) >>> gdf = _reconstruct_edge_gdf(data, metadata) """ is_hetero = metadata.is_hetero edge_data = data[edge_type] if is_hetero and edge_type else data # Extract edge features edge_data_dict = _extract_edge_features(edge_data, edge_type, is_hetero, metadata) # Create geometries from edge indices and node positions geometry = _create_edge_geometries(edge_data, edge_type, is_hetero, data) # Reconstruct index from stored values edge_data_dict = _extract_edge_features(edge_data, edge_type, is_hetero, metadata) # Create geometries from edge indices and node positions geometry = _create_edge_geometries(edge_data, edge_type, is_hetero, data) # Reconstruct index from stored values index_values = _reconstruct_edge_index(edge_type, is_hetero, edge_data_dict, metadata) # Create GeoDataFrame with geometry if geometry is not None: gdf = gpd.GeoDataFrame(edge_data_dict, geometry=geometry, index=index_values) else: # If no geometry, create an empty GeoSeries for the geometry column # and explicitly set its CRS if metadata.crs is available. empty_geom = gpd.GeoSeries([], crs=metadata.crs if metadata.crs else None) gdf = gpd.GeoDataFrame(edge_data_dict, geometry=empty_geom, index=index_values) # Set index names if available _set_edge_index_names(gdf, edge_type, is_hetero, metadata) # Set CRS if metadata.crs: # Check if the geometry column is truly empty or all null if gdf.empty or (gdf.geometry is not None and gdf.geometry.isna().all()): gdf.crs = metadata.crs else: # Use set_crs for non-empty geometries gdf.set_crs(metadata.crs, allow_override=True, inplace=True) return gdf # ============================================================================ # NETWORKX CONVERSION HELPERS # ============================================================================ def _add_homo_nodes_to_graph(graph: nx.Graph, data: Data) -> None: """ Add homogeneous nodes to NetworkX graph. This function adds nodes from homogeneous PyTorch Geometric data to a NetworkX graph with their attributes. Parameters ---------- graph : nx.Graph NetworkX graph to add nodes to. data : Data PyTorch Geometric data object containing node information. See Also -------- _add_homo_edges_to_graph : Add homogeneous edges to NetworkX graph. Examples -------- >>> import networkx as nx >>> import torch >>> from torch_geometric.data import Data >>> graph = nx.Graph() >>> data = Data(x=torch.tensor([[1.0, 2.0]])) >>> _add_homo_nodes_to_graph(graph, data) """ metadata = data.graph_metadata node_mapping_info = metadata.node_mappings.get("default", {}) original_ids = node_mapping_info.get("original_ids", []) num_nodes = data.x.size(0) # Prepare base attributes attrs_df = pd.DataFrame( { "_original_index": [ original_ids[i] if i < len(original_ids) else i for i in range(num_nodes) ], }, ) # Add positions using vectorized operations if hasattr(data, "pos") and data.pos is not None: pos_np: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = data.pos.detach().cpu().numpy() attrs_df["pos"] = [tuple(pos_np[i]) for i in range(min(num_nodes, len(pos_np)))] # Add features using vectorized operations if hasattr(data, "x") and data.x is not None: x_np: np.ndarray[tuple[int, ...], np.dtype[np.float32]] = data.x.detach().cpu().numpy() feature_cols = metadata.node_feature_cols or [f"feat_{j}" for j in range(x_np.shape[1])] for j, col_name in enumerate(feature_cols[: x_np.shape[1]]): attrs_df[col_name] = x_np[:, j] # Add labels using vectorized operations if hasattr(data, "y") and data.y is not None: y_np: np.ndarray[tuple[int, ...], np.dtype[np.float32]] = data.y.detach().cpu().numpy() label_cols = metadata.node_label_cols or [f"label_{j}" for j in range(y_np.shape[1])] for j, col_name in enumerate(label_cols[: y_np.shape[1]]): attrs_df[col_name] = y_np[:, j] # Add nodes in bulk graph.add_nodes_from([(i, attrs_df.iloc[i].to_dict()) for i in range(num_nodes)]) def _add_homo_edges_to_graph(graph: nx.Graph, data: Data) -> None: """ Add homogeneous edges to NetworkX graph. This function adds edges from homogeneous PyTorch Geometric data to a NetworkX graph with their attributes. Parameters ---------- graph : nx.Graph NetworkX graph to add edges to. data : Data PyTorch Geometric data object containing edge information. See Also -------- _add_homo_nodes_to_graph : Add homogeneous nodes to NetworkX graph. Examples -------- >>> import networkx as nx >>> import torch >>> from torch_geometric.data import Data >>> graph = nx.Graph() >>> data = Data(edge_index=torch.tensor([[0, 1], [1, 0]])) >>> _add_homo_edges_to_graph(graph, data) """ metadata = data.graph_metadata edge_feature_cols = metadata.edge_feature_cols original_edge_indices = metadata.edge_index_values edge_index = data.edge_index.detach().cpu().numpy() num_edges = edge_index.shape[1] # Initialize attributes DataFrame attrs_df = pd.DataFrame(index=range(num_edges)) # Add edge attributes if available if hasattr(data, "edge_attr") and data.edge_attr is not None: edge_attrs_np = data.edge_attr.detach().cpu().numpy() columns = edge_feature_cols or [f"edge_feat_{j}" for j in range(edge_attrs_np.shape[1])] edge_attrs_df = pd.DataFrame(edge_attrs_np, columns=columns) attrs_df = pd.concat([attrs_df, edge_attrs_df], axis=1) # Add original edge indices if available if original_edge_indices: attrs_df["_original_edge_index"] = list(zip(*original_edge_indices, strict=True)) # Convert to list of dictionaries and add edges in bulk attrs_list = attrs_df.to_dict("records") src_nodes = edge_index[0] dst_nodes = edge_index[1] graph.add_edges_from(zip(src_nodes, dst_nodes, attrs_list, strict=True)) def _add_hetero_nodes_to_graph(graph: nx.Graph, data: HeteroData) -> dict[str, int]: """ Add heterogeneous nodes to NetworkX graph and return node offsets. This function adds nodes from heterogeneous PyTorch Geometric data to a NetworkX graph and tracks node type offsets. Parameters ---------- graph : nx.Graph NetworkX graph to add nodes to. data : HeteroData PyTorch Geometric heterogeneous data object containing node information. Returns ------- dict[str, int] Dictionary mapping node types to their starting offsets in the graph. See Also -------- _add_hetero_edges_to_graph : Add heterogeneous edges to NetworkX graph. Examples -------- >>> import networkx as nx >>> from torch_geometric.data import HeteroData >>> graph = nx.Graph() >>> data = HeteroData() >>> offsets = _add_hetero_nodes_to_graph(graph, data) """ node_offset = {} current_offset = 0 metadata = data.graph_metadata for node_type in metadata.node_types: node_offset[node_type] = current_offset node_data = data[node_type] num_nodes = node_data.num_nodes # Get original node IDs and prepare base attributes node_mapping_info = metadata.node_mappings.get(node_type, {}) original_ids = node_mapping_info.get("original_ids", list(range(num_nodes))) attrs_df = pd.DataFrame( { "node_type": node_type, "_original_index": [ original_ids[i] if i < len(original_ids) else i for i in range(num_nodes) ], }, ) # Add positions using vectorized operations if hasattr(node_data, "pos") and node_data.pos is not None: pos_np = node_data.pos.detach().cpu().numpy() attrs_df["pos"] = [tuple(pos_np[i]) for i in range(min(num_nodes, len(pos_np)))] # Add features using vectorized operations if hasattr(node_data, "x") and node_data.x is not None: x_np = node_data.x.detach().cpu().numpy() # Handle the type union for node_feature_cols feature_cols = metadata.node_feature_cols.get(node_type) or [ f"feat_{j}" for j in range(x_np.shape[1]) ] for j, col_name in enumerate(feature_cols[: x_np.shape[1]]): attrs_df[col_name] = x_np[:, j] # Add labels using vectorized operations if hasattr(node_data, "y") and node_data.y is not None: y_np = node_data.y.detach().cpu().numpy() # Handle the type union for node_label_cols label_cols = metadata.node_label_cols.get(node_type) or [ f"label_{j}" for j in range(y_np.shape[1]) ] for j, col_name in enumerate(label_cols[: y_np.shape[1]]): attrs_df[col_name] = y_np[:, j] # Add nodes in bulk graph.add_nodes_from( [(current_offset + i, attrs_df.iloc[i].to_dict()) for i in range(num_nodes)], ) current_offset += num_nodes return node_offset def _add_hetero_edges_to_graph( graph: nx.Graph, data: HeteroData, node_offset: dict[str, int], ) -> None: """ Add heterogeneous edges to NetworkX graph. This function adds edges from heterogeneous PyTorch Geometric data to a NetworkX graph using node offsets for proper indexing. Parameters ---------- graph : nx.Graph NetworkX graph to add edges to. data : HeteroData PyTorch Geometric heterogeneous data object containing edge information. node_offset : dict[str, int] Dictionary mapping node types to their starting offsets in the graph. See Also -------- _add_hetero_nodes_to_graph : Add heterogeneous nodes to NetworkX graph. Examples -------- >>> import networkx as nx >>> from torch_geometric.data import HeteroData >>> graph = nx.Graph() >>> offsets = {'building': 0, 'road': 100} >>> _add_hetero_edges_to_graph(graph, data, offsets) """ metadata = data.graph_metadata for edge_type in metadata.edge_types: src_type, rel_type, dst_type = edge_type edge_store = data[edge_type] edge_index = edge_store.edge_index.detach().cpu().numpy() num_edges = edge_index.shape[1] # Create attributes DataFrame using helper functions attrs_df = _create_edge_attrs_dataframe( edge_store, metadata, rel_type, edge_type, num_edges, ) # Add relation type and convert to records attrs_df["edge_type"] = rel_type attrs_list = attrs_df.to_dict("records") # Add edges with offset adjustments src_nodes = edge_index[0] + node_offset[src_type] dst_nodes = edge_index[1] + node_offset[dst_type] graph.add_edges_from(zip(src_nodes, dst_nodes, attrs_list, strict=True)) def _create_edge_attrs_dataframe( edge_store: Data, metadata: GraphMetadata, rel_type: str, edge_type: tuple[str, str, str], num_edges: int, ) -> pd.DataFrame: """ Create edge attributes DataFrame with features and original indices. This function extracts edge attributes from a PyTorch Geometric edge store and creates a pandas DataFrame with feature columns and original edge indices. Parameters ---------- edge_store : Data PyTorch Geometric Data object containing edge information. metadata : GraphMetadata Metadata object containing graph structure information. rel_type : str Relationship type identifier for the edges. edge_type : tuple[str, str, str] Tuple specifying the edge type (source_type, relation, target_type). num_edges : int Number of edges in the edge store. Returns ------- pd.DataFrame DataFrame containing edge attributes and original indices. See Also -------- _get_edge_attrs_array : Extract edge attributes array from edge store. _get_edge_feature_columns : Get feature column names. Examples -------- >>> edge_store = Data(edge_attr=torch.randn(100, 5)) >>> metadata = GraphMetadata(...) >>> df = _create_edge_attrs_dataframe(edge_store, metadata, "connects", ... ("node", "connects", "node"), 100) """ # Start with base DataFrame attrs_df = pd.DataFrame(index=range(num_edges)) # Add edge features if available edge_attrs_array = _get_edge_attrs_array(edge_store) if edge_attrs_array is not None: feature_columns = _get_edge_feature_columns(metadata, rel_type, edge_attrs_array.shape[1]) feature_df = pd.DataFrame(edge_attrs_array, columns=feature_columns) attrs_df = pd.concat([attrs_df, feature_df], axis=1) # Add original edge indices if available original_indices = None if isinstance(metadata.edge_index_values, dict): original_indices = metadata.edge_index_values.get(edge_type) if original_indices: attrs_df["_original_edge_index"] = list(zip(*original_indices, strict=True)) return attrs_df def _get_edge_attrs_array( edge_store: Data, ) -> np.ndarray[tuple[int, ...], np.dtype[np.float32]] | None: """ Extract edge attributes array from edge store, or None if not available. This function safely extracts the edge attribute tensor from a PyTorch Geometric Data object and converts it to a NumPy array. Returns None if no edge attributes are present. Parameters ---------- edge_store : Data PyTorch Geometric Data object that may contain edge attributes. Returns ------- np.ndarray or None Edge attributes as a NumPy array of shape (num_edges, num_features), or None if no edge attributes are available. See Also -------- _create_edge_attrs_dataframe : Create edge attributes DataFrame. Examples -------- >>> edge_store = Data(edge_attr=torch.randn(100, 5)) >>> attrs = _get_edge_attrs_array(edge_store) >>> attrs.shape (100, 5) """ return ( edge_store.edge_attr.detach().cpu().numpy() if hasattr(edge_store, "edge_attr") and edge_store.edge_attr is not None else None ) def _get_edge_feature_columns( metadata: GraphMetadata, rel_type: str, num_features: int, ) -> list[str]: """ Get feature column names, using metadata or generating defaults. This function retrieves edge feature column names from metadata if available, or generates default column names based on the number of features. Parameters ---------- metadata : GraphMetadata Metadata object containing graph structure information. rel_type : str Relationship type identifier for the edges. num_features : int Number of edge features. Returns ------- list[str] List of column names for edge features. See Also -------- _create_edge_attrs_dataframe : Create edge attributes DataFrame. Examples -------- >>> metadata = GraphMetadata(...) >>> cols = _get_edge_feature_columns(metadata, "connects", 5) >>> cols ['edge_feat_0', 'edge_feat_1', 'edge_feat_2', 'edge_feat_3', 'edge_feat_4'] """ feature_cols = None if isinstance(metadata.edge_feature_cols, dict): feature_cols = metadata.edge_feature_cols.get(rel_type) # For heterogeneous graphs, edge_feature_cols should be dict or None, not list # If it's a list, we ignore it as it indicates homogeneous usage return feature_cols or [f"edge_feat_{j}" for j in range(num_features)] def _convert_homo_pyg_to_nx(data: Data, metadata: GraphMetadata) -> nx.Graph: """ Convert homogeneous PyG data to NetworkX graph. This function converts a homogeneous PyTorch Geometric Data object to a NetworkX Graph, preserving node and edge attributes along with metadata. Parameters ---------- data : Data Homogeneous PyTorch Geometric Data object to convert. metadata : GraphMetadata Metadata object containing graph structure information. Returns ------- nx.Graph NetworkX graph with nodes, edges, and attributes from the PyG data. See Also -------- _convert_hetero_pyg_to_nx : Convert heterogeneous PyG data to NetworkX. _add_homo_nodes_to_graph : Add homogeneous nodes to graph. _add_homo_edges_to_graph : Add homogeneous edges to graph. Examples -------- >>> data = Data(x=torch.randn(100, 10), edge_index=torch.randint(0, 100, (2, 200))) >>> metadata = GraphMetadata(...) >>> graph = _convert_homo_pyg_to_nx(data, metadata) >>> len(graph.nodes) 100 """ graph = nx.Graph() # Add metadata graph.graph["crs"] = metadata.crs graph.graph["is_hetero"] = False # Add nodes and edges _add_homo_nodes_to_graph(graph, data) _add_homo_edges_to_graph(graph, data) # Store index information for reconstruction graph.graph["node_index_names"] = metadata.node_index_names graph.graph["edge_index_names"] = metadata.edge_index_names return graph def _convert_hetero_pyg_to_nx(data: HeteroData, metadata: GraphMetadata) -> nx.Graph: """ Convert heterogeneous PyG data to NetworkX graph. This function converts a heterogeneous PyTorch Geometric HeteroData object to a NetworkX Graph, flattening the heterogeneous structure while preserving node and edge attributes along with type information. Parameters ---------- data : HeteroData Heterogeneous PyTorch Geometric HeteroData object to convert. metadata : GraphMetadata Metadata object containing graph structure information. Returns ------- nx.Graph NetworkX graph with nodes, edges, and attributes from the hetero PyG data. See Also -------- _convert_homo_pyg_to_nx : Convert homogeneous PyG data to NetworkX. _add_hetero_nodes_to_graph : Add heterogeneous nodes to graph. _add_hetero_edges_to_graph : Add heterogeneous edges to graph. Examples -------- >>> data = HeteroData() >>> data['node'].x = torch.randn(100, 10) >>> data['edge'].x = torch.randn(50, 5) >>> metadata = GraphMetadata(...) >>> graph = _convert_hetero_pyg_to_nx(data, metadata) >>> graph.graph['is_hetero'] True """ graph = nx.Graph() # Add metadata graph.graph["crs"] = metadata.crs graph.graph["is_hetero"] = True graph.graph["node_types"] = metadata.node_types graph.graph["edge_types"] = metadata.edge_types # Store metadata for reconstruction graph.graph["metadata"] = metadata # Add nodes and edges node_offset = _add_hetero_nodes_to_graph(graph, data) _add_hetero_edges_to_graph(graph, data, node_offset) graph.graph["node_offset"] = node_offset return graph