Motif-Driven Contrastive Learning of Graph Representations

https://ojs.aaai.org/index.php/AAAI/article/view/17986/17791

https://arxiv.org/pdf/2012.12533

Motif-Driven Contrastive Learning of Graph Representations,2021,AAAI-21 Student Papers and Demonstrations

总结:文章动机挺好的,希望通过motif来采样语义信息更丰富的子图来执行GCL,用于GNN的预训练。而且本文实验做的很丰富,实验结果看起来也挺好的。不过,文章写得有点乱(好像是本科生写的),看起来有点费劲,个人感觉有些地方描述的不太准确,而且有点冗余。对于这篇文章有两点思考:1. 作者这种motif学习方式可不可行?有没有道理?而且作者用的数据集是分子数据集,在其他类型数据集中,这种方法可不可行?2. 个人感觉利用子图进行对比学习是一个挺好的方向,有没有其他和motif类似的子图生成方法?

1. 简介

1.1 摘要

Pre-training Graph Neural Networks (GNN) via self-supervised contrastive learning has recently drawn lots of attention. However, most existing works focus on node-level contrastive learning, which cannot capture global graph structure. The key challenge to conducting subgraph-level contrastive learning is to sample informative subgraphs that are semantically meaningful. To solve it, we propose to learn graph motifs, which are frequently-occurring subgraph patterns (e.g. functional groups of molecules), for better subgraph sampling. Our framework MotIf-driven Contrastive leaRning Of Graph representations (MICRO-Graph) can: 1) use GNNs to extract motifs from large graph datasets; 2) leverage learned motifs to sample informative subgraphs for contrastive learning of GNN. We formulate motif learning as a differentiable clustering problem, and adopt EM-clustering to group similar and significant subgraphs into several motifs. Guided by these learned motifs, a sampler is trained to generate more informative subgraphs, and these subgraphs are used to train GNNs through graph-to-subgraph contrastive learning. By pretraining on the ogbg-molhiv dataset with MICRO-Graph, the pre-trained GNN achieves 2.04% ROC-AUC average performance enhancement on various downstream benchmark datasets, which is significantly higher than other state-of-the-art self-supervised learning baselines.

最近通过自监督对比学习的预训练GNN得到了很多关注。但是,现有的工作大多关注于node-level对比学习,无法捕捉全局图结构。对于subgraph-level对比学习,其关键是如何采样具有语义信息的子图。为了解决这个问题,我们提出通过学习graph motif来实现更好的子图采样。我们提出的框架MICRO-Graph可以:(1)利用GNNs从大型图数据集中提取motifs;(2)利用学习到的motifs采样蕴含丰富信息的子图用于GNN的对比学习。具体来说,我们将motif学习看做一个可微聚类问题,并利用EM-clustering将相似的、有意义的子图划分成不同的motif。在这些学习到的motifs的指导下,训练一个sampler用于生成蕴含信息更丰富的子图,然后讲这些生成的子图用于GNN的对比学习中。通过ogbg-molhiv数据集预训练模型后,得到的GNN在其他各种标准数据集中的性能得到了平均2.04%的提升。

1.2 本文工作

背景: 最近GNNs在图表示学习领域展现了其强大的能力,为了进一步提高GNNs在没有数据标签情况下的能力,人们提出了很多采用自监督方式预训练GNNs的方法。预训练好的GNNs只需经过少量微调步骤就能用于相同领域下的其他数据集中,并取得很好的表现。

动机: 现有的GCL方法大多都是进行node-level对比,不利于捕捉全局图信息,因此subgraph-level对比是一种比较好的选择,但是如何采样informative子图是一个很大的挑战。作者受motif相关研究的启发,希望利用motif来更好地采样子图用于GNN的对比学习。

本文工作: 作者提出了一种subgraph-level图对比学习方法MICRO-Graph。首先通过motif learning从大型数据集中提取motifs,然后利用这些学习到的motifs指导子图生成,将这些子图用于GNN对比学习,达到对GNN进行预训练的目的。在ogbg-molhiv数据集上,利用MICRO-Graph框架对GNN进行预训练后,pre-trained GNN在各种标准数据集上的性能得到了很大提升。

2. 方法

目标: 以自监督方式(无需标签)预训练一个GNN编码器ENCθ()\mathbf{\text{ENC}}_\theta(·),只需要少量有标签数据进行微调,即可泛化到同领域新数据集上。

作者提出的对比模型是subgraph-level的,因此最关键的问题就是如何生成用于对比的子图。为了解决这个问题,作者的大概思路如下:作者将motif learning问题看做一个可微EM聚类问题,然后用学习到的motifs来指导子图采样。MICRO-Graph框架如下图所示:

其大致执行步骤如下:

  1. 首先将一个batch的graph传入模型,通过GNN编码器ENCθ()\mathbf{\text{ENC}}_\theta(·)计算节点嵌入;
  2. 然后利用EM算法将节点聚类成motif-like子图,再利用池化函数得到子图嵌入;
  3. 然后再将得到的子图嵌入分别传入两个模块:
    • Motif-Learner:利用子图嵌入更新motif嵌入
    • Contrastive Learning:施行GNN的对比学习

整个框架最复杂的就是中间那块motif-learner相关部分,下面作一个详细介绍。

2.1 motif-learner

这一部分其实包含两个小环节:

  1. 给定motif嵌入M\mathbf M,如何将所有节点划分成若干个motif-like子图集合;
  2. 得到子图集合后,如何反过来更新M\mathbf M,得到更好的motif原型。

给定有NN个节点的图G\mathcal G,我们用{h1,...,hN}\mathbf{\{h_1,...,h_N\}}表示节点嵌入,Par\mathbf Par表示对所有节点的某种划分方式,{s1,...,sJ}=G[Par]\mathbf{\{s_1,...,s_J\}}=\mathbf{\mathcal G[Par]}表示该划分方式下得到的所有子图的嵌入。

为了对问题进行建模,我们为每个子图定义一个K-way类别随机变量zj{1,,K}z_{j} \in\{1, \cdots, K\}P(sjzj=k)P\left(s_{j} \mid z_{j}=k\right)表示sjs_j是在k-th motif指导下生成的概率。这样关于单个子图的似然函数可以定义成:

P(sjM,θ)=k=1KP(sjzj=k,M,θ)P(zj=k)(1)P\left(s_{j} \mid M, \theta\right)=\sum_{k=1}^{K} P\left(s_{j} \mid z_{j}=k, M, \theta\right) P\left(z_{j}=k\right)\tag 1

这里对”似然函数“相关概念做一个补充说明:p(xθ)p(x|\theta)

对于函数p(xθ)p(x|\theta),有两个输入:x表示一个具体的数据;θ\theta表示模型参数。

  • 如果θ\theta确定,x是变量,那么这个函数称作概率函数。即在给定θ\theta条件下,变量x出现的概率。
  • 如果x是确定的,θ\theta为变量,那么这个函数就叫做似然函数。即在不同模型参数下,x出现的概率是多少。

对于建模问题,我们通常都是有一组观测数据,需要估计模型的参数,也就是上述第二种情况。非常常用的一种参数估计办法就是极大似然估计,即将似然函数p(xθ)p(x|\theta)取最大值时的θ\theta作为模型参数(其背后思想是:已经发生的就是最可能发生的)。

因为假设所有子图都是独立同分布的,整图G\mathcal G的条件似然就是:

P(G Par ,M,θ)=sjG[Par]k=1KP(sjzj=k,M,θ)P(zj=k)(2)\begin{array}{l} P(\mathcal{G} \mid \text { Par }, M, \theta) \\ =\prod_{s_{j} \in \mathcal{G}[P a r]} \sum_{k=1}^{K} P\left(s_{j} \mid z_{j}=k, M, \theta\right) P\left(z_{j}=k\right) \end{array}\tag 2

此时,前文提到的两个问题都可以通过最大化该似然函数来解决:

  1. 给定M的情况下,利用EM最大化公式2,得到最优划分方式ParPar^*
  2. 给定划分方式ParPar情况下,同样利用EM最大化公式2,更新motif table M\mathbf M

但是对于第1个问题的解决方式:“给定M,最大化上述似然函数来找到最优节点划分方式Par\mathbf {Par^*}”难以实现。因为Par\mathbf{Par}的搜索空间巨大,且随节点数量增加而指数增大,计算量太大。为了解决这一问题,作者将P(sjzj=k)P\left(s_{j} \mid z_{j}=k\right)进一步拆分到节点级。

具体来说,为每个节点定义一个K-way类型变量P(hlcl=k)P\left(\boldsymbol{h}_{l} \mid c_{l}=k\right),这样公式2可以转化成:

P^(GPar,M,θ)=sjG[Par]k=1Kl=1hlsjNP(hlcl=k,M,θ)P(zj=k)(3)\begin{array}{l} \hat{P}(\mathcal{G} \mid P a r, M, \theta)= \\ \prod_{s_{j} \in G[P a r]} \sum_{k=1}^{K} \prod_{l=1 \atop h_{l} \in s_{j}}^{N} P\left(h_{l} \mid c_{l}=k, M, \theta\right) P\left(z_{j}=k\right) \end{array}\tag 3

总结一下就是上述两个优化问题可以分别分别定义成:

Par,θ=argmaxPar,θP^(GPar,M,θ)(4)\begin{aligned} P a r^{*}, \theta^{*} &=\arg \max _{P a r, \theta} \hat{P}(\mathcal{G} \mid P a r, M, \theta) \end{aligned}\tag 4

M=argmaxMP(GPar,M,θ)(5)\begin{aligned} M^{*} &=\arg \max _{\boldsymbol{M}} P\left(\mathcal{G} \mid P a r^{*}, M, \theta^{*}\right) \end{aligned}\tag 5

到这里就建模完毕,下面对这两个优化问题的具体解决方案进行详细介绍。

2.1.1 Modeling and Learning for The Graph Partition

注:这里作者说的所有EM相关算法,因为不存在什么联合分布、隐数据,其实就是极大似然法,和我们常用的softmax分类差不多。所谓E步就是求后验概率,M步就是根据后验概率计算似然损失,再利用损失函数进行梯度下降优化模型参数。

EM算法分为E和M两步,在E-step计算后验概率ql,k=P(cl=khl)q_{l, k}=P\left(c_{l}=k \mid h_{l}\right),在M-step对目标进行优化得到最优解。

这里作者首先将节点嵌入映射到motif嵌入空间,然后计算节点嵌入和motif嵌入之间的相似度作为后验概率ql,kq_{l,k}

ql,k=exp(ϕ(Whhl)Tϕ(mk)/τ)kexp(ϕ(Whhl)Tϕ(mk)/τ)(6)q_{l, k}=\frac{\exp \left(\phi\left(W_{h} \boldsymbol{h}_{l}\right)^{T} \phi\left(m_{k}\right) / \tau\right)}{\sum_{k^{\prime}} \exp \left(\phi\left(W_{h} \boldsymbol{h}_{l}\right)^{T} \phi\left(m_{k^{\prime}}\right) / \tau\right)}\tag 6

其中ϕ(x)=x/x2\phi(x)=x /\|x\|_{2}L2L-2范式,τ\tau为温度参数。对于1个batch的图G={G1,,GB}\mathbb G=\left\{\mathcal{G}_{1}, \cdots, \mathcal{G}_{B}\right\},用Q=[q(1),,q(B)]TQ=\left[q^{(1)}, \cdots, q^{(B)}\right]^{T}表示所有node-to-motif概率。

在E-step,利用公式6可以计算后验概率P(cl=khl)P\left(c_{l}=k \mid h_{l}\right)。这里和前人的研究类似,将该后验概率直接用于M-step存在问题,因为会出现坍塌解(即所有节点都分配给同一个motif)。

为了避免这个问题,作者采用了另一篇文章中的做法:

maxQ^QTr(Q^QT)+1λH(Q^), where Q={Q^R+NB,KQ^1K=1BNB,Q^T1NB=1KK}\begin{array}{c} \max _{\hat{Q} \in Q} \operatorname{Tr}\left(\hat{Q} Q^{T}\right)+\frac{1}{\lambda} H(\hat{Q}), \text { where } \\ Q=\left\{\hat{Q} \in \mathbb{R}_{+}^{N_{B}, K} \mid \hat{Q} 1_{K}=\frac{1_{B}}{N_{B}}, \hat{Q}^{T} 1_{N_{B}}=\frac{1_{K}}{K}\right\} \end{array}

这块没有细看,大概的原理是:通过某种正则化手段,强迫motif assignment平衡,防止所有节点坍塌到同一个motif,并且在将Q离散化。

然后这里和softmax多分类损失类似,定义node-mot损失函数如下:

Lnode-mot =l=1Nk=1Kql,klogql,k(10)\mathcal{L}_{\text {node-mot }}=-\sum_{l=1}^{N} \sum_{k=1}^{K} q_{l, k}^{*} \log q_{l, k}\tag{10}

这里作者考虑到:使用这种子图采样方法忽略了图的结构信息,因为没有限制分到同一个子图中的节点之间存在边,因此得到的子图可能是稀疏的。

为了解决这个问题,作者参考前人研究,定义了一个spectral clustering-based正则化损失:

Lreg=Tr(qTAq)Tr(qTDq)+q~Tq~q~Tq~FIJJF(11)\mathcal{L}_{r e g}=-\frac{\operatorname{Tr}\left(q^{T} A q\right)}{\operatorname{Tr}\left(q^{T} D q\right)}+\left\|\frac{\tilde{q}^{T} \tilde{q}}{\left\|\tilde{q}^{T} \tilde{q}\right\|_{F}}-\frac{I_{J}}{\sqrt{J}}\right\|_{F}\tag{11}

这个损失函数可以让我们得到的Q^\hat Q和spectral clustering结果更近一点。

2.1.2 Modeling and Learning for Motif Embeddings

和前面差不多,首先计算子图和motif之间的相似度,作为后验概率:

pj,k=exp(ϕ(Wssj)Tϕ(mk)/τ)jexp(ϕ(Wssj)Tϕ(mk)/τ)(12)p_{j, k}=\frac{\exp \left(\phi\left(W_{s} s_{j}\right)^{T} \phi\left(m_{k}\right) / \tau\right)}{\sum_{j^{\prime}} \exp \left(\phi\left(W_{s} s_{j^{\prime}}\right)^{T} \phi\left(m_{k}\right) / \tau\right)}\tag{12}

然后和上面一样,计算分类损失:

Lmot-sub =j=1Jk=1Kπj,klogpj,k(13)\mathcal{L}_{\text {mot-sub }}=-\sum_{j=1}^{J} \sum_{k=1}^{K} \pi_{j, k} \log p_{j, k}\tag{13}

这里公式格式和前面基本一样,只不过这里是利用损失函数进行梯度下降,更新MM矩阵,而前面则利用公式10进行梯度下降,更新参数θ\theta,从而更好的对节点进行分组。

至此,我们就得到了所有motif学习相关的损失函数:

Lmotif =λnLnode-mot +λsLmot-sub +λrLreg (14)\mathcal{L}_{\text {motif }}=\lambda_{n} \mathcal{L}_{\text {node-mot }}+\lambda_{s} \mathcal{L}_{\text {mot-sub }}+\lambda_{r} \mathcal{L}_{\text {reg }}\tag{14}

2.2 对比学习

如前文图2所示,这里将整图嵌入作为锚点,将从该锚点采样得到的子图作为正例,将从其他图采样得到的子图作为负例,对比损失如下:

Lcontra =1Bi=1BsjGilogexp(Yi,j/τ)jexp(Yi,j/τ)(15)\mathcal{L}_{\text {contra }}=-\frac{1}{B} \sum_{i=1}^{B} \sum_{s_{j} \in G_{i}} \log \frac{\exp \left(Y_{i, j} / \tau\right)}{\sum_{j^{\prime}} \exp \left(Y_{i, j^{\prime}} / \tau\right)}\tag {15}

这样整个框架的损失定义为:

L=αLmotif +(1α)Lcontra (16)\mathcal{L}=\alpha \mathcal{L}_{\text {motif }}+(1-\alpha) \mathcal{L}_{\text {contra }}\tag{16}

框架代码大致如下:

在初始状态下motif embedding和motif-like子图都是随机的。

3. 实验

3.1 基础实验

作者进行了下列两种实验:

  • Transfer Fine-tune Setting: 在只有少量有标签数据的下游任务中对预训练好的GNN模型进行微调。
  • Feature Extraction Setting: 和前一种设置差不多,只不过将GNN看做特征提取器,在其基础上重新训练一个线性分类器。

实验结果如下表1和表2所示:

3.2 消融实验

主要针对:子图采样方法、对比视角、子图编码、motif数量、GNN原型5个方面进行消融实验,结果如下表3和图3所示:

作者对其中一些结果做了分析和假设,感兴趣的可以看原文。

3.3 可视化实验

作者通过closest subgraphs展示了模型学习到的子图:

其中包含生物化学分子里面常见的苯环、 醋酸酯 等。

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信