Development process
The development process of GNNs algorithm based on pytorch includes the following steps.
Construct node table and edge table data and build graph.
Write GSL query and complete subgraph sampling
Use Dataset and DataLoader to process the data of GSL
Write the model
Train/Predict
GraphLearn is compatible with the open source framework pyG(https://github.com/rusty1s/pytorch_geometric),
Therefore, when you want to develop GNNs algorithm based on pyG, you need to implement an induce_func after writing GSL to complete the conversion of gl.nn.Data
dict generated by GSL to Data
of pyG. Then you can use the PyGDataLoader
encapsulated by GraphLearn to merge Data
into Batch
objects, and then the later part of the model can be implemented directly using pyG.
Data layer
Corresponds to nn/pytorch/data
When writing the model using GraphLearn, you first need to construct the graph and write the GSL graph query query, .
The meta-path to be sampled is described by GSL, and then a data stream that returns numpy ndarry can be generated. To facilitate the use of the model layer, GraphLearn implements pytorch’s Dataset
to convert the GSL query into a tensor format gl.nn.Data
of dict or induce it into a pyG Data
object. Then you can traverse through pytorch’s DataLoader to get the data.
Dataset
Dataset provides two basic functions.
Generic data conversion: converts GSL data into a
gl.nn.Data
dict with tensor format, and then provides theas_dict
interface for pytorch’storch.utils.data.DataLoader
to convert eachData
object into a dict. The final traversal returns a large dict with the elements of theData
converted dict.Convert to pyG data: construct a subgraph via the custom
induce_func
, i.e., convert the numpygl.nn.Data
dict into a list of pyGData
of size batch_size (specified in GSL). and then merge theData
list into a pyGBatch
object via thePyGDataLoader
.
class Dataset(th.utils.data.IterableDataset):
def __init__(self, query, window=5, induce_func=None):
"""Dataset reformats the sampled batch from GSL query as `Data` object
consists of Pytorch Tensors.
Args:
query: a GSL query.
window: the buffer size for query.
induce_func: A function that takes in the query result `Data` dict and
returns a list of subgraphs. For pyG, these subgraphs are pyG `Data`
objects.
"""
self._rds = RawDataset(query, window=window)
self._format = lambda x: x
self._induce_func = induce_func
def as_dict(self):
"""Convert each `Data` to dict of torch tensors.
This function is used for raw `DataLoader` of pytorch.
"""
PyGDataLoader
To facilitate the merging process of the Data list induced by Dataset, GraphLearn wraps a pyG-oriented PyGDataLoader
. Note, since the GraphLearn batch operation is already generated in GSL, One iteration of Dataset returns a batch of data, so we force batch_size=1 in the PyGDataLoader
implementation.
class Collater(object):
def __init__(self):
pass
def collate(self, batch):
batch = batch[0]
elem = batch[0]
if isinstance(elem, Data):
return Batch.from_data_list(batch)
def __call__(self, batch):
return self.collate(batch)
class PyGDataLoader(torch.utils.data.DataLoader):
"""pyG Data loader which merges a list of pyG `Data` objects induced
from a the `graphlearn.python.nn.pytorch.data.Dataset` to a pyG `Batch` object.
Args:
dataset (Dataset): The dataset to convert GSL and induce a list of pyG `Data` objects.
"""
def __init__(self, dataset, **kwargs):
if "batch_size" in kwargs:
del kwargs["batch_size"]
if "collate_fn" in kwargs:
del kwargs["collate_fn"]
super(PyGDataLoader, self).__init__(dataset, batch_size=1, collate_fn=Collater(), **kwargs)
Model layers
pyG
After using the Dataset
that provides induce_func
and the PyGDataLoader
, the returned data is the Batch
object of pyG, so you can directly reuse the model and layers of pyG.
Other
If you don’t want to use pyG, you can also manipulate the data based on the dict of gl.nn.Data
from Dataset, and then just write the model based on pytorch. Please contact us if you have good suggestions.
Example
The full example is available at examples/pytorch.