Development process

The development process of GNNs algorithm based on pytorch includes the following steps.

  1. Construct node table and edge table data and build graph.

  2. Write GSL query and complete subgraph sampling

  3. Use Dataset and DataLoader to process the data of GSL

  4. Write the model

  5. Train/Predict

GraphLearn is compatible with the open source framework pyG(https://github.com/rusty1s/pytorch_geometric), Therefore, when you want to develop GNNs algorithm based on pyG, you need to implement an induce_func after writing GSL to complete the conversion of gl.nn.Data dict generated by GSL to Data of pyG. Then you can use the PyGDataLoader encapsulated by GraphLearn to merge Data into Batch objects, and then the later part of the model can be implemented directly using pyG.

Data layer

Corresponds to nn/pytorch/data

When writing the model using GraphLearn, you first need to construct the graph and write the GSL graph query query, . The meta-path to be sampled is described by GSL, and then a data stream that returns numpy ndarry can be generated. To facilitate the use of the model layer, GraphLearn implements pytorch’s Dataset to convert the GSL query into a tensor format gl.nn.Data of dict or induce it into a pyG Data object. Then you can traverse through pytorch’s DataLoader to get the data.

Dataset

Dataset provides two basic functions.

  • Generic data conversion: converts GSL data into a gl.nn.Data dict with tensor format, and then provides the as_dict interface for pytorch’s torch.utils.data.DataLoader to convert each Data object into a dict. The final traversal returns a large dict with the elements of the Data converted dict.

  • Convert to pyG data: construct a subgraph via the custom induce_func, i.e., convert the numpy gl.nn.Data dict into a list of pyG Data of size batch_size (specified in GSL). and then merge the Data list into a pyG Batch object via the PyGDataLoader.

class Dataset(th.utils.data.IterableDataset):
  def __init__(self, query, window=5, induce_func=None):
    """Dataset reformats the sampled batch from GSL query as `Data` object
    consists of Pytorch Tensors.
    Args:
      query: a GSL query.
      window: the buffer size for query.
      induce_func:  A function that takes in the query result `Data` dict and
        returns a list of subgraphs. For pyG, these subgraphs are pyG `Data`
        objects.
    """
    self._rds = RawDataset(query, window=window)
    self._format = lambda x: x
    self._induce_func = induce_func


  def as_dict(self):
    """Convert each `Data` to dict of torch tensors.
    This function is used for raw `DataLoader` of pytorch.
    """

PyGDataLoader

To facilitate the merging process of the Data list induced by Dataset, GraphLearn wraps a pyG-oriented PyGDataLoader. Note, since the GraphLearn batch operation is already generated in GSL, One iteration of Dataset returns a batch of data, so we force batch_size=1 in the PyGDataLoader implementation.

class Collater(object):
  def __init__(self):
    pass

  def collate(self, batch):
    batch = batch[0]
    elem = batch[0]
    if isinstance(elem, Data):
      return Batch.from_data_list(batch)

  def __call__(self, batch):
    return self.collate(batch)


class PyGDataLoader(torch.utils.data.DataLoader):
    """pyG Data loader which merges a list of pyG `Data` objects induced
    from a the `graphlearn.python.nn.pytorch.data.Dataset` to a pyG `Batch` object.

    Args:
      dataset (Dataset): The dataset to convert GSL and induce a list of pyG `Data` objects.
    """
    def __init__(self, dataset, **kwargs):
      if "batch_size" in kwargs:
        del kwargs["batch_size"]
      if "collate_fn" in kwargs:
        del kwargs["collate_fn"]
      super(PyGDataLoader, self).__init__(dataset, batch_size=1, collate_fn=Collater(), **kwargs)

Model layers

pyG

After using the Dataset that provides induce_func and the PyGDataLoader, the returned data is the Batch object of pyG, so you can directly reuse the model and layers of pyG.

Other

If you don’t want to use pyG, you can also manipulate the data based on the dict of gl.nn.Data from Dataset, and then just write the model based on pytorch. Please contact us if you have good suggestions.

Example

The full example is available at examples/pytorch.