SAIL: Self-Augmented Graph Contrastive Learning

https://arxiv.org/pdf/2009.00934

SAIL: Self-Augmented Graph Contrastive Learning,2022,AAAI

总结:作者提出了一种自增强GCL方法SAIL,和普通GCL方法相比,SAIL没有使用任何增强技术,而是将节点原始输入特征和GNNs学习到的节点表示看作两个视角进行对比。再此基础上,作者认为模型还存在问题,这里正负样本对的选取是通过edge connectivity来进行的,可能无法代表整个节点表示分布的全貌,并且可能导致模型偏向于链路预测任务(在其他pretext task,比如节点分类中表现一般)。为了解决这个问题,作者提出了一种self-distilling方法,包含intra-和inter-distilling两个模块。intra-distilling可以保证学习到的节点表示和原始输入特征具有相同的分布。inter-distilliing可以让浅层GNNs模拟出深层GNNs的效果,并且不存在过平滑问题。

整体来说,文章还可以,实验比较丰富,实验结果也挺可观的。不过个人感觉,SAIL中的对比策略和知识蒸馏组合在一起有点A+B的意思,结构不是那么连贯,故事不够完整。

1. 简介

1.1 摘要

This paper studies learning node representations with graph neural networks (GNNs) for unsupervised scenario. Specifically, we derive a theoretical analysis and provide an empirical demonstration about the non-steady performance of GNNs over different graph datasets, when the supervision signals are not appropriately defined. The performance of GNNs depends on both the node feature smoothness and the locality of graph structure. To smooth the discrepancy of node proximity measured by graph topology and node feature, we proposed SAIL - a novel Self-Augmented graph contrastive Learning framework, with two complementary self-distilling
regularization modules, i.e., intra- and inter-graph knowledge distillation. We demonstrate the competitive performance of SAIL on a variety of graph applications. Even with a single GNN layer, SAIL has consistently competitive or even better performance on various benchmark datasets, comparing with state-of-the-art baselines.

本文研究无监督场景下利用GNNs学习节点表示。具体来说,我们首先进行了理论分析,并提供了一个经验性证明。当没有定义一个合适的监督信号的情况下,GNNs在不同数据集上表现不稳定。GNNs的表现同时取决于节点特征的平滑度和图结构的局部性。为了让通过图拓扑和节点特征测量得到的节点邻近性的差异变得平滑,我们提出了SAIL模型,一种新的自增强图对比学习框架。SAIL包含两个辅助self-distilling正则化模块,即intra- and inter-graph knowledge distillation(图内和图间知识蒸馏)。我们证明了SAIL在很多图应用中都取得了具有竞争力的表现。即使只有一层GNN,在各种标准数据集下,和SOTA方法相比,SAIL依旧表现的具有竞争力甚至更好。

1.2 本文工作

背景: GNNs的关键在于通过不断聚合local neighbors,过滤原始node features中的噪声,来得到平滑的节点表示,这一过程通常依赖于监督信号(比如节点标签、图标签)。但是很多场景下,标签信息无法得到。因此自监督学习得到了越来越多的关注,以无标签方式预训练GNNs已经成为学习GNN模型的另一种方法。

动机: can we advance the expressivity of GNNs with the knowledge extracted by themselves in an unsupervised way? (我们能否以无监督方式从GNN自身提取知识,以提高GNN的表达能力)

本文工作: (1)作者首先进行了一个理论分析,得出这样一个结论:提高GNN shallow layer的节点表示,可以帮助deep layer得到更好的节点表示。(2)作者基于这个发现提出了自增强图对比学习方法——SAIL,将原始输入的节点特征和最后一层输出进行对比。(3)在第2点的基础上,作者引入了self-distilling(自蒸馏)模块,具体包含intra- 和inter-distilling两部分。inter-distilling可以让shallow模型模拟出deep模型的操作,同时避免过平滑问题。intra-distilling可以学习到的节点表示和节点原始特征具有consistent uniformity distribution。

2. 方法

2.1 理论分析

作者首先对GNNs的底层原理进行了一个理论分析,得出结论:“ The quality of each GNN layer has close relation to previous layer. As the initial layer, the quality of input layer feature X~\tilde X will propagate from the bottom to the top layer of a given GNN model. ”。(然后作者基于这个发现引出他们的self-augment方法,感觉两者没啥关系,有点勉强。)

总的来说,个人感觉这篇文章理论分析部分和文章关联性不是特别大,有点拿来凑数的感觉,感兴趣的可以自己看原文。关联性比较大的一点可能是在证明定理1的时候得出一个结论:“if a pair of node (vi,vj)(v_i, v_j) similar neighbors (i.e. αicαjc\alpha_{ic}\approx\alpha_{jc}), the obtained node representations to be similar. ”,这个和后文一直强调的smothness相关。

2.2 模型架构

SAIL架构如上图所示,看起来有点复杂,我们一点一点拆开看。

既然SAIL是一种GCL方法,那么我们先看看SAIL是如何设计它的对比视角和对比损失的。

前面提到SAIL是self-augmented方法,和普通GCL方法相比,它没有使用其它任何增强策略,单纯的自己和自己比较。

具体来说,SAIL将节点的原始输入特征和GNN最后一层得到的节点表示看作两个view进行对比。对于一个样本集合{vi,vj,vk},eijE,eikE\{v_i,v_j,v_k\}, e_{ij}\in \mathcal E, e_{ik}\notin\mathcal Ehih_i表示viv_i的节点表示,x~i\tilde x_i表示viv_i的原始特征,ljkil_{jk}^i表示一个(hi,x~j)(h_i,\tilde x_j)(hi,x~k)(h_i,\tilde x_k)的pairwise comparison。最终的损失函数定义为:

Lssl=eijEeikE(ψ(hi,x~j),ψ(hi,x~k))+λR(G)\mathcal{L}_{s s l}=\sum_{e_{i j} \in \mathcal{E}} \sum_{e_{i k} \notin \mathcal{E}}-\ell\left(\psi\left(h_i, \widetilde{x}_j\right), \psi\left(h_i, \widetilde{x}_k\right)\right)+\lambda \mathcal{R}(\mathcal{G})

其中l()l(\cdot)表示任意一种对比损失函数,ψ\psi表示打分函数,R\mathcal R表示正则化函数。本文作者使用lnσ(ψ(hi,x~j)ψ(hi,x~k))\ln \sigma\left(\psi\left(h_i, \widetilde{x}_j\right)-\psi\left(h_i, \widetilde{x}_k\right)\right)作为对比损失,其中σ(x)=11+exp(x)\sigma(x)=\frac{1}{1+\exp (-x)}

观察上面对比损失函数,我们可以发现这里(hi,x~j)(h_i,\tilde x_j)(hi,x~k)(h_i,\tilde x_k)其实就相当于常规GCL方法中的正样本对和负样本对。不过我们发现常规GCL模型的对比损失只有上述公式的前半部分,公式后半部分的正则化函数是干啥用的?

作者认为:这里正负样本对的选取是通过edge connectivity来进行的,可能无法代表整个节点表示分布的全貌,并且可能导致模型偏向于链路预测任务(在其他pretext task,比如节点分类中表现一般)。

为了解决这个问题,作者提出了一种self-distilling方法,包含intra-和inter-distilling两个模块。

2.2.1 Intra-distilling

目的: Ensure the distribution consistency on the relations between the learned node representations H and the node features $ \tilde X$ over a set of randomly sampled nodes.

下面看一下它的具体实现。

LS={LS1,LS2,,LSN}L S=\left\{L S_1, L S_2, \cdots, L S_N\right\}表示随机选取的伪关系图,其中LSiVLS_i\subset\mathcal V并且LSi=d|LS_i|=d表示以viv_i为中心随机选取的d个邻居。

Sijt=exp(ψ(hi,x~j))vjLSiexp(ψ(hi,x~j))Sijs=exp(ψ(x~i,x~j))vjLSiexp(ψ(x~i,x~j))\begin{aligned} &S_{i j}^t=\frac{\exp \left(\psi\left(h_i, \widetilde{x}_j\right)\right)}{\sum_{v_j \in L S_i} \exp \left(\psi\left(h_i, \widetilde{x}_j\right)\right)} \\ &S_{i j}^s=\frac{\exp \left(\psi\left(\widetilde{x}_i, \widetilde{x}_j\right)\right)}{\sum_{v_j \in L S_i} \exp \left(\psi\left(\widetilde{x}_i, \widetilde{x}_j\right)\right)} \end{aligned}

SijtS_{ij}^tSijsS_{ij}^sviv_ivjv_j之间两种不同的相似度,前者计算的是viv_i的表示和vjv_j的原始特征之间的相似度,后者计算的是viv_i的原始特征和vjv_j的原始特征之间的相似度。

作者将SijtS_{ij}^t看做teacher signal,引导节点特征X~={x~1,x~2,,x~N}\tilde{\mathbf{X}}=\left\{\widetilde{x}_1, \widetilde{x}_2, \cdots, \tilde{x}_N\right\}在随机抽样图上和自己达成一致的关系分布。作者使用交叉熵计算两个分布之间的相似度Si=CrossEntropy(S[i,]]t,S[i,]s)S_i=CrossEntropy \left(S_{[i,]]}^t, S_{[i, \cdot]}^s\right)。这样我们就可以计算所有节点的关系相似度分布:

Rintra =i=1NSi\mathcal{R}_{\text {intra }}=\sum_{i=1}^N \mathcal{S}_i

Rintra\mathcal R_{intra}将作为前文提到的正则化函数的一部分,让学习到的节点表示H\mathbf H和节点特征X~\widetilde{\mathbf X}在subgraph-level保持一致。

2.2.2 Inter-distilling

目的:Through multiple implementations of the inter-distilling module, we implicitly mimic the deep smoothing operation with a shallow GNN (e.g. a single GNN layer), while avoiding to bring noisy information from high-order neighbors 。

下面看一下它的具体实现。

如前文图1所示,作者首先拷贝目标模型的参数,构建一个教师模型Φt\Phi_t

通过几轮迭代后,向目标模型参数中注入噪声。

通过和教师模型一起工作,学生模型Φs(X,A)={Hs,Xs~}\Phi_s(\mathbf X, \mathbf A)=\{\mathbf H_s,\widetilde{\mathbf X_s}\}可以从教师模型Φt\Phi_t中蒸馏知识。

由于没有标签信息,作者采用“ contrastive representation distillation ”一文中使用的方法进行知识蒸馏,包含如下两个部分:

Rinter =KD(Ht,X~sG)+KD(Ht,HsG)\mathcal{R}_{\text {inter }}=K D\left(\mathbf{H}_t, \widetilde{\mathbf{X}}_s \mid \mathcal{G}\right)+K D\left(\mathbf{H}_t, \mathbf{H}_s \mid \mathcal{G}\right)

然后通过一系列推导可以得到如下公式(感兴趣可以看原文):

其中x~is\tilde x_i^s表示学生模型中节点viv_i的特征,hish_i^s表示学生模型学习到的节点viv_i的表示。βij\beta_{ij}表示注意力权重,有很多选择,比如mean pooling。本文作者使用的就是mean-pooling。

至此,我们就得到了SAIL完整的损失函数:

Lssl=eijEeikE(ψ(his,x~js),ψ(his,x~ks))+λ(Rintra +Rinter )\begin{aligned} \mathcal{L}_{s s l}=\sum_{e_{i j} \in \mathcal{E}} \sum_{e_{i k} \notin \mathcal{E}} &-\ell\left(\psi\left(h_i^s, \widetilde{x}_j^s\right), \psi\left(h_i^s, \widetilde{x}_k^s\right)\right) +\lambda\left(\mathcal{R}_{\text {intra }}+\mathcal{R}_{\text {inter }}\right) \end{aligned}

SAIL伪代码如下图所示:

3. 实验

3.1 基础实验

  1. 节点分类

  2. 节点聚类

  3. 链路预测

3.2 节点表示的平滑度

  1. Before encoding

  2. After encoding

    作者使用目标节点和邻居节点之间的“mean average distance (MAD)”来度量节点表示的smoothness。实验结果如下表:

其中MADgap=MADrmtMADneiMAD_{gap}=MAD_{rmt}-MAD_{nei}MADneiMAD_{nei}表示目标节点和其邻居之间的MAD值,MADrmtMAD_{rmt}表示目标节点和remote节点间的MAD值。如果我们得到的MADgapMAD_{gap}比较大,表示学习到的节点表示没有过平滑问题。

作者还定义了MADratio=MADgapMADneiMAD_{ratio}=\frac{MAD_{gap}}{MAD_{nei}}来度量邻居节点和远端节点之间带来的 “information-to-noise ratio(信噪比)”。

从上表5我们可以发现:(1)SAIL实现了最好的smoothing表现(MADneiMAD_{nei}最小)(2)对于过平滑问题,SAIL的MADgapMAD_{gap}相对更大(可以对比各个方法的MADratioMAD_{ratio}值)。

3.3 消融实验

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信