图遍历
介绍
图遍历,在GNN里的语义有别于经典的图计算。主流深度学习算法的训练模式会按batch迭代。为了满足这种要求,数据要能够按batch访问,我们把这种数据的访问模式称为遍历。在GNN算法中,数据源为图,训练样本通常由图的顶点和边构成。图遍历是指为算法提供按batch获取顶点、边或子图的能力。
目前GL支持顶点和边的batch遍历。这种随机遍历可以是无放回的,也可以是有放回的。在无放回遍历中,每当一个epoch结束后都会触发gl.OutOfRangeError
。被遍历的数据源是划分后的,即当前worker(以分布式TF为例)只遍历与其对应的Server上的数据。
顶点遍历
用法
顶点的数据来源有3种:所有unique的顶点,所有边的源顶点,所有边的目的顶点。顶点遍历依托NodeSampler
算子实现,Graph对象的node_sampler()
接口返回一个NodeSampler
对象,再调用该对象的get()
接口返回Nodes
格式的数据。
def node_sampler(type, batch_size=64, strategy="by_order", node_from=gl.NODE):
"""
Args:
type(string): 当node_from为gl.NODE时,为顶点类型,否则为边类型;
batch_size(int): 每次遍历的顶点数
strategy(string): 可选值为"by_order"和"random",表示无放回遍历和随机遍历。当为"by_order"时,若触底后不足batch_size,则返回实际数量,若实际数量为0,则触发gl.OutOfRangeError
node_from: 数据来源,可选值为gl.NODE、gl.EDGE_SRC、gl.EDGE_DST;
Return:
NodeSampler对象
"""
def NodeSampler.get():
"""
Return:
Nodes对象,若非触底,预期ids的shape为[batch_size]
"""
通过Nodes
对象获取具体的值,如id、weight、attribute等,参考 接口。在GSL中,顶点遍历参考g.V()
。
示例
user顶点表:
id | attributes |
---|---|
10001 | 0:0.1:0 |
10002 | 1:0.2:3 |
10003 | 3:0.3:4 |
buy边表:
src_id | dst_id | attributes |
---|---|---|
10001 | 1 | 0.1 |
10001 | 2 | 0.2 |
10001 | 3 | 0.4 |
10002 | 1 | 0.1 |
# Exmaple1: 随机采样顶点。
sampler1 = g.node_sampler("user", batch_size=3, strategy="random")
for i in range(5):
nodes = sampler1.get()
print(nodes.ids) # shape=(3, )
print(nodes.int_attrs) # shape=(3, 2),有2个int属性
print(nodes.float_attrs) # shape=(3, 1),有1个float属性
# Exmaple2: 遍历图中的user顶点
sampler2 = g.node_sampler("user", batch_size=3, strategy="by_order")
while True:
try:
nodes = sampler1.get()
print(nodes.ids) # 除最后一个batch外,shape为(3, ),最后一个batch的shape为剩余的id数
print(nodes.int_attrs)
print(nodes.float_attrs)
except gl.OutOfRangError:
break
# Exmaple3: 遍历图中的buy边的源顶点,即user顶点,为unique的
sampler2 = g.node_sampler("user", batch_size=3, strategy="by_order", node_from=gl.EDGE_SRC)
while True:
try:
nodes = sampler1.get()
print(nodes.ids) # shape=(2, ),由于buy边表中src_id只有2个unique的值,不满batch_size 3,因此这个循环只进行了一次
print(nodes.int_attrs)
print(nodes.float_attrs)
except gl.OutOfRangError:
break
边遍历
用法
边遍历依托EdgeSampler
算子实现。Graph对象的edge_sampler()
接口返回一个EdgeSampler
对象,再调用该对象的get()
接口返回Edges
格式的数据。
def edge_sampler(edge_type, batch_size=64, strategy="by_order"):
"""
Args:
edge_type(string): 边类型
batch_size(int): 每次遍历的边数
strategy(string): 可选值为"by_order"和"random",表示无放回遍历和随机遍历。当为"by_order"时,若触底后不足batch_size,则返回实际数量,若实际数量为0,则触发gl.OutOfRangeError
Return:
EdgeSampler对象
"""
def EdgeSampler.get():
"""
Return:
Edges对象,若非触底,预期src_ids的shape为[batch_size]
"""
通过Edges
对象获取具体的值,如id、weight、attribute等,参考 接口。在GSL中,边遍历参考g.E()
。
示例
src_id | dst_id | weight | attributes |
---|---|---|---|
20001 | 30001 | 0.1 | 0.10,0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18,0.19 |
20001 | 30003 | 0.2 | 0.20,0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29 |
20003 | 30001 | 0.3 | 0.30,0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,0.39 |
20004 | 30002 | 0.4 | 0.40,0.41,0.42,0.43,0.44,0.45,0.46,0.47,0.48,0.49 |
sampler = g.edge_sampler("buy", batch_size=3, strategy="random")
for i in range(5):
edges = sampler.get()
print(edges.src_ids)
print(edges.src_ids)
print(edges.weights)
print(edges.float_attrs)