Inductive Representation Learning on Large Graphs

https://arxiv.org/pdf/1706.02216

https://github.com/williamleif/GraphSAGE

Inductive Representation Learning on Large Graphs,2017,NIPS

总结:现有的图表示学习方法都是transductive的,即测试用的节点在训练时是seen的。本文作者提出了一种inductive图表示学习方法GraphSAGE。inductive的具体实现方法就是学习一个聚合器,每一层采样固定大小的邻居,然后聚合这些邻居的信息用于生成当前层目标节点的嵌入向量,而不是为每一个节点学习一个嵌入向量。

1. 简介

1.1 摘要

Low-dimensional embeddings of nodes in large graphs have proved extremelyuseful in a variety of prediction tasks, from content recommendation to identifyingprotein functions. However, most existing approaches require that all nodes in thegraph are present during training of the embeddings; these previous approaches areinherentlytransductiveand do not naturally generalize to unseen nodes. Here wepresent GraphSAGE, a generalinductiveframework that leverages node featureinformation (e.g., text attributes) to efficiently generate node embeddings forpreviously unseen data. Instead of training individual embeddings for each node,we learn a function that generates embeddings by sampling and aggregating featuresfrom a node’s local neighborhood. Our algorithm outperforms strong baselineson three inductive node-classification benchmarks: we classify the category ofunseen nodes in evolving information graphs based on citation and Reddit postdata, and we show that our algorithm generalizes to completely unseen graphsusing a multi-graph dataset of protein-protein interactions.

大规模图中节点的低维嵌入在各种预测任务中是非常有用的,比如内容推荐或者蛋白质功能确认。但是,现有的大多数方法都是transductive的,即训练的时候所有节点都要是可见的,不能将模型泛化到unseen节点。本文作者提出了GraphSAGE模型——一种通用的indective框架,通过利用节点特征信息,为unseen节点生成有效的embeddings。GraphSAGE不是为每个节点训练单独的嵌入,而是学习一个函数,通过从节点的局部领域采样和聚合特征俩生成嵌入的。我们的算法在三个inductive节点分类标准数据集上的性能都优于baseline方法。

1.2 本文工作

现有的大多数图嵌入算法都是matrix-factorization-based来优化节点嵌入,这些方法都是transductive的,无法直接用于unseen节点。虽然可以将这些方法调整成inductive,但是计算代价过大,每次做新预测前都要重新优化模型。

目前为止,GCNs仅仅被用于固定图下的transductive learning。本文作者将GCNs拓展到inductive无监督学习任务中,提出了一种通用的图嵌入框架GraphSAGE。

和那些利用矩阵分解的图嵌入方法不同,作者利用节点特征(比如文本属性,节点度等等)来学习可以generalize到unseen节点的嵌入函数。

具体来说,模型不是为每个节点训练一个单独的嵌入向量,而是训练一系列aggregator functions,通过聚合节点局部领域的特征信息,为每个节点生成嵌入向量,如下图所示:

2. 方法

整个框架最核心的问题就是:如何从节点的局部邻域聚合特征信息?即聚合函数如何定义。

2.1 嵌入生成算法

假设模型参数已经训练好了,本部分主要介绍框架的前向传播算法。

具体来说,假设模型学习到了K个aggregator函数,用AGGREGATEk,k{1,...,K}AGGREGATE_k,\forall k\in\{1,...,K\}表示。每个聚合函数都有权重参数Wk,k{1,...,K}W^k,\forall k\in\{1,...,K\}。K表示模型层数,即搜索深度。前向传播算法伪代码如下:

算法外层K个循环,表示模型有K层,即节点聚合了K阶邻域的信息。需要注意的时k=0时,hv0h^0_v表示节点的输入特征xvx_v

一、算法1拓展到minibatch setting

为了使用stochastic gradient descent优化模型,算法需要支持minibatches的前向传播和反向传播。下图展示了minibatch下的前项传播算法:

大概的思想是先为所有节点采样好每一层k中需要用的的邻居节点,然后再执行前向传播算法。算法2的2~7行即采样部分:BkB^k集合中存储着该batch下的所有输入节点(即需要计算embedding的节点),Bk1B^{k-1}存储着BkB^k集合中所有节点及这些节点的邻居,以此类推。最后B0B^0中存储着原始输入节点B及这些节点K阶邻域内的所有邻居节点。

节点采样完毕后,和算法1一样,执行前向传播过程,计算节点嵌入。

二、和WLtest之间的关联

如果我们对算法1作几点调整:(1)set K=VK=|V|;(2)set 权重矩阵为identity;(3)使用一个合适的哈希函数作为aggregator,这样算法一就等价于WL同构测试。此时,如果用算法1计算两个图的节点嵌入完全一样,那么WL test会输出两个图为同构图。

因此,算法1可以看做是对WL test的一种近似,即用神经网络替代WL test中的哈希函数。不过GraphSAGE使用来计算节点嵌入的,而不是用来判断图是否同构。不过GraphSAGE和WL test之间的关联从理论上佐证了GraphSAGE可以学习节点邻域的拓扑结构信息。

三、邻域的定义

算法1中如果聚合节点所有邻居的信息,计算代价过大,因此需要重新定义节点的邻域。作者定义N(v)\mathcal N(v)为固定大小集合,从{uV:(u,v)E}\{u\in \mathcal V:(u,v)\in\mathcal E\}中采样得到。如果不固定N(v)\mathcal N(v)大小,那么每个batch的复杂度是不可预测的,最糟糕的情况可能达到O(V)O(|\mathcal V|)

2.2 参数的优化

作者使用一个graph-based损失函数,来优化权重矩阵Wk,k{1,...,K}\mathbf W^k,\forall k\in\{1,...,K\}和聚合器参数:

JG(zu)=log(σ(zuzv))QEvnPn(v)log(σ(zuzvn))(1)J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)\tag 1

该损失函数鼓励距离近的节点有相似的节点嵌入,而不同节点的嵌入高度不同。其中vv表示在固定长度随机游走中和u共现的节点,σ\sigma表示sigmoid函数,Pn(v)P_n(v)表示负采样分布,Q表示负样本数量。该公式为无监督场景下的损失函数,在下游具体任务中可以替换成其他损失函数,比如交叉熵损失。

2.3 聚合器的选取

和N-D格子数据(比如图像、句子)不同,一个节点的邻居是无序的,因此聚合器的输入是一个无序的向量集合。理想情况下,aggregator函数应该是对称的(即具有置换不变性,和输入顺序无关),同时还要是trainable并且能够比较高的表征能力。作者测试了三种候选aggregator函数:

  1. Mean aggregator

    直接计算邻居向量的平均值,这样算法1的4、5行可以替换为:

    hvkσ(WMEAN({hvk1}{huk1,uN(v)})(2)\mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right.\tag 2

  2. LSTM aggregator

    采用LSTM聚合邻居信息,和Mean aggregator相比计算开销更大。但是LSTM不是symmetric,对输入顺序敏感,作者通过简单地将LSTM应用于节点邻居的随机排列,使LSTM能够操作一个无序集合。

  3. Pooling aggregator

    将节点所有邻居分别输入到一个全连接神经网络中,然后在用一个elementwise最大池化操作聚合邻居信息:

     AGGREGATE kpool =max({σ(Wpool huik+b),uiN(v)}),(3)\text { AGGREGATE }_{k}^{\text {pool }}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right),\tag 3

    需要注意的时max池化可以替换成任何symmetric vector函数,作者用mean池化代替max池化,模型性能并没有改变,因此作者下文所有实验都采用最大池化。

3. 实验

**数据集:**Citation,Reddit,PPI。所有实验都是在训练师unseen节点或者graph上进行的。

**baseline:**采用随机分类器、逻辑回归、DeepWalk、DeepWalk+features四个;

GraphSAGE变体:(1)GraphSAGE-GCN,将GraphSAGE用于GCN中;(2)分别使用上述三种聚合器。

3.1 实验结果

3.2 时间复杂度及参数敏感性

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信