Adaptive-Step Graph Meta-Learner for Few-Shot Graph Classification

https://arxiv.org/pdf/2003.08246

Adaptive-Step Graph Meta-Learner for Few-Shot Graph Classification,CIKM,2020

1. 简介

1.1 摘要

Graph classification aims to extract accurate information fromgraph-structured data for classification and is becoming more andmore important in graph learning community. Although GraphNeural Networks (GNNs) have been successfully applied to graph classification tasks, most of them overlook the scarcity of labeledgraph data in many applications. For example, in bioinformatics, obtaining protein graph labels usually needs laborious experiments. Recently, few-shot learning has been explored to alleviate this prob-lem with only a few labeled graph samples of test classes. The shared sub-structures between training classes and test classes are essential in few-shot graph classification. Existing methods assume that the test classes belong to the same set of super-classes clus-tered from training classes. However, according to our observations, the label spaces of training classes and test classes usually do not overlap in real-world scenario. As a result, the existing methods don’t well capture the local structures of unseen test classes. To overcome the limitation, in this paper, we propose a direct method to capture the sub-structures with well initialized meta-learner within a few adaptation steps. More specifically, (1) we propose a novel framework consisting of a graph meta-learner, which uses GNNs based modules for fast adaptation on graph data, and a step controller for the robustness and generalization of meta-learner; (2) we provide quantitative analysis for the framework and give a graph-dependent upper bound of the generalization error based on our framework; (3) the extensive experiments on real-world datasets demonstrate that our framework gets state-of-the-art results on several few-shot graph classification tasks compared to baselines.

图分类旨在从图结构数据中提取准确信息用于分类,在图学习领域越来越重要了。尽管GNNs已经成功应用于图分类任务中,但是大多数方法都忽略了许多场景下存在有标签数据稀疏问题。例如,在生物信息领域,获取蛋白质图的标签通常需要进行实际实验,难以获取大量有标签数据。最近提出的小样本学习正是用来解决这一问题的,只需要少量有标签样本就能取得不错的效果。现有的方法都是假设用于测试的类别和用于训练的类别都属于同一个超类(什么意思?)但是根据我们的观察,真实场景下,训练类别的标签空间和测试类别的标签空间通常是不重叠的。因此,现有的方法不能很好地捕捉到不可见的测试类别的局部结构。为了解决这个问题,本文作者提出了一个新的方法,利用已经初始化过的meta-learner只需要优化很少次就能捕捉到sub-structures。具体来说:(1)作者提出了一个架构,包含一个graph meta-learner(用于图数据场景下的快速适应)以及一个step controller(用于保证meta-learner的稳定性和通用性);(2)作者对提出的架构进行了分析,给出了一个通用的错误上界;(3)真实数据集上的大量实验表明作者提出的架构在一些小样本图分类任务下性能优于baseline方法。

1.2 本文工作

虽然很多GNNs方法已经成功应用于图分类任务,但是这些方法都忽略了有标签图数据获取代价很高这一现实问题。如何只通过少量有标签样本完成图分类任务十分具有挑战。目前只有Chauhan等人正对小样本下的图分类任务,基于graph spectral measures提出了一种模型并且取得了不错的效果。从数据集的全局结构来看,Chauhan通过假设测试类别所属的超类属于从训练类别聚类得到的超类,来关联测试类别和训练类别。这种基于图谱度量的方法存在一些弊端:(1)在小样本设定下,训练类别和测试类别的标签空间通常不重合;(2)这种关联方式会削弱模型捕捉测试数据局部结构信息的能力。

从图的局部结构来看,作者发现训练类别和测试类别具有相似的子结构。例如,不同的社交网络通常有一些相似的群组;不同的蛋白质分子通常有相似的刺突蛋白。作者发现通过一个well initialized的meta-learner,经过少数次调整后就能捕捉到这种相似性。下图展示了Chauhan的模型和作者提出的模型:

现有GNNs通过卷积和池化操作捕捉局部结构的能力已经很强了,但是当处理从来没有见过的graph类别时无法快速适应。受MAML的启发,作者利用GNNs作为图嵌入骨架、元学习作为训练策略,在图分类任务中快速捕捉task-specific知识并将捕捉到的知识传递到新的任务中。但是直接使用MAML来实现快速适应是sub-optimal的:

  • 和图片不同,图的节点数和子结构是任意的,会给自适应带来不确定性
  • MAML需要复杂的超参搜索,提高模型稳定性和泛化能力

已经有一些MAML的变体来解决上述问题,但是它们都不是针对图结构数据的。本文作者提出了一种adaptive step controller,来学习一个自适应的opimal step,提高模型稳定性和泛化能力。controller根据两个输入来决定什么时候终止adaptation:

  1. 图嵌入的质量,即元特征,反映了平均节点信息
  2. meta-learner的训练状态,反映了训练分类损失

作者提出的模型命名为AS-MAML,即Adaptive Step MAML,是首个从图局部结构考虑小样本图分类问题的模型,并通过元学习实现图上的快速适应。

2. 方法

2.1 问题定义

定义N-way-K-shot图分类问题。

  1. 给定图数据G={(G1,y1),(G2,y2),,(Gn,yn)}\mathcal G=\left\{\left(G_{1}, \mathbf{y}_{1}\right),\left(G_{2}, \mathbf{y}_{2}\right), \cdots,\left(G_{n}, \mathbf{y}_{n}\right)\right\},其中Gi=(Vi,Ei,Xi)G_{i}=\left(\mathcal{V}_{i}, \mathcal{E}_{i}, \mathbf{X}_{i}\right),用ni=Vin_i=|\mathcal V_i|表示节点数量,故AiRni×ni\mathbf{A}_{i} \in \mathbb{R}^{n_{i} \times n_{i}}XiRni×d\mathbf{X}_{i} \in \mathbb{R}^{n_{i} \times d},d表示节点属性的维度。

  2. 根据图标签y\mathbb y,将G\mathcal G分成{(Gtrain ,ytrain )}\left\{\left(G^{\text {train }}, \mathbf{y}^{\text {train }}\right)\right\} and {(Gtest ,ytest )}\left\{\left(\mathcal{G}^{\text {test }}, \mathbf{y}^{\text {test }}\right)\right\}分别表示训练集和测试集,ytrain\mathbb y^{train}ytest\mathbb y^{test}不重合。

采用episodic方法训练模型,在训练阶段每次采样一个任务T\mathcal T,每个任务包含支持集Dsuptrain={(Gitrain,yitrain)}i=1sD_{s u p}^{t r a i n}=\left\{\left(G_{i}^{t r a i n}, \mathbf{y}_{i}^{t r a i n}\right)\right\}_{i=1}^{s}和查询集Dquetrain ={(Gitrain ,yitrain )}i=1qD_{q u e}^{\text {train }}=\left\{\left(G_{i}^{\text {train }}, \mathbf{y}_{i}^{\text {train }}\right)\right\}_{i=1}^{q},其中s和q分别表示支持集和查询集大小。

  • **训练阶段:**给定有标签的支持集,模型目标是最小化查询集的分类损失。如果s=N×Ks=N\times K即支持集包含N个类别,每个类别有K个有标签样本,我们称该任务为N-way-K-shot小样本分类任务。

  • **测试阶段:**使用支持集样本对模型进行微调,计算模型在查询集中的分类表现

2.2 架构

整个小样本图分类架构包含两部分:

  • GNNs为骨架的meta-learner:采用MAML作为快速适应机制
  • step controller:Du等人提出一种基于RL的step controller来指导meta-learner用于链路预测,作者本文提出了一种新的step controller来加速训练,避免过拟合。

2.2.1 Graph Embedding

这部分由图卷积模块和池化模块构成嵌入骨架。采用GraphSAGE获取图中的节点嵌入:

hvl=σ(Wmean({hvl1}{hul1,uN(v)}).\mathbf{h}_{v}^{l}=\sigma\Big(\mathbf{W} \cdot \operatorname{mean}(\{\mathbf{h}_{v}^{l-1}\} \cup\{\mathbf{h}_{u}^{l-1}, \forall u \in \mathcal{N}(v)\}\Big).

其中hvlh_v^l表示第l层节点v的表示,σ\sigma是sigmoid函数,W\mathbf W是参数,V(v)\mathcal V(v)表示节点v的所有邻居。

对于池化操作,鉴于在小样本设定下,元学习器需要一种灵活的池化策略来加强模型的泛化能力。本文作者采用SAGPool作为池化层,因为其具有灵活的注意力参数。SAGPool的主要步骤就是计算一个注意力得分矩阵:

Si=σ(D~i12Λ~iD~i12XiΘatt)\mathrm{S}_{i}=\sigma\left(\tilde{\mathbf{D}}_{i}^{-\frac{1}{2}} \tilde{\Lambda}_{i} \tilde{\mathbf{D}}_{i}^{-\frac{1}{2}} \mathbf{X}_{i} \Theta_{a t t}\right)

其中SiS_i表示自注意力得分,nin_i表示图中节点数量;σ\sigma表示激活函数比如tanh;A~iRni×ni\tilde A_i\in\mathbb R^{n_i\times n_i}表示带有self-connection的邻接矩阵;D~iRni×ni\tilde D_i\in\mathbb R^{n_i\times n_i}表示度对角矩阵;XiRni×dX_i\in\mathbb R^{n_i\times d}表示维度为d的输入特征;Θatt\Theta_{att}表示可学习的注意力参数。基于注意力得分,选取等分排名前c的节点,保持其边不变。

为了保证得到维度固定的图嵌入,我们还需要一个Read-Out操作,为每个图生成一个固定维度的特征向量。参照Zhang等人的做法,作者通过拼接平均池化和最大池化来计算每一层图嵌入:

ril=R(Hil)=σ(1nilp=1Hil(p,:)maxq=1dHil(:,q))\mathbf{r}_{i}^{l}=\mathcal{R}\left(\mathbf{H}_{i}^{l}\right)=\sigma\left(\frac{1}{n_{i}^{l}} \sum_{p=1} \mathbf{H}_{i}^{l}(p,:) \| \max _{q=1}^{d} \mathbf{H}_{i}^{l}(:, q)\right)

其中rilR2d\mathbf{r}_{i}^{l} \in \mathbb{R}^{2 d}表示第l层的嵌入;niln_i^l表示第l层的节点数量;HilH_i^l表示第l层的节点表示矩阵;||表示拼接操作;σ\sigma是激活函数比如ReLU。得到每一层的图嵌入后,按照如下方式计算最终的图嵌入:

zi=ri1+ri2++riL\mathbf{z}_{i}=\mathbf{r}_{i}^{1}+\mathbf{r}_{i}^{2}+\cdots+\mathbf{r}_{i}^{L}

得到最终图嵌入后,将其放进MLP分类其中进行分类,计算交叉熵损失。

2.2.2 Fast Adaptation

θe\theta_eθc\theta_c分别表示图嵌入和MLP分类器的参数,为了使其快速适应新图,作者采用MAML中的调优方法训练模型。算法流程如下图所示:

2~17行为一个外循环,3 ~16行为一个内循环。

2.2.3 Adaptation Controller

对于MAML模型,找到learning rate和step size之间的最优组合十分困难,尤其在图数据中图的结构和大小都是任意的。本文作者设计了一个基于RL的controller来获取最优的step size。

作者利用ANI(average node information)反映节点嵌入的质量,GiG_i计算方法如下:

ANIil=1nilj=1[(Iil(Dil)1Ail)Hil]j1(5)A N I_{i}^{l}=\frac{1}{n_{i}^{l}} \sum_{j=1}\left\|\left[\left(\mathbf{I}_{i}^{l}-\left(\mathbf{D}_{i}^{l}\right)^{-1} \mathbf{A}_{i}^{l}\right) \mathbf{H}_{i}^{l}\right]_{j}\right\|_{1}\tag 5

其中l表示嵌入的层数;niln_i^l表示节点数;j表示行index或者是j-th节点;1\|\cdot\|_{1}表示L1L_1范式;AilA_i^l表示邻接矩阵;DilD_i^l表示度矩阵;HilH_i^l表示第l层隐藏表示矩阵。作者采用一个标量值作为ANI,计算方法如下:

ANI=1/ni=1nANIiL(6)A N I=1 / n * \sum_{i=1}^{n} A N I_{i}^{L}\tag 6

其中n表示batch的大小;L表示第L层的embedding。

TiT_i表示初始的step大小;MRTi×1\mathbf M\in\mathbb R^{T_i\times 1}表示TiT_isteps内所有的ANI值。然后计算在步骤t时刻的停止概率:

h(t)=LSTM([L(t),M(t)],h(t1)),p(t)=σ(Wh(t)+b)(7)\boldsymbol{h}^{(t)}=\operatorname{LSTM}\left(\left[\mathbf{L}^{(t)}, \mathbf{M}^{(t)}\right], \boldsymbol{h}^{(t-1)}\right), p^{(t)}=\sigma\left(\mathbf{W h}^{(t)}+\mathbf{b}\right)\tag 7

σ\sigma表示sigmoid函数;h(t)h^{(t)}表示LSTM模块的输出。需要注意的是当前任务不会提前终止,知道TiT_i步骤执行完毕了,即是否终止adaptation和p(t)p^{(t)}无关。利用p(t)p^{(t)}计算下一个任务的step size:

Ti+1=1p(Ti)(8)T_{i+1}=\left\lfloor\frac{1}{p^{\left(T_{i}\right)}}\right\rfloor\tag 8

可以得到controller的损失如下:

Q(t)=t=1Tr(t)=t=1T(eTetηt)(9)Q^{(t)}=\sum_{t=1}^{T} r^{(t)}=\sum_{t=1}^{T}\left(e_{T}-e_{t}-\eta * t\right)\tag 9

T表示总的步数;ete_t表示步骤t时查询集上的分类准确度;ηt\eta * t为惩罚项。采用如下方式进行梯度更新:

θs=θs+α3Q(t)θslnp(t)\theta_s=\theta_s+\alpha_3\mathcal Q^{(t)}\bigtriangledown_{\theta_s}lnp(t)

2.3

3. 实验

3.1 准备

两个目的:(1)模型在小样本图分类任务中的表现如何?(2)controller是如何工作的?

数据集:

COIL-DEL、R52、Letter-High和TRIANGLES四个公共数据集。

3.2实验结果

和Graph Kernel、Finetuning、GNNs-Pro对比

  1. 和基于GraphSAGE和SAGPool的微调方法相比,作者的模型性能更好,说明meta-learner工作良好。

  2. 在Graph-R52数据集上,基于kernel的方法性能很好,作者认为可能有两个原因:

    • 数据集的文本图中有许多定义良好的子图,这些子图有文本主题及其邻近的次组成,这对kernel方法有利
    • 基于kernel的方法参数远远少于GNNs,这可以有效防止过拟合。而过拟合在小样本场景下是GNNs模型常见的问题之一。

    虽然kernal方法效果不错,但是很难找到一种合适的kernel。

和GSM对比

从t-SNE结果来看,作者的方法确实由于GSM方法。因为作者假设测试类别的超类属于从训练类别中聚类得到的超类集合,但这在小样本场景下是不成立的(因为训练类别和测试类别通常不重叠)。

3.3 消融实验

根据表2的实验结果可以发现作者提出的step controller在模型中有很大作用。

上图展示了在COIL-DEL数据集中,5-way-10-shot设定下的学习过程。

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信