Graph#
Module for creating heterogeneous graph representations of urban environments.
This module provides comprehensive functionality for converting spatial data (GeoDataFrames and NetworkX objects) into PyTorch Geometric Data and HeteroData objects, supporting both homogeneous and heterogeneous graphs. It handles the complex mapping between geographical coordinates, node/edge features, and the tensor representations required by graph neural networks.
The module serves as a bridge between geospatial data analysis tools and deep learning frameworks, enabling seamless integration of spatial urban data with Graph Neural Networks (GNNs) for tasks of GeoAI such as urban modeling, traffic prediction, and spatial analysis.
- gdf_to_pyg(nodes, edges=None, node_feature_cols=None, node_label_cols=None, edge_feature_cols=None, device=None, dtype=None)[source]#
Convert GeoDataFrames (nodes/edges) to a PyTorch Geometric object.
This function serves as the main entry point for converting spatial data into PyTorch Geometric graph objects. It automatically detects whether to create homogeneous or heterogeneous graphs based on input structure. Node identifiers are taken from the GeoDataFrame index. Edge relationships are defined by a MultiIndex on the edge GeoDataFrame (source ID, target ID).
- Parameters:
nodes (dict[str, geopandas.GeoDataFrame] or geopandas.GeoDataFrame) – Node data. For homogeneous graphs, provide a single GeoDataFrame. For heterogeneous graphs, provide a dictionary mapping node type names to their respective GeoDataFrames. The index of these GeoDataFrames will be used as node identifiers.
edges (dict[tuple[str, str, str], geopandas.GeoDataFrame] or geopandas.GeoDataFrame, optional) – Edge data. For homogeneous graphs, provide a single GeoDataFrame. For heterogeneous graphs, provide a dictionary mapping edge type tuples (source_type, relation_type, target_type) to their GeoDataFrames. The GeoDataFrame must have a MultiIndex where the first level represents source node IDs and the second level represents target node IDs.
node_feature_cols (dict[str, list[str]] or list[str], optional) – Column names to use as node features. For heterogeneous graphs, provide a dictionary mapping node types to their feature columns.
node_label_cols (dict[str, list[str]] or list[str], optional) – Column names to use as node labels for supervised learning tasks. For heterogeneous graphs, provide a dictionary mapping node types to their label columns.
edge_feature_cols (dict[str, list[str]] or list[str], optional) – Column names to use as edge features. For heterogeneous graphs, provide a dictionary mapping relation types to their feature columns.
device (str or torch.device, optional) – Target device for tensor placement (‘cpu’, ‘cuda’, or torch.device). If None, automatically selects CUDA if available, otherwise CPU.
dtype (torch.dtype, optional) – Data type for float tensors (e.g., torch.float32, torch.float16). If None, uses torch.float32 (default PyTorch float type).
- Returns:
PyTorch Geometric Data object for homogeneous graphs or HeteroData object for heterogeneous graphs. The returned object contains:
Node features (x), positions (pos), and labels (y) if available
Edge connectivity (edge_index) and features (edge_attr) if available
Metadata for reconstruction including ID mappings and column names
- Return type:
torch_geometric.data.Data or torch_geometric.data.HeteroData
- Raises:
ImportError – If PyTorch Geometric is not installed.
ValueError – If input GeoDataFrames are invalid or incompatible.
See also
pyg_to_gdf
Convert PyTorch Geometric data back to GeoDataFrames.
nx_to_pyg
Convert NetworkX graph to PyTorch Geometric object.
city2graph.utils.validate_gdf
Validate GeoDataFrame structure.
Notes
This function automatically detects the graph type based on input structure. For heterogeneous graphs, provide dictionaries mapping types to GeoDataFrames. Node positions are automatically extracted from geometry centroids when available. - Preserves original coordinate reference systems (CRS) - Maintains index structure for bidirectional conversion - Handles both Point and non-Point geometries (using centroids) - Creates empty tensors for missing features/edges - For heterogeneous graphs, ensures consistent node/edge type mapping
Examples
Create a homogeneous graph from single GeoDataFrames:
>>> import geopandas as gpd >>> from city2graph.graph import gdf_to_pyg >>> >>> # Load and prepare node data >>> nodes_gdf = gpd.read_file("nodes.geojson").set_index("node_id") >>> edges_gdf = gpd.read_file("edges.geojson").set_index(["source_id", "target_id"]) >>> >>> # Convert to PyTorch Geometric >>> data = gdf_to_pyg(nodes_gdf, edges_gdf, ... node_feature_cols=['population', 'area'])
Create a heterogeneous graph from dictionaries:
>>> # Prepare heterogeneous data >>> buildings_gdf = buildings_gdf.set_index("building_id") >>> roads_gdf = roads_gdf.set_index("road_id") >>> connections_gdf = connections_gdf.set_index(["building_id", "road_id"]) >>> >>> # Define node and edge types >>> nodes_dict = {'building': buildings_gdf, 'road': roads_gdf} >>> edges_dict = {('building', 'connects', 'road'): connections_gdf} >>> >>> # Convert to heterogeneous graph with labels >>> data = gdf_to_pyg(nodes_dict, edges_dict, ... node_label_cols={'building': ['type'], 'road': ['category']})
- nx_to_pyg(graph, node_feature_cols=None, node_label_cols=None, edge_feature_cols=None, device=None, dtype=None)[source]#
Convert NetworkX graph to PyTorch Geometric Data object.
Converts a NetworkX Graph to a PyTorch Geometric Data object by first converting to GeoDataFrames then using the main conversion pipeline. This provides a bridge between NetworkX’s rich graph analysis tools and PyTorch Geometric’s deep learning capabilities.
- Parameters:
graph (networkx.Graph) – NetworkX graph to convert.
node_feature_cols (list[str], optional) – List of node attribute names to use as features.
node_label_cols (list[str], optional) – List of node attribute names to use as labels.
edge_feature_cols (list[str], optional) – List of edge attribute names to use as features.
device (torch.device or str, optional) – Target device for tensor placement (‘cpu’, ‘cuda’, or torch.device). If None, automatically selects CUDA if available, otherwise CPU.
dtype (torch.dtype, optional) – Data type for float tensors (e.g., torch.float32, torch.float16). If None, uses torch.float32 (default PyTorch float type).
- Returns:
PyTorch Geometric Data object for homogeneous graphs or HeteroData object for heterogeneous graphs. The returned object contains:
Node features (x), positions (pos), and labels (y) if available
Edge connectivity (edge_index) and features (edge_attr) if available
Metadata for reconstruction including ID mappings and column names
- Return type:
torch_geometric.data.Data or torch_geometric.data.HeteroData
- Raises:
ImportError – If PyTorch Geometric is not installed.
ValueError – If the NetworkX graph is invalid or empty.
See also
pyg_to_nx
Convert PyTorch Geometric data to NetworkX graph.
gdf_to_pyg
Convert GeoDataFrames to PyTorch Geometric object.
city2graph.utils.nx_to_gdf
Convert NetworkX graph to GeoDataFrames.
Notes
Uses intermediate GeoDataFrame conversion for consistency
Preserves all graph attributes and metadata
Handles spatial coordinates if present in node attributes
Maintains compatibility with existing city2graph workflows
Automatically creates geometry from ‘x’, ‘y’ coordinates if available
Examples
Convert a NetworkX graph with spatial data:
>>> import networkx as nx >>> from city2graph.graph import nx_to_pyg >>> >>> # Create NetworkX graph with spatial attributes >>> G = nx.Graph() >>> G.add_node(0, x=0.0, y=0.0, population=1000) >>> G.add_node(1, x=1.0, y=1.0, population=1500) >>> G.add_edge(0, 1, weight=0.5, road_type='primary') >>> >>> # Convert to PyTorch Geometric >>> data = nx_to_pyg(G, ... node_feature_cols=['population'], ... edge_feature_cols=['weight'])
Convert from graph analysis results:
>>> # Use NetworkX for analysis, then convert for ML >>> communities = nx.community.greedy_modularity_communities(G) >>> # Add community labels to nodes >>> for i, community in enumerate(communities): ... for node in community: ... G.nodes[node]['community'] = i >>> >>> # Convert with community labels >>> data = nx_to_pyg(G, node_label_cols=['community'])
- pyg_to_gdf(data, node_types=None, edge_types=None)[source]#
Convert PyTorch Geometric data to GeoDataFrames.
Reconstructs the original GeoDataFrame structure from PyTorch Geometric Data or HeteroData objects. This function provides bidirectional conversion capability, preserving spatial information, feature data, and metadata.
- Parameters:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData) – PyTorch Geometric data object to convert back to GeoDataFrames.
node_types (str or list[str], optional) – For heterogeneous graphs, specify which node types to reconstruct. If None, reconstructs all available node types.
edge_types (str or list[tuple[str, str, str]], optional) – For heterogeneous graphs, specify which edge types to reconstruct. Edge types are specified as (source_type, relation_type, target_type) tuples. If None, reconstructs all available edge types.
- Returns:
- For HeteroData input: Returns a tuple containing:
First element: dict[str, geopandas.GeoDataFrame] mapping node type names to GeoDataFrames
Second element: dict[tuple[str, str, str], geopandas.GeoDataFrame] mapping edge types to GeoDataFrames
- For Data input: Returns a tuple containing:
First element: geopandas.GeoDataFrame containing nodes
Second element: geopandas.GeoDataFrame containing edges (or None if no edges)
- Return type:
See also
gdf_to_pyg
Convert GeoDataFrames to PyTorch Geometric object.
pyg_to_nx
Convert PyTorch Geometric data to NetworkX graph.
Notes
Preserves original index structure and names when available
Reconstructs geometry from stored position tensors
Maintains coordinate reference system (CRS) information
Converts feature tensors back to named DataFrame columns
Handles both homogeneous and heterogeneous graph structures
Examples
Convert homogeneous PyTorch Geometric data back to GeoDataFrames:
>>> from city2graph.graph import pyg_to_gdf >>> >>> # Convert back to GeoDataFrames >>> nodes_gdf, edges_gdf = pyg_to_gdf(data)
Convert heterogeneous data with specific node types:
>>> # Convert only specific node types >>> node_gdfs, edge_gdfs = pyg_to_gdf(hetero_data, ... node_types=['building', 'road'])
- pyg_to_nx(data)[source]#
Convert a PyTorch Geometric object to a NetworkX graph.
Converts PyTorch Geometric Data or HeteroData objects to NetworkX graphs, preserving node and edge features as graph attributes. This enables compatibility with the extensive NetworkX ecosystem for graph analysis.
- Parameters:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData) – PyTorch Geometric data object to convert.
- Returns:
NetworkX graph with node and edge attributes from the PyG object. For heterogeneous graphs, node and edge types are stored as attributes.
- Return type:
- Raises:
ImportError – If PyTorch Geometric is not installed.
See also
nx_to_pyg
Convert NetworkX graph to PyTorch Geometric object.
pyg_to_gdf
Convert PyTorch Geometric data to GeoDataFrames.
Notes
Node features, positions, and labels are stored as node attributes
Edge features are stored as edge attributes
For heterogeneous graphs, type information is preserved
Geometry information is converted from tensor positions
Maintains compatibility with NetworkX analysis algorithms
Examples
Convert PyTorch Geometric data to NetworkX:
>>> from city2graph.graph import pyg_to_nx >>> import networkx as nx >>> >>> # Convert to NetworkX graph >>> nx_graph = pyg_to_nx(data) >>> >>> # Use NetworkX algorithms >>> centrality = nx.betweenness_centrality(nx_graph) >>> communities = nx.community.greedy_modularity_communities(nx_graph)
- is_torch_available()[source]#
Check if PyTorch Geometric is available.
This utility function checks whether the required PyTorch and PyTorch Geometric packages are installed and can be imported. It’s useful for conditional functionality and providing helpful error messages.
- Returns:
True if PyTorch Geometric can be imported, False otherwise.
- Return type:
See also
gdf_to_pyg
Convert GeoDataFrames to PyTorch Geometric (requires torch).
pyg_to_gdf
Convert PyTorch Geometric to GeoDataFrames (requires torch).
Notes
Returns False if either PyTorch or PyTorch Geometric is missing
Used internally by torch-dependent functions to provide helpful error messages
Examples
Check availability before using torch-dependent functions:
>>> from city2graph.graph import is_torch_available >>> >>> if is_torch_available(): ... from city2graph.graph import gdf_to_pyg ... data = gdf_to_pyg(nodes_gdf, edges_gdf) ... else: ... print("PyTorch Geometric not available.")
- validate_pyg(data)[source]#
Validate PyTorch Geometric Data or HeteroData objects and return metadata.
This centralized validation function performs comprehensive validation of PyG objects, including type checking, metadata validation, and structural consistency checks. It serves as the single point of validation for all PyG objects in city2graph.
- Parameters:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData) – PyTorch Geometric data object to validate.
- Returns:
Metadata object containing graph information for reconstruction.
- Return type:
GraphMetadata
- Raises:
ImportError – If PyTorch Geometric is not installed.
TypeError – If data is not a valid PyTorch Geometric object.
ValueError – If the data object is missing required metadata or is inconsistent.
See also
pyg_to_gdf
Convert PyG objects to GeoDataFrames.
pyg_to_nx
Convert PyG objects to NetworkX graphs.
Examples
>>> data = create_homogeneous_graph(nodes_gdf) >>> metadata = _validate_pyg_data(data)