Node Classification on Graphs with Few-Shot Novel Labels via Meta Transformed Network Embedding

https://arxiv.org/pdf/2007.02914

https://github.com/llan-ml/MetaTNE

Node Classification on Graphs with Few-Shot Novel Labels via Meta Transformed Network Embedding,NIPS,2020

总结:最大的创新点就是嵌入转换函数,利用自注意力机制对节点嵌入进行调整使其task-specific。针对小样本设置还是利用元学习思想,学习不同类别间统一的内在模式,提高目标任务上的分类准确度。

目前图上的小样本问题大多都是针对有标签节点数量较少,且基本都是利用元学习思想来做,可以尝试着用小样本学习中的其他方法比如基于生成等等。另外可以探索一下关系或者属性在小样本场景下该如何做。

1.简介

1.1 摘要

We study the problem of node classification on graphs with few-shot novel labels,which has two distinctive properties: (1) There are novel labels to emerge in thegraph; (2) The novel labels have only a few representative nodes for training a clas-sifier. The study of this problem is instructive and corresponds to many applicationssuch as recommendations for newly formed groups with only a few users in onlinesocial networks. To cope with this problem, we propose a novel Meta TransformedNetwork Embedding framework (MetaTNE), which consists of three modules: (1)Astructural moduleprovides each node a latent representation according to thegraph structure. (2) Ameta-learning modulecaptures the relationships between thegraph structure and the node labels as prior knowledge in a meta-learning manner.Additionally, we introduce anembedding transformation functionthat remedies thedeficiency of the straightforward use of meta-learning. Inherently, the meta-learnedprior knowledge can be used to facilitate the learning of few-shot novel labels.(3) Anoptimization moduleemploys a simple yet effective scheduling strategy totrain the above two modules with a balance between graph structure learning andmeta-learning. Experiments on four real-world datasets show that MetaTNE bringsa huge improvement over the state-of-the-art methods.

本文我们研究了具有少量新类标签下的图的节点分类问题,新类标签具有两个特性:(1)新类标签出现在图中;(2)新类标签只有少量代表性节点用于训练分类器。该研究对许多应用场景(如对在线社交网络中只有少数用户的新组建的群体进行推荐)具有重要意义。为了解决这个问题,我们提出了一个原生的Meta Transformed Network Embedding(MetaTNE)架构,它包含三个模块:(1)结构化模块:根据图结构学习每个节点的潜在表示;(2)元学习模块:按照元学习的方式捕捉图结构和节点标签之间的关系作为先验知识,另外我们引入了一个embedding transformation function来弥补直接使用元学习的不足。本质上元学习学到的先验知识可以用来增强在只有少量样本情况下的学习性能。(3)优化模块:采用一种简单但是有效的策略来训练上述两个模块,并在图结构学习和元学习之间取得平衡。4个真实数据集上的实验表明MetaTNE和当前最优方法相比带来了很大的提升。

1.2 本文工作

图被广泛应用于很多领域来表示数据,包括社交网络分析、生物信息、推荐系统以及计算机网络安全中。节点分类、链路预测、社区发现等图分析任务对我们的现实生活具有重大意义。本文我们主要针对节点分类任务,具体来说是小样本情景下的分类任务,即一些新类标签只有少量样本可以用来训练模型。小样本 的图节点分类可以指导很多实际任务,比如:

  1. 一些在线社交网络组织如Facebook、Twitter和Flickr,会有很多广告,我们希望知道用户对这些广告是否感兴趣。通过小样本学习,这些组织就可以通过少量用户的反馈更好地预测其他用户的喜好,提供更好的服务或者推荐。
  2. 对于生物里面的蛋白质网络,一些研究人员可能发现某些蛋白质之间有一种新的生物反应。给定一些蛋白质(有或者没有这一反应),小样本学习可以预测其他蛋白质之间是否有这一反应,这可以为实验人员提供新的研究方向。

本文作者考虑:图中的不同类别标签可能共享一些内在的演化模式(比如标签信息在图上的传播方式)。假定有部分类别已经有足够多的的支持集节点(有标签),我们希望能从这些节点中提取出共同的内在信息,然后利用这些提取到的信息帮助新类上(只有少量节点有标签)的分类。但是图结构和节点标签之间的关系十分复杂,节点间信息传播模式也多种多样,因此如何建模捕捉这些信息并将其应用到小样本新类上仍然具有很大挑战。

作者借鉴元学习思想,提出了一种原生的Meta Transformed Network Embedding架构,称之为MetaTNE,来捕捉共同的内在模式。如上图所示,MetaTNE共有3个模块:structural module,meta-learning module和optimization module。

2. 方法

MetaTNE包含三部分:结构模块、元学习模块和优化模块。给定一个图和一些已知标签:

  1. structural模块首先根据图结构学习每个节点的标签;
  2. 元学习模块学习一个转换函数,为每个元学习任务调整节点表示(只和图结构有关),然后使用基于距离的分类器判别节点标签;
  3. 最后按照某种概率分别优化结构模块和元学习模块。

2.1 Structural模块

该模块主要利用图结构信息来学习每个节点的潜在表示。从数学角度来看就是:对于每个节点viVv_i\in\mathcal V,我们要最大化其邻居共现的对数概率即minviVvjN(vi)logP(vjvi)min\sum_{v_i\in\mathcal V}\sum_{v_j\in\mathcal N(v_i)}log\mathbb P(v_j|v_i)N(vi)\mathcal N(v_i)表示viv_i的邻居。

关于节点邻居集合N()\mathcal N(·)的重构方法有很多,本文采用的是1-hop邻居(Line),通过优化上面的目标函数,我们可以得到一个嵌入矩阵URV×dU\in\mathbb R^{|V|\times d},其第i行uiu_i表示节点viv_i的表示。

2.2 Meta-Learning模块

作者采用基于度量的元学习架构来解决小样本问题,提出了一种转换函数,将与任务无关的嵌入转换成task-specific的嵌入,以便更好地处理多标签问题。

2.2.1 数据设置

和传统的半监督学习不一样,我们根据已知标签Yknown\mathcal Y_{known}(样本数量比较少)建立一个小样本节点分类任务池。和基于元学习的小样本图像分类任务类似,小样本节点分类任务Ti=(Si,Qi,yi)\mathcal T_i=(S_i,Q_i,y_i)由支持集SiS_i,查询集QiQ_i以及从Yknown\mathcal Y_{known}随机采样的标签yiy_iyiy_i表示某个类别标签,判别查询集中样本是否有yiy_i标签,每个任务相当于是2-Way k-shot小样本任务)。其中支持集Si=Si+SiS_i=S_i^+\cup S_i^-Si+Dyi+S_i^+\subset\mathcal D_{y_i}^+表示随机采样的正类样本,SiDyiS_i^-\subset\mathcal D_{y_i}^-表示随机采样的负类样本。Qi=Qi+Qi\mathcal Q_i=Q_i^+\cup Q_i^-,但是$\mathcal S_i\cap Q_i=\varnothing 。每个任务的目标是:给定支持集nodelabel对,寻找一个分类器。每个**任务的目标是**:给定支持集node-label对,寻找一个分类器f_{\mathcal T_i}$尽可能正确的预测查询集中节点标签。

2.2.2 元学习方法

对于每一个任务Ti=(Si,Qi,yi)p(TYknown)\mathcal T_i=(S_i,Q_i,y_i)\sim p(\mathcal T|\mathcal Y_{known}),我们希望为label yiy_i构建一个分类器fTif_{\mathcal T_i},给定支持集SiS_i,分类器能够正确判别QiQ_i中节点的类别。对于每一个(vq,lvq,yi)Qi(v_q,l_{v_q,y_i})\in\mathcal Q_i,模型损失定义为:

L(l^vq,yi,lvq,yi)=lvq,yilogl^vq,yi(1lvq,yi)log(1l^vq,yi)\mathcal L(\hat l_{v_q,y_i},l_{v_q,y_i})=-l_{v_q,y_i}log\hat l_{v_q,y_i}-(1-l_{v_q,y_i})log(1-\hat l_{v_q,y_i})

其中l^vq,yi\hat l_{v_q,y_i}表示预测vqv_q标签为yiy_i的概率(这其实就是二分类逻辑回归的损失)。

对于每一个任务TiT_i,分类器fTif_{T_i}有两个参数(d维向量):c+(i)c_+^{(i)}表示正类原型,c(i)c_-^{(i)}负类原型。我们基于节点表示和这两个原型之间的距离预测节点标签。从数学角度来看就是:给定每个查询节点vqv_q的嵌入向量uqu_q,我们按照下面方式计算概率:

l^vq,yi=fTi(vqc+(i),c(i))=exp(dist(uq,c+(i)))m{+,}exp(dist(uq,cm(i)))(2)\hat l_{v_q,y_i}=f_{T_i}(v_q|c_+^{(i)},c_-^{(i)})=\frac{exp(-dist(u_q,c_+^{(i)}))}{\sum_{m\in\{+,-\}}exp(-dist(u_q,c_m^{(i)}))} \tag 2

即节点表示和正类原型的相似度除以节点表示和正类原型与负类原型相似度之和。其中dist(,):Rd×Rd[0,+)dist(·,·):\mathbb R^d\times\mathbb R^d\rightarrow[0,+\infty)表示欧氏距离的开方,正类原型和负类原型通常为相应所有节点表示的均值。

2.2.2.1 嵌入转换

公式2是在每个节点的嵌入向量是相同的或者是任务无关的,与节点的标签或者当前任务无关这一条件下计算预测概率的。这种方法在单标签小样本分类(每个样本只有一个标签)中是合理的,但是在多标签场景(每个节点可能被分配多个标签)下存在问题。

比如在社交网络中,T1\mathcal T_1T2\mathcal T_2是两个分类任务,其对应的标签分别为"Sports"Yknown"Sports"\in\mathcal Y_{known}"Music"Ynovel"Music"\in\mathcal Y_{novel}AABB是这两个任务中涉及到的两个用户,假如A和B对于“Sports”的反馈都是positive,对于“Music”的反馈分别是positive和negative。此时,经过T1\mathcal T_1任务训练,这种task-agnostic方式得到的A和B的表示会很相似,这种节点表示不适用与T2\mathcal T_2

为了解决上述问题,作者提出了一个转换函数Tr()Tr(·):可以将task-agnostic嵌入转换成task-specific嵌入。

  1. 不同的查询节点和支持集中节点的关联模式不同,为了全面挖掘查询节点和支持节点之间的关系,对于每个查询节点我们都要对节点嵌入做针对性调整
  2. 为了对查询节点进行分类,我们对查询节点和支持节点(正类或负类)之间的距离的关系更感兴趣,而不是查询节点和支持节点(正类和负类)之间的关系。换句话说就是,查询节点和支持节点之间的关系要将支持节点分为正类和负类单独讨论。因此在转换过程中,分别使用正类支持节点和负类支持节点对查询节点进行调整。

基于这两点,对于每个查询节点,我们首先构建两个集合:一个包含查询节点的task-agnostic嵌入和正类支持节点,另一个包含查询节点的task-agnostic嵌入和负类支持节点。然后分别将这两个集合喂给转换函数,即给定一个任务Ti=(Si,Qi,yi)\mathcal T_i=(S_i,Q_i,y_i),对于每个查询节点vqVQiv_q\in\mathcal V_{\mathcal Q_i}

\begin{align} \{\tilde u_{q,+}^{(i)}\}\cup\{\tilde u_{k,q}^{(i)}|v_k\in\mathcal V_{S_i^+}\}=Tr(\{u_q\}\cup\{u_k|v_k\in\mathcal V_{S_i^+}\})\\ \{\tilde u_{q,-}^{(i)}\}\cup\{\tilde u_{k,q}^{(i)}|v_k\in\mathcal V_{S_i^-}\}=Tr(\{u_q\}\cup\{u_k|v_k\in\mathcal V_{S_i^-}\}) \end{align}\tag 3

其中u~q,+(i)\tilde u_{q,+}^{(i)}u~q,(i)\tilde u_{q,-}^{(i)}分别表示用正类支持节点和负类支持节点调整后的查询节点的嵌入,u~k,q(i)\tilde u_{k,q}^{(i)}表示针对查询节点vqv_q调整后的支持节点的嵌入。这样每个查询节点都有两种调整后的嵌入,可以捕捉节点间多种关系来适应多标签场景。

2.2.2.2 转换函数的具体实现

作者使用“ Attention is all you need ”一文中的自注意力架构来实现转换函数。self-attention中每个输入元素都扮演三个角色:

  1. 和其他每个元素进行比较,计算一个权重(反映了它对其他元素的影响)
  2. 和其他每个元素进行比较,计算一个权重(反映其他元素对它的影响)
  3. 用作每个元素输出的一部分

每个元素都有三个向量:query、key和value,分别表示三个角色。

对于公式3,输入是两个集合{uq}{ukvkVSim},m{+,}\{u_q\}\cup\{u_k|v_k\in\mathcal V_{S_i^m}\},m\in\{+,-\},对于任意两个节点vi,vj{vq}VSimv_i,v_j\in\{v_q\}\cup\mathcal V_{S_i^m}viv_ivjv_j可能相同),我们首先计算viv_ivjv_j的注意力权重:

wij=exp((WQui)(WKuj)/d1/2)vk{vq}VSimexp((WQui)(WKuk)/d1/2)(4)w_{ij}=\frac{exp((W_Qu_i)·(W_Ku_j)/d'^{1/2})}{\sum_{v_k\in\{v_q\}\cup\mathcal V_{S_i^m}}exp((W_Qu_i)·(W_Ku_k)/d'^{1/2})}\tag 4

其中WQ,WKRd×dW_Q,W_K\in\mathbb R^{d'\times d}分别是将输入向量转换成query向量和key向量的trainable矩阵,dd'表示query、key和value向量的维度,"·"表示点乘操作,1d\frac{1}{\sqrt {d'}}是缩放因子(防止梯度太小)。

wijw_{ij}本质上表示的是:节点vjv_jviv_i之间的相关程度,或者说是vjv_jviv_i的影响程度

这样转换后的节点表示向量(聚合了其他所有节点信息)可以通过加权的方式计算得到:WVRd×dW_V\in\mathbb R^{d'\times d}表示计算value向量的trainable矩阵,WORd×dW_O\in\mathbb R^{d\times d'}表示确保输出向量和输入向量维度相同的另一个trainable矩阵,按照如下方式计算节点vqv_q转换后的向量:

u~q,m(i)=WO(wqqWVuq+vkVSimwqkWVuk)(5)\tilde u_{q,m}^{(i)}=W_O\Big(w_{qq}W_Vu_q+\sum\limits_{v_k\in\mathcal V_{S_i^m}}w_{qk}W_Vu_k\Big)\tag 5

按照如下方式计算针对查询节点vqv_q的支持节点的表示:

u~k,q(i)=WO(wkkWVuk+vj(VSim\{vk}){vq}wkjWVuj)(6)\tilde u_{k,q}^{(i)}=W_O\Big(w_{kk}W_Vu_k+\sum\limits_{v_j\in\mathcal (V_{S_i^m}\backslash\{v_k\})\cup\{v_q\}}w_{kj}W_Vu_j\Big)\tag 6

本文作者最终采用的是多头注意力机制,将每个头的输入拼接到一起后再通过WOW_O映射成原始输入维度。得到转换后的节点嵌入后,我们重新计算针对查询节点vqv_q的正类原型和负类原型及其预测概率:

c~m,q(i)=1SimvkVSimu~k,q(i),m{+,}l^vq,yi=exp(dist(u~q,+(i),c~+,q(i)))m{+,}exp(dist(u~q,m(i),c~m,q(i)))(7)\begin{aligned} \tilde c_{m,q}^{(i)}&=\frac{1}{|S_i^m|}\sum\limits_{v_k\in V_{S_i^m}}\tilde u_{k,q}^{(i)},m\in\{+,-\}\\ \hat l_{v_q,y_i}&=\frac{exp(-dist(\tilde u_{q,+}^{(i)},\tilde c_{+,q}^{(i)}))}{\sum_{m\in\{+,-\}}exp(-dist(\tilde u_{q,m}^{(i)},\tilde c_{m,q}^{(i)}))} \end{aligned}\tag 7

最终的元学习目标函数也调整为:

minU,ΘTi(vq,lvq,yiQi)L(l^vq,yi,lvq,yi)+λΘ22\mathop{min}_{U,\Theta}\sum_{\mathcal T_i}\sum_{(v_q,l_{v_q},y_i\in\mathcal Q_i)}\mathcal L(\hat l_{v_q,y_i},l_{v_q,y_i})+\lambda\sum||\Theta||_2^2

其中Θ\Theta是转换函数的所有参数(WQ,WKWVW_Q,W_K和W_V),λ>0\lambda>0是平衡因子。(损失里面为什么最小化参数?)

2.3 模型优化

一种典型的方式是最小化structural损失和meta损失之和,但是图的结构信息在训练开始阶段并没有被正确的嵌入进去,节点的表示具有一定随机性,这对小样本分类任务没有意义。因此一种比较好的优化策略是:在开始阶段针对structural模块进行优化,然后逐渐增加meta-learning模块损失的权重。

为了实现这个策略,作者引入了一个概率阈值τ\tau,在每一步训练过程中分别按照τ\tau1τ1-\tau的概率对structural模块和meta-learning模块进行优化。阈值τ\tau按照阶梯式方式逐渐从1衰减到0:τ=1/(1+γstepNdecay)\tau=1/(1+\gamma\lfloor\frac{step}{N_{decay}}\rfloor),其中γ\gamma表示衰减率,stepstep表示当前的step number,NdecayN_{decay}表示衰减次数。时间复杂度如下:

  • structural模块:时间复杂度为O(kdE)O(kd|\mathcal E|),k表示每次迭代过程中负类节点数,d表示节点嵌入维度,E|\mathcal E|表示边数;

  • 元学习模块:时间主要花费在转换函数(自注意力模块)上,m表示查询节点数量,n表示负类或正类支持节点数量;计算query、key和value向量花费时间O(mndd)O(mndd')dd'表示三个向量的维度;计算注意力权重和value向量之和花费时间O(mn2d)O(mn^2d');计算最终输出向量花费时间O(mndd)O(mndd')

因此总的时间复杂度为O(kdE+mndd+mn2d)O(kd|\mathcal E|+mndd'+mn^2d')。整个算法优化流程如下:

我们最终的目标是:给定少量标签为yYnovely\in\mathcal Y_{novel}的支持节点,我们能判断其他节点是否具有标签yy,这可以看做是一个小样本节点分类任务T=(S,Q,y)\mathcal T=(S,Q,y)。训练好模型后,我们可以得到task-agnostic节点表示UU和参数为Θ\Theta的转换函数Tr()Tr(·)。因此,如果我们要判别一个查询节点vqQv_q\in Q的标签,我们只需要查询表示矩阵UU得到其嵌入uqu_q,然后通过公式5和6的转换函数得到其转换后的嵌入表示u~q+\tilde u_q^+u~q\tilde u_q^-,再根据公式7计算其预测概率即可。算法流程如下:

3. 实验

3.1 数据集

每个数据集的标签按照6:2:2划分成训练集、验证集和测试集。训练阶段将训练标签看做known标签,并从里面采样小样本节点分类任务。将验证和测试标签看做novel标签并从中分别采样1000个任务。我们采用测试任务上的平均分类表现来比较不同方法的性能。KS,+,KS,,KQ,+,KQ,K_{S,+},K_{S,-},K_{Q,+},K_{Q,-}分别表示支持集和查询集的正、负类样本数量。baseline方法采用Label Propagation,LINE,Node2Vec,GCN和Meta-GNN。

3.2 对比试验

设置KS,+=KQ,+K_{S,+}=K_{Q,+}KS,=KQ,K_{S,-}=K_{Q,-},用K,+K_{*,+}K,K_{*,-}简化表示正类样本和负类样本数量。考虑到负类样本数量更容易获取,作者设置K,+=20K_{*,+}=20K,=40K_{*,-}=40,对比结果如下表所示:

具体参数设置可以查看原始论文补充材料。

3.3 消融实验

验证不同模块的作用,提出5种变体:

  1. 没有转换函数
  2. 直接将查询节点和支持节点的表示直接喂给self-attention而不是像公式3一样将正类和负类分开
  3. 两个模块的损失的平衡因子从{101,101,...,102}\{10^{-1},10^{-1},...,10{2}\}里面选取
  4. 最开始先学习节点嵌入,然后将其固定(即两个模块分开单独训练)
  5. 每个节点的表示用one-hot向量表示,节点嵌入仅仅通过meta-learning模块训练

这五种变体分别用V1,V2,V3,V4,V5V_1,V_2,V_3,V_4,V_5表示,是研究结果如下表所示:

3.4 额外实验

  1. 正类和负类节点数量

    在负类样本小于等于正类样本时,MetaTNE和Planetoid方法表现相当,但是当负类样本数大于正类样本数时,MetaTNE方法由于Planetoid。这证明了我们的方法优于Planetoid,因为在实际应用中负类节点数量通常多于正类节点数量。

  2. 查询节点数量

    在前面试验中查询集合支持集的正负类样本数量相同,但是在实际应用中查询集正负类节点数量可能不同,因此作者进一步测试了查询及节点数对模型表现的影响。(感觉原文对这一部分实验结果的分析有点勉强)

  3. 更少的正类节点数量

    设置正类样本数量为5。

  4. 节点表示可视化

    如图a,未对节点嵌入做调整时,查询节点(真实类别为负类)距离正类原型更近,分类错误;对节点嵌入作调整后,查询节点(真实类别为负类)表示Query(+)到正类正类原型的距离大于Query(-)到负类原型的距离。图b也是同样地结果,说明作者提出的转换函数是有效的。

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信