Prototypical Graph Contrastive Learning

https://arxiv.org/pdf/2106.09645

Prototypical Graph Contrastive Learning,2021,arxive preprint

总结:作者提出了一种graph-graph图对比学习方法PGCL,用于无监督图分类任务。和其他图对比学习方法相比,作者最大的创新点就是负样本采样策略加权对比损失。现有的GCL方法,对于负样本的选取基本都是均匀采样,不考虑负样本和锚点之间的相似度。作者在PGCL中首先对样本进行聚类,计算每个类别的原型向量,然后将和锚点处于不同簇的其他样本视作真负样本,处于相同簇的样本视为假负样本。另外,作者根据样本所在簇的原型和锚点所在簇原型之间距离计算权重用于对比损失的定义。感觉目前学术界对于GCL的研究,基本都是将CV和NLP中对比学习方法拓展到图邻域,可以多多关注CV中对比学习的最新进展。2021ICLR中有一篇名为“ PROTOTYPICAL CONTRASTIVE LEARNING OF UNSUPERVISED REPRESENTATIONS”的CV论文,我没有看着篇文章,不知道和本文是否有联系。

1. 简介

1.1 摘要

Graph-level representations are critical in various real-world applications, such as predicting the properties of molecules. But in practice, precise graph annotations are generally very expensive and time-consuming. To address this issue, graph contrastive learning constructs instance discrimination task which pulls together positive pairs (augmentation pairs of the same graph) and pushes away negative pairs (augmentation pairs of different graphs) for unsupervised representation learning. However, since for a query, its negatives are uniformly sampled from all graphs, existing methods suffer from the critical sampling bias issue, i.e., the negatives likely having the same semantic structure with the query, leading to performance degradation. To mitigate this sampling bias issue, in this paper, we propose a Prototypical Graph Contrastive Learning (PGCL) approach. Specifically, PGCL models the underlying semantic structure of the graph data via clustering semantically similar graphs into the same group, and simultaneously encourages the clustering consistency for different augmentations of the same graph. Then given a query, it performs negative sampling via drawing the graphs from those clusters that differ from the cluster of query, which ensures the semantic difference between query and its negative samples. Moreover, for a query, PGCL further reweights its negative samples based on the distance between their prototypes (cluster centroids) and the query prototype such that those negatives having moderate prototype distance enjoy relatively large weights. This reweighting strategy is proved to be more effective than the uniform sampling. Experimental results on various graph benchmarks testify the advantages of our PGCL over state-of-the-art methods.

图表示学习在很多现实应用中都非常重要,比如预测分子属性。但是在实际应用中对图数据进行准确标记通常代价十分昂贵。为了解决这个问题,图对比学习通过构造实例判别任务,来拉近正例样本间距离,拉远负样本间距离,用于无监督表示学习。然而对于一个锚点,其负样本是从所有图中通过均匀采样的到的,现有的这些方法存在严重的抽样偏差问题,即负样本可能和锚点样本具有相同的语义结构,从而导致模型性能下降。为了缓解采样偏差问题,本文我们提出了原型图对比学习方法PGCL。具体来说,PGCL通过将语义相似的图进行聚类,来学习图数据潜在的语义结构。然后,给定一个锚点时,从其他clusters中选取负样本,这样可以确保锚点和负样本语义间的差异。另外,对于一个锚点,PGCL进一步根据负样本所在cluster原型和锚点所在cluster原型之间的距离,对所有负样本进行reweight,让原型距离适中的那些负样本簇权重更大。这种reweight操作均匀抽样更有效。不同数据集上的实验表明,PGL比其他SOTA方法优势更大。

1.2 本文工作

背景: 图表示学习在很多实际应用中都有用到,近几年来,研究人员关注最多的方法就是GNNs。但是大部分GNNs都是有监督方法,依赖大量有标签数据。然后,某些领域数据标记需要领域知识,因此标记成本过大。因此,这两年无监督图表示学习方法是一个重要发展方向,研究人员通过最大化local(or global)和global信息间的MI来学习图表示。

动机: 现有的图对比学习方法都存在两个弊端:

  • 全局结构: 现有的GCL方法主要关注于对instance-level结构相似度建模,但是在实际应用中图数据的全局结构更重要。比如MUTAG是一个 致突变芳香族和杂芳香族硝基数据集,可分成7个类别。这些类别分配包含着潜在的全局结构,但是并没有被标记用于提高表示学习性能。

    这里我理解的大概意思就是:现有GCL方法在对比的时候只关注instance这个层面,没有站在顶层看整个数据集本身。这个时候只能捕捉到局部结构信息,无法捕捉全局结构信息。

  • 负样本采样偏差: 如下图1所示,现有的GCL方法对于某个锚点负样本的选取,基本都是按照均匀分布进行采样。但是作者认为负样本有“true”,有"false",还有质量高低之分。我们应该让模型尽量选取质量高的负样本。

本文工作: 作者提出了一种新的图对比学习框架PGCL,首先对所有样本进行聚类,然后基于聚类结果选取负样本,并按照某种策略给负样本分配不同权重,基于这些权重计算对比损失。

2. 方法

首先再明确下,作者本文针对的问题是无监督图分类问题。作者提出的图对比学习方法PGCL有两个重要步骤:

  • 聚类: 聚类的目标有两个,一是将语义相似的图划分到同一组,二是将同一张图的不同视角划分到同一组。
  • 权重分配: 对于每一个锚点样本,选取负样本时不再使用随机均匀采样,而是给第一步聚类得到的所有簇分配一个权重,按照权重大小采样。和目标节点所在簇距离越近的簇权重越大。

上面两个步骤准备完毕后,每一个样本都有其对应的正样本和负样本集合,此时可以进行对比学习了。下图展示了PGCL的框架:

2.1 聚类

给定图GiG_izi=fθ(Gi)z_i=f_\theta(G_i)表示通过GNN得到的图表示向量。将所有样本划分成K组,用CRK×D={c1,...,ck}C\in\mathbb R^{K\times D}=\{c_1,...,c_k\}分别表示每一组数据的原型向量,需要注意的时该原型向量是可训练的(即参与梯度下降的参数)。

这样,给定一个样本GiG_i,及其表示向量ziz_i,我们可以按照下列方式计算该样本属于每一组的概率:

p(yzi)=softmax(Cfθ(Gi))(4)p\left(y \mid z_{i}\right)=\operatorname{softmax}\left(\mathbf{C} \cdot f_{\theta}\left(G_{i}\right)\right)\tag 4

(pi,qi)=y=1Kq(yzi)logp(yzi)(5)\ell\left(p_{i}, q_{i^{\prime}}\right)=-\sum_{y=1}^{K} q\left(y \mid z_{i}^{\prime}\right) \log p\left(y \mid z_{i}\right)\tag 5

交换ziz_iziz_i'位置我们可以得到l(pi,qi)l(p_{i'},q_i),这样最终聚类损失定义为:

Lconsistency =i=1n[(pi,qi)+(pi,qi)](6)\mathcal{L}_{\text {consistency }}=\sum_{i=1}^{n}\left[\ell\left(p_{i}, q_{i^{\prime}}\right)+\ell\left(p_{i^{\prime}}, q_{i}\right)\right]\tag 6

优化公式5会存在一种退化解,即将所有样本分配到同一组。为了解决这个问题,将样本尽可能均匀分配到各组,采用如下优化方法:

minp,qLconsistency  subject to y:q(yzi)[0,1] and i=1Nq(yzi)=NK(7)\min _{p, q} \mathcal{L}_{\text {consistency }} \text { subject to } \quad \forall y: q\left(y \mid z_{i}\right) \in[0,1] \text { and } \sum_{i=1}^{N} q\left(y \mid z_{i}\right)=\frac{N}{K}\tag 7

公式7添加的限制,意味着一个batch中的N个样本要均匀分配到K组。公式7描述的其实是最优传输问题,最终可以转化成如下形式:

minp,qLconsistency =minQTQ,logPlogN(10)\min _{p, q} \mathcal{L}_{\text {consistency }}=\min _{Q \in \mathbf{T}}\langle Q,-\log P\rangle-\log N\tag{10}

公式10的最优解可以通过Sinkhorn-Knopp算法解决。

关于这部分的详细内容可以看论文原文和https://lccurious.github.io/2020/01/30/optimal-transport/这篇博客。

2.2 对比

这部分重点介绍作者如何解决现有GCL方法中存在的负样本偏差问题。

作者认为对比学习中不同负样本有真假之分,有些是“true” negative examples,有些是“false”negative examples。怎么区分负样本的真假呢?作者采用的方法很简单,就是利用前面聚类的结果,和锚点处于同一簇的负样本都是“true”,反之为“false”。这样对比损失就定义为:

L=i=1nlogexp(zizi/τ)exp(zizi/τ)+j=1N1cicjexp(zizj/τ)(12)\mathcal{L}=-\sum_{i=1}^{n} \log \frac{\exp \left(\boldsymbol{z}_{i} \cdot \boldsymbol{z}_{i}^{\prime} / \tau\right)}{\exp \left(\boldsymbol{z}_{i} \cdot \boldsymbol{z}_{i}^{\prime} / \tau\right)+\sum_{j=1}^{N} \mathbb{1}_{\mathbf{c}_{i} \neq \mathbf{c}_{j}} \cdot \exp \left(\boldsymbol{z}_{i} \cdot \boldsymbol{z}_{j}^{\prime} / \tau\right)}\tag{12}

其中cic_icjc_j分别表示GiG_iGjG_j所处聚类簇的原型向量。另外,作者认为“true”负样本也有质量高低之分。

具体来说,作者认为从直觉上来说理想的负样本所在簇和锚点所在簇之间应该有一个合适的距离,不能太近,也不能太远。如下图3所示:

如果两者相距太远(比如上图紫色点),说明该簇中负样本和锚点之间太容易被区分了,对模型没啥用。反之如果太近(比如上图蓝色点),说明两者过于相似,那这个负样本和正样本差不多了,也不利于模型学习。因此作者重新定义如下对比损失:

LReweighted =i=1nlogexp(zizi/τ)exp(zizi/τ)+Mij=1N1cicjwijexp(zizj/τ)(13)\mathcal{L}_{\text {Reweighted }}=-\sum_{i=1}^{n} \log \frac{\exp \left(z_{i} \cdot \boldsymbol{z}_{i}^{\prime} / \tau\right)}{\exp \left(\boldsymbol{z}_{i} \cdot \boldsymbol{z}_{i}^{\prime} / \tau\right)+M_{i} \sum_{j=1}^{N} \mathbb{1}_{\mathbf{c}_{i} \neq \mathbf{c}_{j}} \cdot \boldsymbol{w}_{i j} \cdot \exp \left(\boldsymbol{z}_{i} \cdot \boldsymbol{z}_{j}^{\prime} / \tau\right)}\tag{13}

其中wijw_{ij}表示权重,Mi=Nj=1NwijM_i=\frac{N}{\sum_{j=1}^Nw_{ij}}表示正则化因子。权重计算方式采用两个原型之间的余弦距离D(ci,cj)=1cicjci2cj2\mathcal D(c_i,c_j)=1-\frac{\mathbf{c}_{i} \cdot \mathbf{c}_{j}}{\left\|\mathbf{c}_{i}\right\|_{2}\left\|\mathbf{c}_{j}\right\|_{2}}wijw_{ij}计算方式如下:

wij=exp{[D(ci,cj)μi]22σi2}(14)\boldsymbol{w}_{i j}=\exp \left\{-\frac{\left[\mathcal{D}\left(\mathbf{c}_{i}, \mathbf{c}_{j}\right)-\mu_{i}\right]^{2}}{2 \sigma_{i}^{2}}\right\}\tag{14}

至此将对比损失和聚类损失加到一起就是整个模型的损失函数:

L=LReweighted +λLConsistency (15)\mathcal{L}=\mathcal{L}_{\text {Reweighted }}+\lambda \mathcal{L}_{\text {Consistency }}\tag{15}

3. 实验

3.1 对比实验

可以看到,和无监督baselines相比,作者在所有数据集上都取得了最优结果。PGCL性能和有监督方法相比性能相当,在个别数据集上表现更优。

3.2 消融/可视化

不同损失函数: 作者分析了PGCL中聚类和负采样策略的有效性

原型数目K和batch大小N: 作者分析了不同K值和batch size对模型性能的影响

可视化: 作者用t-SNE对学习到的图表示进行了可视化,每个数据集中K都取10

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信