图神经网络DGL库之消息传递
- 1 消息传递
- 1.1 图解
- 1.2 语法格式
- 1.2.1 message函数
- 1.2.2 reduce函数
- 1.2.3 update函数
- 1.2.4 apply_nodes函数
- 1.2.5 apply_edges函数
- 2 具体例子
- 2.1 建图
- 2.2 消息传递
- 2.2.1 函数构造
- 2.2.2 边更新
- 2.2.3 节点更新
- 2.2.4 消息聚合
- 1 未使用更新函数
- 2 使用更新函数
1 消息传递
1.1 图解
对上图的绿色框的函数进行解释:
- 消息函数(message function):消息函数可以接收源节点的e.src.data,边的e.data以及目标节点的e.dst.data,之后将三者数据进行一些操作(例如加和),最终将数据存放在Mailbox
- apply_nodes函数:可以使用目标节点的e.dst.data数据进行一些操作(例如e.dst.data+1)
- 聚合函数(reduce function):可以获取目标节点以及Mailbox数据。将Mailbox数据提取出来,并清空Mailbox,之后更新目标节点。
对上图未提及的函数进行说明:
- apply_edges函数:可以将消息函数操作后的数据附加在边上
- update_all函数(更新):启用消息函数和聚合函数,即开始更新节点的流程(消息传递+消息聚合)。
1.2 语法格式
1.2.1 message函数
message函数采用单个参数edges(具有三个成员src,dst和data)分别用于访问源节点,目标节点和边的特征,如下:
def message_func(edges):
1.2.2 reduce函数
reduce函数采用单个参数节点nodes。 节点的成员属性mailbox可以用来访问节点收到的信息,然后做一些运算
- 一些最常见的聚合运算包括sum,max,min等
如下:
def reducer(nodes):
1.2.3 update函数
调用节点计算的接口是update_all(),它在单个API调用里合并了消息生成、消息聚合和节点特征更新。update_all的参数是消息函数,reduce函数和更新函数。
- 更新函数是可选择参数,用户可以不使用,DGL不推荐在 update_all 中指定更新函数
- 该函数相当于开始更新节点的流程(消息传递+消息聚合+节点特征更新)。
1.2.4 apply_nodes函数
语法格式:
DGLGraph.apply_nodes(func, v='__ALL__', ntype=None, inplace=False)
参数解释如下:
- func:用于更新节点特征的函数。
- v:默认是更新所有节点。
- ntype:可选,节点类型名称。如果图中只有一个类型的节点,则可以省略。
- 最后一个已弃用
1.2.5 apply_edges函数
DGLGraph.apply_edges(func, edges='__ALL__', etype=None, inplace=False)
参数解释如下:
- func:用于生成新的边特征。
- v:默认是更新所有边。
- ntype:可选,边类型名称。如果图中只有一个类型的边,则可以省略。
- 最后一个已弃用
2 具体例子
2.1 建图
示例图如下:
建图代码如下:
import dgl
import torchg = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])
2.2 消息传递
2.2.1 函数构造
该消息传递方式将源节点的特征和边的特征进行聚合
def message_func(edges):# 常规属性操作如下:# print('edges.data:',edges.data)# print('edges.src:',edges.src)# print('edges.dst:',edges.dst)tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}return tmp
2.2.2 边更新
将消息传递函数应用在边上,更新边的特征,代码如下:
import dgl
import torchg = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])
print('origin edge feat')
print(g.edata)
print('-------------------------------')
def message_func(edges):# 常规属性操作如下:# print('edges.data:',edges.data)# print('edges.src:',edges.src)# print('edges.dst:',edges.dst)tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}return tmpg.apply_edges(message_func)
print('updata edge')
print(g.edata)
运行时,以(0,1)边为例,m=节点0的n_feat + 该边的e_feat,即20+2000=2020,以此类推,结果如下:
2.2.3 节点更新
import dgl
import torch# 创建图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000]) # 边特征
g.ndata['n_feat'] = torch.tensor([20, 21, 22]) # 节点特征
print('Original node features:')
print(g.ndata)
print('-------------------------------')g.apply_nodes(lambda nodes: {'n_feat': nodes.data['n_feat'] * 2})# 打印更新后的节点特征
print('Updated node features:')
print(g.ndata)
将节点信息×2,结果如图所示:
2.2.4 消息聚合
1 未使用更新函数
import dgl
import torchg = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])
print('origin node feat')
print(g.ndata)
print('-------------------------------')
def message_func(edges):# 常规属性操作如下:# print('edges.data:',edges.data)# print('edges.src:',edges.src)# print('edges.dst:',edges.dst)tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}return tmp
def reducer(nodes):# DGL中,批次中的节点是按照图的划分和计算需求确定的print('batch nodes: ',nodes.nodes())# nodes.mailbox 只包含在 message_func 中生成并发送到节点的消息print('mailbox: ',nodes.mailbox)print('--------------------------')# 每一行进行求和,目的是将数据转成列表格式tmp = {'h': torch.sum(nodes.mailbox['m'],dim=1)}return tmp
g.update_all(message_func, reducer)
print('updata node')
print(g.ndata)
print('edge')
print(g.edata)
经过了消息生成、消息聚合和节点特征更新过程,将新特征h更新到节点的特征字典中。
- 注意:这个过程并不会把特征m加到边的特征字典中
2 使用更新函数
很少这么用,不建议
import dgl
import torch# 创建图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000]) # 边特征
g.ndata['n_feat'] = torch.tensor([20, 21, 22]) # 节点特征
print('Original node features:')
print(g.ndata)
print('-------------------------------')# 消息传递函数
def message_func(edges):# 计算消息:边特征 + 源节点特征return {'m': edges.data['e_feat'] + edges.src['n_feat']}# 聚合函数
def reducer(nodes):# 打印批次节点和邮件箱内容print('Batch nodes: ', nodes.nodes())print('Mailbox: ', nodes.mailbox)print('--------------------------')# 对消息进行求和return {'h': torch.sum(nodes.mailbox['m'], dim=1)}# 更新节点特征的函数
def update_node_features(nodes):# 使用聚合后的特征更新节点特征# nodes.data['h'] 是聚合后的消息# nodes.data['n_feat'] 是节点的原始特征updated_feat = nodes.data['n_feat'] + nodes.data['h']return {'h': updated_feat}# 执行消息传递和聚合
g.update_all(message_func, reducer)# 在消息传递后,使用 apply_nodes 更新节点特征
g.update_all(message_func, reducer,update_node_features) # 获取聚合结果# 打印更新后的节点特征
print('Updated node features:')
print(g.ndata)
print('Edge features:')
print(g.edata)
结果如图所示: