Graph Prototypical Networks for Few-shot Learning on Attributed Networks

https://arxiv.org/pdf/2006.12739

https://github.com/kaize0409/GPN

Graph Prototypical Networks for Few-shot Learning on Attributed Networks,CIKM,2020

总结:整个思想是原型网络+GNN,将元学习中的原型网络拓展到属性网络中。另外,对原型网络中原型计算方法针对属性网络做了改进,即计算节点对其所属类别的重要程度z作为计算类别原型时的权重。本文实验部分还比较充实,作为图学习领域为数不多的小样本学习方法,总体来说还可以。

1. 简介

1.1 摘要

Attributed networks nowadays are ubiquitous in a myriad of high-impact applications, such as social network analysis, financial frauddetection, and drug discovery. As a central analytical task on at-tributed networks, node classification has received much attentionin the research community. In real-world attributed networks, alarge portion of node classes only contains limited labeled instances,rendering a long-tail node class distribution. Existing node clas-sification algorithms are unequipped to handle thefew-shotnodeclasses. As a remedy, few-shot learning has attracted a surge ofattention in the research community. Yet, few-shot node classifi-cation remains a challenging problem as we need to address thefollowing questions: (i) How to extract meta-knowledge from anattributed network for few-shot node classification? (ii) How toidentify the informativeness of each labeled instance for building arobust and effective model? To answer these questions, in this paper,we propose a graph meta-learning framework – Graph PrototypicalNetworks (GPN). By constructing a pool of semi-supervised nodeclassification tasks to mimic the real test environment, GPN is ableto performmeta-learningon an attributed network and derive ahighly generalizable model for handling the target classificationtask. Extensive experiments demonstrate the superior capability ofGPN in few-shot node classification .

如今属性网络普遍存在于很多应用场景中,比如社交网络分析、金融欺诈检测以及药物发现等等。节点分类作为属性网络上最重要的任务之一吸引了很多人的关注。在现实世界的属性网络中呈长尾分布,即存在很多类别的节点数量很少,现有的节点分类算法很难处理这些样本数量很少的类别。最近小样本学习得到了业内很多人的关注。但是小样本节点分类仍然具有很大挑战,需要解决以下问题:(1)如何从属性网络中提取元知识用于小样本节点分类?(2)如何确定每个有标签节点蕴含的信息,来建立稳定有效的模型?为了解决这些问题,本文作者提出了一种元学习架构——Graph Prototypical Networks(GPN)。通过构建一个半监督节点分类任务池来模拟真实的测试环境,GPN可以在属性网络上进行元学习,得到一个泛化能力很好地模型来处理目标分类任务。大量实验证明了作者提出GPN模型在下样本场景下的优越性。

1.2 本文工作

1.2.1 背景

在许多真实世界的属性网络中,许多节点类别只包含少量有标签节点,呈现一种长尾分布。如下图所示,展示了DBLP数据集中每个类别包含的样本数量:

超过30%的类别拥有的有标签样本数量少于10个。另外,许多实际应用场景也需要模型能够处理这种小样本问题。一个典型的例子就是Traffic网络上的入侵检测问题,这种情况下,攻击者不断开发新的攻击和威胁。但是标记成本高,对于一些特定类型的攻击,只能获取到少量样本。因此只通过少量标记样本了解这些攻击类型,来提出有效的对策是很重要的。因此在小样本设定下研究节点分类模型是十分重要的。

1.2.2 挑战

虽然小样本学习已经取得了不错的成功,但是属性网络上的小样本学习仍然有待探索,面临着下面两个挑战:

  • 元训练任务的构建过程依赖于这样一个假设:数据是独立同分布的,但是这个假设在属性网络中并不成立。直接将现有的FSL方法应用到属性网络中是不可行的,无法捕捉到潜在的数据结构,使得最终学习到的节点表示很差。因此如何在将元学习应用到属性网络,是从数据中提取元知识的必要条件。
  • 现有的大多数FSL方法都假设所有有标签样本重要程度是一样的,但是在真实属性网络中,忽略有标签样本的individual信息会限制模型的性能。一方面这会使得FSL模型对噪声和异常点十分敏感,另一方面属性网络中节点的重要程度差异性可能很大(通常一个community中的中心节点可能更representative),这和传统FSL方法的假设不相符。因此如何在属性网络中捕捉到每个节点的信息是另一个挑战。

1.2.3 作者方法

针对上面两个挑战,作者提出了Graph Prototypical Networks(GPN)模型,一种用于解决属性网络上小样本节点分类问题的元学习架构。GPN尝试学习一个可迁移的度量空间,在整个空间中根据测试节点和各个类别原型之间的距离预测测试节点所属类别。GPN包含两个重要的组成部分:

  • network encoder:通过GNNs学习网络中的节点表示
  • node valuator:GNN-based,利用编码在网络中的额外信息,评估每个有标签样本的informativeness,帮助GPN计算出稳定的表示能力强的类别原型

另外通过构建一个用于元学习的任务池,GPN可以从属性网络中提取出元知识,实现更好地泛化能力,可以用于目标阿小样本分类任务。

1.2.4 问题定义

文章中所有的符号含义如下表所示:

属性网络 G=(V,E,X)G=(\mathcal V,\mathcal E,\mathbf X)V={v1,...,vn}\mathcal V=\{v_1,...,v_n\}表示所有节点,E={e1,...,em}\mathcal E=\{e_1,...,e_m\}表示所有边,X=[x1;...;xn],xiR1×d\mathbf X=[x_1;...;x_n],x_i\in\mathbb R^{1\times d}表示节点属性特征向量。更通用的一种表示为G=(A,X)G=(A,X),其中A={0,1}n×nA=\{0,1\}^{n\times n}为邻接矩阵,Ai,j=0A_{i,j}=0表示viv_ivjv_j之间没有边,Ai,j=1A_{i,j}=1表示viv_ivjv_j之间有边。问题定义为:

  • 给定一个属性网络G={A,X}\mathcal G=\{A,X\},假设对于类别集合CtrainC_{train},每个类别都有充足的有标签节点。模型经过CtrainC_{train}中的数据训练后,希望模型能够准确预测disjoint类别集合CtestC_{test}中节点的类别。需要注意的时每个类别只有少量有标签节点可利用。
  • 按照通用的FSL设定,如果CtestC_{test}中包含N个类别,支持集SS包含K个有标签节点,该问题称之为N-way-K-shot节点分类问题。整个模型的目标是学习一个元分类器,能够适用于新任务(还有少量有标签节点的新类任务)。

因此如何从CtrainC_{train}中提取可迁移的元知识是解决这个问题的关键。

2. 模型

GPN模型主要为了解决下面三个问题而设计的:

  1. 如何将元学习应用到属性网络(不满足i.i.d)中,提取元知识?
  2. 如何同时考虑节点属性和网络拓扑结构学习expressive的节点表示?
  3. 如何区分每个有标签节点的informativeness,即重要程度,学习稳定、具有判别力的类别原型?

整个模型的架构如下图所示:

2.1 Episodic训练

GPN在很多meta-training任务之间不断迭代训练。在每个episode,构建一个N-way-K-shot元训练任务:

St={(v1,y1),(v2,y2),...,(vN×K,yN×K)},Qt={(v1,y1),(v2,y2),...,(vN×K,yN×K)},Tt={St,Qt}(1)\begin{aligned} S_t&=\{(v_1,y_1),(v_2,y_2),...,(v_{N\times K},y_{N\times K})\},\\ Q_t&=\{(v_1^*,y_1^*),(v_2^*,y_2^*),...,(v^*_{N\times K},y^*_{N\times K})\},\\ \mathcal T_t&=\{S_t,Q_t\} \end{aligned}\tag 1

其中StS_tQtQ_t采样自CtrainC_{train},分别是元训练任务Tt\mathcal T_t的支持集和查询集。支持集StS_t中每个类别有K个有标签节点,查询集QtQ_t包含M个查询节点。整个训练过程都是在任务集合Ttrain={Tt}t=1T\mathcal T_{train}=\{\mathcal T_t\}_{t=1}^T,训练目标是最小化每个任务中查询集上的分类误差。模型不断在这些元任务上进行训练,可以逐渐学到元知识,将模型推广到元测试任务Ttest={S,Q}\mathcal T_{test}=\{S,Q\}(样本采样自不可见类别CtestC_{test})。

和传统的episodic训练不同,作者采用的半监督训练方式:在每个episode,作者采样N-way K-shot个节点作为有标签节点,mask其余节点作为unlabeld节点。通过这种方式构建半监督节点分类任务,模型可以学到更具有表征能力的节点表示用于小样本节点分类。

2.2 网络表示学习

一、节点表示

利用GNNs学习节点表示,具体来说通过堆叠多个GNN层学习节点表示:

H1=GNN1(A,X)...Z=GNNL(A,HL1)(3)\begin{aligned} &H^1=GNN^1(A,X)\\ &...\\ &Z=GNN^L(A,H^{L-1}) \end{aligned}\tag 3

作者用fθ()f_\theta(·)表示所有L个GNN层。

二、类别原型

学习到节点表示后,需要为每个类别计算原型向量,即表示这个类别整体特征的向量:

pc=PROTO({ziiSc})(4)p_c=PROTO\Big(\{z_i|\forall i\in S_c\}\Big)\tag 4

其中ScS_c表示类别c中所有有标签节点,PROTOPROTO表示原型计算函数。一种简单有效的方式就是计算该类别下所有节点表示的平均值作为该类别的原型:

pc=1SciSczi(5)p_c=\frac{1}{|S_c|}\sum_{i\in S_c}z_i\tag 5

2.3 节点重要性评估

尽管在计算类别原型的时候采用平均值方法也可以取得不错的效果,但是否认的是该类别下不同节点的重要程度是不同的。另外按照平均值方法计算类别原型会导致算法很不稳定,对噪声数据很敏感。因此作者优化了类别原型计算方法,提高模型的性能和稳定性。

作者参考 Estimating nodeimportance in knowledge graphs using graph neural networks 这篇文章的做法,从节点邻居的重要性来评估节点的重要程度。作者设计了如下图所示的节点评估器gϕ()g_\phi(·)

首先通过Scoring Layer计算每个节点初始重要程度得分: $$ s_i^0=tanh(w_s^Tx_i+b_s)\tag 8 $$ 其中$w_s\in\mathbb R^d$是可学习的权重向量,$b_s\in\mathbb R^1$是偏置。然后通过Score Aggregation Layer对邻居节点重要程度得分进行聚合,聚合结果作为节点$v_i$的最终得分。计算公式如下: $$ s_i^l=\sum_{j\in\mathcal N_i\cup v_i}\alpha_{ij}^ls_j^{l-1}\tag 6 $$ 其中$s_i^l$表示节点$v_i$在l层$l=(1,...,L)$的重要程度得分。$\alpha_{ij}^l$表示聚合时$v_i$和$v_j$之间的注意力权重,计算方法如下: $$ \alpha_{ij}^l=\frac{exp(LeakyReLU(a^T[s_i^{l-1}||s_j^{l-1}]))}{\sum_{k\in\mathcal N_i\cup v_i}exp(LeakyReLU(a^T[s_j^{l-1}||s_k^{l-1}]))}\tag 7 $$ 其中$||$表示拼接操作,$a$是权重向量。根据之前的节点重要性计算方法的建议,节点的重要程度和节点在图中的centrality是相关的。节点centrality的一种常用表示方法是节点的入度$deg(i)$,节点$v_i$初始centrality $C(i)$计算方法如下: $$ \tilde s_i=sigmoid(C(i)·s_i^L)\tag {10} $$ 这样我们就能够利用编码在网络中的其他信息调整支持集中节点的重要程度。

2.4 小样本节点分类

得到每个节点的重要程度得分后,首先对其进行正则化处理:

βi=exp(S~i)kScexp(s~K)(11)\beta_i=\frac{exp(\tilde S_i)}{\sum_{k\in S_c}exp(\tilde s_K)}\tag {11}

然后计算每个类别的原型:

pc=iScβizi(12)p_c=\sum_{i\in S_c}\beta_iz_i\tag {12}

对于查询集中的节点我们只需要计算该节点和所有类别原型之间的距离,然后选取距离最近的类别原型代表的类别作为预测结果即可。对于测试节点viv_i^*,其属于类别c的概率为:

p(cvi)=exp(d(zi,pc))cexp(d(zi,pc)(13)p(c|v_i^*)=\frac{exp(-d(z_i^*,p_c))}{\sum_{c'}exp(-d(z_i^*,p_{c'})}\tag {13}

其中d()d(·)表示距离度量函数,通常采用欧式距离的平方。每个元训练任务的损失定义如下:

L=1N×Mi=1N×Mlog(p(yivi))(14)\mathcal L=-\frac{1}{N\times M}\sum_{i=1}^{N\times M}log(p(y_i^*|v_i^*))\tag{14}

整个算法流程如下图所示:

# 3. 实验

数据集:Amazon-Clothing、Amazon-Electronics、DBLP和Reddit

对比方法:DeepWalk、node2vec(基于随机游走)、GCN、SGC、PN(基于GNN)、MAML、Meta-GNN(小样本学习)

在四个数据集上,作者的方法都取得了最好的结果。
作者针对N-Way和K-Shot进行了消融实验,其中GPN-naive表示去掉了node valuator后的GPN变体。
作者还针对查询集大小进行了笑容实验。
上图展示了Meta-GNN和GPN在DBLP数据集上查询集合支持集节点之间的相似度,可以发现GPN模型下,查询集合支持集间相同类别节点的距离更小。
打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信