Module gnnnas.models

Expand source code
# This Source Code Form is subject to the terms of the
# BSD 2-Clause "Simplified" License. If a copy of the same
# was not distributed with this file, You can obtain one at
# https://github.com/akhilpandey95/gnnNAS/blob/master/LICENSE.

import numpy as np
import torch
import torch_geometric as pyg


class MPNN(torch.nn.Module):
    """Creates an MPNN model in pytorch geometric"""

    def __init__(
        self,
        n_node_features: int,
        n_edge_features: int,
        n_hidden: int,
        n_output: int,
        MPNN_inp: torch.nn.Module,
        MPNN_hidden: torch.nn.Module,
        n_conv_blocks: int,
        skip_connection: str = "plain",
    ) -> None:
        """
        Build the MPNN model
        Parameters
        ----------
        arg1 | n_node_features: int
            Number of features at node level
        arg2 | n_edge_features: int
            Number of features at edge level
        arg3 | n_hidden: int
            Number of hidden activations
        arg4 | n_output: int
            Number of output activations
        arg5 | n_conv_blocks: int
            Number of convolutional kernels
        Returns
        -------
        Nothing
            None
        """
        # super class the class structure
        super().__init__()

        # set the growth dimension
        self.growth_dimension = n_hidden

        # encode the node information
        self.node_encoder = MPNN_inp(n_node_features, n_hidden)

        # add the ability to add one or more conv layers
        conv_blocks = []

        # ability to add one or more conv blocks
        for block in range(n_conv_blocks):
            if skip_connection == "dense":
                self.growth_dimension = n_hidden + (n_hidden * block)
            conv = MPNN_hidden(self.growth_dimension, n_hidden)
            norm = torch.nn.LayerNorm(n_hidden, elementwise_affine=True)
            act = torch.nn.ReLU(inplace=True)
            layer = pyg.nn.DeepGCNLayer(conv, norm, act, block=skip_connection)
            conv_blocks.append(layer)

        # group all the conv layers
        self.conv_layers = torch.nn.ModuleList(conv_blocks)

        # add the linear layers for flattening the output from MPNN
        self.flatten = torch.nn.Sequential(
            torch.nn.Linear(self.growth_dimension, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_output),
        )

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor, batch_idx: torch.Tensor
    ) -> torch.Tensor:
        """
        Process the MPNN model
        Parameters
        ----------
        arg1 | x: torch.Tensor
            Input features at node level
        arg2 | edge_index: torch.Tensor
            Index pairs of verticies
        arg3 | batch_idx: torch.Tensor
            Batch index
        Returns
        -------
        Tensor
            torch.Tensor
        """
        # obtaint the input
        if isinstance(self.node_encoder, pyg.nn.MessagePassing):
            x = self.node_encoder(x, edge_index)
        else:
            x = self.node_encoder(x)

        # pass the node information to the conv layer
        x = self.conv_layers[0].conv(x, edge_index)

        # process the layers
        for layer in range(len(self.conv_layers[1:])):
            x = self.conv_layers[layer](x, edge_index)

        # obtain the output from the MPNN final layer
        y = pyg.nn.global_add_pool(x, batch=batch_idx)

        # pass the output to the linear output layer
        out = self.flatten(y)

        # return the output
        return out

Classes

class MPNN (n_node_features: int, n_edge_features: int, n_hidden: int, n_output: int, MPNN_inp: torch.nn.modules.module.Module, MPNN_hidden: torch.nn.modules.module.Module, n_conv_blocks: int, skip_connection: str = 'plain')

Creates an MPNN model in pytorch geometric

Build the MPNN model Parameters


arg1 | n_node_features: int Number of features at node level arg2 | n_edge_features: int Number of features at edge level arg3 | n_hidden: int Number of hidden activations arg4 | n_output: int Number of output activations arg5 | n_conv_blocks: int Number of convolutional kernels Returns


Nothing
None
Expand source code
class MPNN(torch.nn.Module):
    """Creates an MPNN model in pytorch geometric"""

    def __init__(
        self,
        n_node_features: int,
        n_edge_features: int,
        n_hidden: int,
        n_output: int,
        MPNN_inp: torch.nn.Module,
        MPNN_hidden: torch.nn.Module,
        n_conv_blocks: int,
        skip_connection: str = "plain",
    ) -> None:
        """
        Build the MPNN model
        Parameters
        ----------
        arg1 | n_node_features: int
            Number of features at node level
        arg2 | n_edge_features: int
            Number of features at edge level
        arg3 | n_hidden: int
            Number of hidden activations
        arg4 | n_output: int
            Number of output activations
        arg5 | n_conv_blocks: int
            Number of convolutional kernels
        Returns
        -------
        Nothing
            None
        """
        # super class the class structure
        super().__init__()

        # set the growth dimension
        self.growth_dimension = n_hidden

        # encode the node information
        self.node_encoder = MPNN_inp(n_node_features, n_hidden)

        # add the ability to add one or more conv layers
        conv_blocks = []

        # ability to add one or more conv blocks
        for block in range(n_conv_blocks):
            if skip_connection == "dense":
                self.growth_dimension = n_hidden + (n_hidden * block)
            conv = MPNN_hidden(self.growth_dimension, n_hidden)
            norm = torch.nn.LayerNorm(n_hidden, elementwise_affine=True)
            act = torch.nn.ReLU(inplace=True)
            layer = pyg.nn.DeepGCNLayer(conv, norm, act, block=skip_connection)
            conv_blocks.append(layer)

        # group all the conv layers
        self.conv_layers = torch.nn.ModuleList(conv_blocks)

        # add the linear layers for flattening the output from MPNN
        self.flatten = torch.nn.Sequential(
            torch.nn.Linear(self.growth_dimension, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_output),
        )

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor, batch_idx: torch.Tensor
    ) -> torch.Tensor:
        """
        Process the MPNN model
        Parameters
        ----------
        arg1 | x: torch.Tensor
            Input features at node level
        arg2 | edge_index: torch.Tensor
            Index pairs of verticies
        arg3 | batch_idx: torch.Tensor
            Batch index
        Returns
        -------
        Tensor
            torch.Tensor
        """
        # obtaint the input
        if isinstance(self.node_encoder, pyg.nn.MessagePassing):
            x = self.node_encoder(x, edge_index)
        else:
            x = self.node_encoder(x)

        # pass the node information to the conv layer
        x = self.conv_layers[0].conv(x, edge_index)

        # process the layers
        for layer in range(len(self.conv_layers[1:])):
            x = self.conv_layers[layer](x, edge_index)

        # obtain the output from the MPNN final layer
        y = pyg.nn.global_add_pool(x, batch=batch_idx)

        # pass the output to the linear output layer
        out = self.flatten(y)

        # return the output
        return out

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x: torch.Tensor, edge_index: torch.Tensor, batch_idx: torch.Tensor) ‑> torch.Tensor

Process the MPNN model Parameters


arg1 | x: torch.Tensor Input features at node level arg2 | edge_index: torch.Tensor Index pairs of verticies arg3 | batch_idx: torch.Tensor Batch index Returns


Tensor
torch.Tensor
Expand source code
def forward(
    self, x: torch.Tensor, edge_index: torch.Tensor, batch_idx: torch.Tensor
) -> torch.Tensor:
    """
    Process the MPNN model
    Parameters
    ----------
    arg1 | x: torch.Tensor
        Input features at node level
    arg2 | edge_index: torch.Tensor
        Index pairs of verticies
    arg3 | batch_idx: torch.Tensor
        Batch index
    Returns
    -------
    Tensor
        torch.Tensor
    """
    # obtaint the input
    if isinstance(self.node_encoder, pyg.nn.MessagePassing):
        x = self.node_encoder(x, edge_index)
    else:
        x = self.node_encoder(x)

    # pass the node information to the conv layer
    x = self.conv_layers[0].conv(x, edge_index)

    # process the layers
    for layer in range(len(self.conv_layers[1:])):
        x = self.conv_layers[layer](x, edge_index)

    # obtain the output from the MPNN final layer
    y = pyg.nn.global_add_pool(x, batch=batch_idx)

    # pass the output to the linear output layer
    out = self.flatten(y)

    # return the output
    return out