Deep Graph Infomax

https://arxiv.org/pdf/1809.10341

https://github.com/PetarV-/DGI

Deep Graph Infomax ,2019,ICLR

总结:图对比学习的鼻祖。作者首次提出基于互信息的无监督图表示学习方法DGI,用于节点分类任务。DGI通过最大化local表示和global表示之间的互信息来节点表示。作者还从理论上证明了最大化DGI的目标函数和图表示学习中的最大化互信息是等价的。

1. 简介

1.1 摘要

We present Deep Graph Infomax (DGI), a general approach for learning node representations within graph-structured data in an unsupervised manner. DGI relies on maximizing mutual information between patch representations and corresponding high-level summaries of graphs—both derived using established graph convolutional network architectures. The learnt patch representations summarize subgraphs centered around nodes of interest, and can thus be reused for downstream node-wise learning tasks. In contrast to most prior approaches to unsupervised learning with GCNs, DGI does not rely on random walk objectives, and is readily applicable to both transductive and inductive learning setups. We demonstrate competitive performance on a variety of node classification benchmarks, which at times even exceeds the performance of supervised learning.

我们展示了DGI模型,一种通用的无监督图表示学习方法。DGI依赖于最大化patch表示和对应高层级图表示之间的互信息。学习到的patch表示融合了节点周围感兴趣的子图,因此可以在下游node-wise学习任务中复用。和现有的GCNs无监督方法相比,DGI不依赖于随机游走目标,并且同时适用于transductive和inductive设定。我们在各种节点分类标准数据集上证明了其出色的表现,并且有时候其性能还要优于有监督学习。

1.2 本文工作

背景: 最近几年,虽然GNNs被广泛用于各种图学习任务,并且取得了很大成功,但是这些方法大多数都是有监督的,依赖于大量有标签数据。因此,发展无监督图学习方法对许多任务来说是很重要的。

动机: 现有的无监督图学习方法都依赖于基于随机游走的目标函数,有时候甚至直接通过重构邻接矩阵来学习图表示。虽然随机游走很强大,但是也存在很大局限性,即过度强调邻近信息,并且模型性能过于依赖超参数的选择。

本文工作: 作者首次提出一种基于互信息的无监督图学习方法DGI(以前的都是基于随机游走或者邻接矩阵重构)。

2. 方法

DGI: 可以看作local-global对比学习方法,通过最大化node representation和graph representation之间的互信息来学习最终的节点嵌入。

2.1 符号定义

h\vec h表示节点嵌入:

  • 为了得到图嵌入,需要定义一个readout函数R:RN×FRF\mathcal R:\mathbb R^{N\times F}\rightarrow\mathbb R^F,这样图表示可以定义为s=R(E(X,A))\vec{s}=\mathcal{R}(\mathcal{E}(\mathbf{X}, \mathbf{A}))

  • 为了最大化节点和图之间的互信息,需要定义一个判别器:D:RFR\mathcal D:\mathbb R^F\rightarrow\mathbb RD(hi,s)\mathcal D(\vec h_i,\vec s)表示两者之间的互信息。

  • 单图设定下,为了获取负样本需要定义一个扰动函数C:RN×F×RN×NRM×F×RM×M\mathcal{C}: \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{M \times F} \times \mathbb{R}^{M \times M},扰动后的图表示为(X~,A~)=C(X,A)(\widetilde{\mathbf{X}}, \widetilde{\mathbf{A}})=\mathcal{C}(\mathbf{X}, \mathbf{A})

得到节点嵌入和正负样本图嵌入后,模型的优化目标定义为:

L=1N+M(i=1NE(X,A)[logD(hi,s)]+j=1ME(X~,A~)[log(1D(h~j,s))])\mathcal{L}=\frac{1}{N+M}\left(\sum_{i=1}^{N} \mathbb{E}_{(\mathbf{X}, \mathbf{A})}\left[\log \mathcal{D}\left(\vec{h}_{i}, \vec{s}\right)\right]+\sum_{j=1}^{M} \mathbb{E}_{(\tilde{\mathbf{X}}, \widetilde{\mathbf{A}})}\left[\log \left(1-\mathcal{D}\left(\overrightarrow{\widetilde{h}}_{j}, \vec{s}\right)\right)\right]\right)

2.2 DGI流程

  1. 通过图扰动函数得到一个负样本图:(X~,A~)C(X,A)(\widetilde{\mathbf{X}}, \widetilde{\mathbf{A}}) \sim \mathcal{C}(\mathbf{X}, \mathbf{A})
  2. 图编码器计算原始图节点嵌入:H=E(X,A)={h1,h2,,hN}\mathbf{H}=\mathcal{E}(\mathbf{X}, \mathbf{A})=\left\{\vec{h}_{1}, \vec{h}_{2}, \ldots, \vec{h}_{N}\right\}
  3. 图编码器计算扰动图节点嵌入:H~=E(X~,A~)={h~1,h~2,,h~M}\widetilde{\mathbf{H}}=\mathcal{E}(\widetilde{\mathbf{X}}, \widetilde{\mathbf{A}})=\left\{\overrightarrow{\widetilde{h}}_{1}, \overrightarrow{\widetilde{h}}_{2}, \ldots, \overrightarrow{\widetilde{h}}_{M}\right\}
  4. 通过Readout函数计算图嵌入:s=R(H)\vec{s}=\mathcal{R}(\mathbf{H})
  5. 通过最大化L=1N+M(i=1NE(X,A)[logD(hi,s)]+j=1ME(X~,A~)[log(1D(h~j,s))])\mathcal{L}=\frac{1}{N+M}\left(\sum_{i=1}^{N} \mathbb{E}_{(\mathbf{X}, \mathbf{A})}\left[\log \mathcal{D}\left(\vec{h}_{i}, \vec{s}\right)\right]+\sum_{j=1}^{M} \mathbb{E}_{(\tilde{\mathbf{X}}, \widetilde{\mathbf{A}})}\left[\log \left(1-\mathcal{D}\left(\overrightarrow{\widetilde{h}}_{j}, \vec{s}\right)\right)\right]\right)更新图编码器参数。

模型框架如下图1所示:

## 2.3 理论证明

这部分,作者展示了DGI目标函数和图表示学习中互信息在理论上的关联。

实力有限,大家可以看原文!

3. 实验

3.1 实验设定

3.1.1 Transductive

使用Cora、Citeseer和Pubmed三个数据集,图编码器采用单层GCN :

E(X,A)=σ(D^12A^D^12XΘ)\mathcal{E}(\mathbf{X}, \mathbf{A})=\sigma\left(\hat{\mathbf{D}}^{-\frac{1}{2}} \hat{\mathbf{A}} \hat{\mathbf{D}}^{-\frac{1}{2}} \mathbf{X} \Theta\right)

其中A^=A+IN\hat{\mathbf{A}}=\mathbf{A}+\mathbf{I}_{N}表示带自环的邻接矩阵,D^\hat{\mathbf D}表示对应的度矩阵,非线性函数σ\sigma使用ReLU,隐藏层维度512(pubmed中由于内存原因,维度为256)。

扰动函数C\mathcal C:保留原始图邻接矩阵不变,即A~=A\widetilde A=A,将特征矩阵XX进行随机row-wise shuffle。

3.1.2 Inductive

一、单个图(Reddit)

和GraphSAGE-GCN类似,采用平均池化传播规则:

MP(X,A)=D^1A^XΘ\mathrm{MP}(\mathbf{X}, \mathbf{A})=\hat{\mathbf{D}}^{-1} \hat{\mathbf{A}} \mathbf{X} \Theta

编码器采用带skip connections的三层平均池化模型:

MP~(X,A)=σ(XΘMP(X,A))E(X,A)=MP^3(MP2^(MP^1(X,A),A),A)\widetilde{\mathrm{MP}}(\mathrm{X}, \mathrm{A})=\sigma\left(\mathrm{X} \Theta^{\prime} \| \operatorname{MP}(\mathrm{X}, \mathrm{A})\right) \quad \mathcal{E}(\mathrm{X}, \mathrm{A})=\widehat{\mathrm{MP}}_{3}\left(\widehat{\mathrm{MP}_{2}}\left(\widehat{\mathrm{MP}}_{1}(\mathrm{X}, \mathbf{A}), \mathrm{A}\right), \mathbf{A}\right)

另外,在大型数据集上,作者采用minibatch方式训练模型,每次按照10,10,25个邻居采样子图,每个子图包含1+10+100+2500=2611个节点。

二、多个图(PPI)

受GAT的启发,编码器采用dense skip 连接:

H1=σ(MP1(X,A))H2=σ(MP2(H1+XWskip ,A))E(X,A)=σ(MP3(H2+H1+XWskip ,A))\begin{aligned} \mathbf{H}_{1} &=\sigma\left(\mathrm{MP}_{1}(\mathbf{X}, \mathbf{A})\right) \\ \mathbf{H}_{2} &=\sigma\left(\mathrm{MP}_{2}\left(\mathbf{H}_{1}+\mathbf{X W}_{\text {skip }}, \mathbf{A}\right)\right) \\ \mathcal{E}(\mathbf{X}, \mathbf{A}) &=\sigma\left(\mathrm{MP}_{3}\left(\mathbf{H}_{2}+\mathbf{H}_{1}+\mathbf{X} \mathbf{W}_{\text {skip }}, \mathbf{A}\right)\right) \end{aligned}

其中WskipW_{skip}表示可学习映射矩阵,MPMP和前文中定义的一样。在多图场景下,不在使用扰动函数创建负样本,而是随机选取训练集中其他图作为负样本。

3.1.3 其他

Readout函数: 所有节点平均值,R(H)=σ(1Ni=1Nhi)\mathcal{R}(\mathbf{H})=\sigma\left(\frac{1}{N} \sum_{i=1}^{N} \vec{h}_{i}\right)

判别器: 采用一个简单的双线性评分函数,D(hi,s)=σ(hiTWs)\mathcal{D}\left(\vec{h}_{i}, \vec{s}\right)=\sigma\left(\vec{h}_{i}^{T} \mathbf{W} \vec{s}\right)

其他: 学习率0.001,采用Adam SGC优化器,patience设置为20。

3.2 结果

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信