Relative and Absolute Location Embedding for Few-Shot Node Classification on Graph

https://ojs.aaai.org/index.php/AAAI/article/view/16551/16358

Relative and Absolute Location Embedding for Few-Shot Node Classification on Graph,2021,AAAI

总结:本文的出发点是:单个元任务中不同节点间的相互依赖关系应该被编码到prior knowledge中,但是两个节点在整个图中的距离可能较远,GNNs无法捕获这种long range dependencies。作者的解决方案就是,在两个节点u和v之间采样路径,使用path encoder计算路径编码,同时考虑不同任务之间学习到的先验知识应该被对齐到整个图上,所以从task和graph两个level上分别计算location embedding。另外考虑到远距离节点之间的路径采样效率很低,提出了一种新的基于hub节点的路径采样策略。总的来说,文章写得挺好的,实验结果也挺好的,学习价值较高,可惜作者暂时没有提供源码。

1. 简介

1.1 摘要

Node classification is an important problem on graphs. While recent advances in graph neural networks achieve promising performance, they require abundant labeled nodes for training. However, in many practical scenarios, there often exist novel classes in which only one or a few labeled nodes are available as supervision, known as few-shot node classification. Although meta-learning has been widely used in vision and language domains to address few-shot learning, its adoption on graphs has been limited. In particular, graph nodes in a few-shot task are not independent and relate to each other. To deal with this, we propose a novel model called Relative and Absolute Location Embedding (RALE) hinged on the concept of hub nodes. Specifically, RALE captures the task-level dependency by assigning each node a relative location within a task, as well as the graph-level dependency by assigning each node an absolute location on the graph to further align different tasks toward learning a transferable prior. Finally,extensive experiments on three public datasets demonstratethe state-of-the-art performance of RALE.

节点分类作为图学习中非常重要的任务之一,尽管最近GNNs取得了不错的表现,但是这些模型需要大量有标签节点用于训练。在许多实际应用中,经常存在一些新类别只要少量有标签节点用来训练模型,称之为小样本节点分类。尽管元学习已经被广泛用于解决CV和NLP领域的小样本学习,但将其应用到图结构数据上存在很大限制。特别是在小样本任务中图节点并不是独立的,而是相互关联的。为了解决这个问题,我们基于hub nodes理念提出了一个新的模型,称之为Relative and Absolute Location Embedding(RALE)。具体来说,RALE通过节点在单个任务中的相对位置捕捉task-level依赖,同时根据节点在图上的绝对位置来捕捉graph-level依赖,来对其不同任务,学习一个可迁移的先验知识。

1.2 本文工作

动机: GNNs虽然在图分类任务中取得了不错的效果,但是通常需要大量有标签数据。节点分类任务的一种典型设置是:给定类别集合,每个节点属于一个类别,模型目标就是基于有标签节点预测没有标签节点的类别标签。但是在很多场景下,我们需要处理训练时未见到的新类。一方面,存在一些类别,称之为基类,具有充足样本,但另一方面新类只有少量有标签节点。例如图1(a)所示的toy citation网络,“SVM”、“Neural network”这类主题的节点很多,但是“Explainable AI”和“Fair ML”这类主题节点很少。本文作者研究的就是针对这些novel calss的分类问题,称之为小样本节点分类。

现有工作弊端: 现有的几篇利用元学习做小样本节点分类的文章都是遵循通用的meta-learning paradigm。具体来说它们将小样本节点分类定义为一系列分类任务(即元任务),每个任务的目的是:经过少量support nodes微调先验知识后,预测query node的类别,如图1(b)所示。在现有的这些方法中,单个任务中节点的相互依赖并没有被显示建模,也没有被整合到可迁移的先验知识中。

作者方案: 为了捕获GNNs无法捕获到的单个任务内节点间的long-range dependencies,作者先采样节点之间的路径,然后使用path encoder计算节点的location embedding,从task和graph两个level捕获节点间的依赖,用于小样本节点分类。

2. 方法

整个模型大致可以分成图3所示三个部分:

  1. 图3(a),给定一张图,利用Graph encoder为每个节点计算一个embedding,同时采样得到meta task;
  2. 图3(b),针对每个task中的节点,首先根据hub节点采样一个路径集合(分两部分,task level和graph level),再结合node embedding和path encoder学习path embedding;
  3. 根据node embedding,RL embedding和AL embedding三个向量进行节点分类。

2.1 计算位置嵌入

单纯利用GNNs作为graph encoder存在一些弊端,因为GNNs为了防止过平滑问题,层数通常较浅,这导致难以对距离较远的两个节点之间的关系进行建模。因此作者在graph encoder基础上提出了一个path encoder,对距离较远的两个节点之间的依赖进行建模。具体来说,假设Pu,v\mathcal P_{u,v}表示节点u和v之间的路径集合(路径长度在最大长度之内),位置嵌入计算公式如下:

evR=ϕ({Pu,v:uR};θg,θp)(1)\mathbf{e}_{v}^{\mathcal{R}}=\phi\left(\left\{\mathcal{P}_{u, v}: u \in \mathcal{R}\right\} ; \theta_{g}, \theta_{p}\right)\tag 1

其中ϕ\phi表示嵌入函数,Pu,v\mathcal P_{u,v}表示路径集合,R\mathcal R表示reference节点集合,θg\theta_g表示graph encoder参数,θp\theta_p表示path encoder参数。

一、Graph encoder

ϕg(;θg)\phi_g(·;\theta_g)表示GNN encoder,参数为θg\theta_g,其每一层计算方式如下:

hvi=M(hvi1,{hui1,uNv};θg)(2)\mathbf{h}_{v}^{i}=\mathcal{M}\left(\mathbf{h}_{v}^{i-1},\left\{\mathbf{h}_{u}^{i-1}, \forall u \in \mathcal{N}_{v}\right\} ; \theta_{g}\right)\tag 2

其中hviRdih_v^i\in\mathbb R^{d_i}表示第i层节点v的d嵌入向量,Nv\mathcal N_v表示v的邻居节点集合,M\mathcal M表示消息传播函数,hv0=xvh_v^0=x_v。用ϕg(v;θg)=hvRdg\phi_g(v;\theta_g)=h_v\in\mathbb R^{d_g}表示节点v最终嵌入向量。

二、Path encoder

给定一条路径p=(v1,v2,...,vs)p=(v_1,v_2,...,v_s),根据graph encoder的输出可以得到一组嵌入向量P=ϕg(p;θg)=(hv1,hv2,...,hvs)P=\phi_g(p;\theta_g)=(\mathbf{h_{v1},h_{v2},...,h_{vs}})。Path encoder模型可表示如下:

p=ϕp(P;θp)(3)\mathbf{p}=\phi_{p}\left(P ; \theta_{p}\right)\tag 3

其中ϕ(;θp)\phi(·;\theta_p)表示一个序列模型比如RNN或者Transformer,参数为θp\theta_p

三、Location embedding

利用前面两个encoder,给定起始点u,我们可以计算节点v的location embedding:

ev{u}=AGGR({ϕp(ϕg(p;θg);θp):pPu,v})(4)\mathbf{e}_{v}^{\{u\}}=\operatorname{AGGR}\left(\left\{\phi_{p}\left(\phi_{g}\left(p ; \theta_{g}\right) ; \theta_{p}\right): \forall p \in \mathcal{P}_{u, v}\right\}\right)\tag 4

其中AGGR()AGGR(·)表示聚合函数,比如平均池化,Puv\mathcal P_{uv}表示从u起始终到v的路径集合。如果给定起始点集合R\mathcal R,节点v的location计算方法如下:

evR=ϕ({Pu,v:uR};θg,θp)=AGGR({ev{u}:uR})(5)\begin{aligned} \mathbf{e}_{v}^{\mathcal{R}} &=\phi\left(\left\{\mathcal{P}_{u, v}: u \in \mathcal{R}\right\} ; \theta_{g}, \theta_{p}\right) \\ &=\operatorname{AGGR}\left(\left\{\mathbf{e}_{v}^{\{u\}}: \forall u \in \mathcal{R}\right\}\right) \end{aligned}\tag 5

2.2 路径采样

2.1部分介绍了图3中各种encoder和aggregator的详细计算方法,下面主要介绍图3中(b1)和(b2)部分的详细过程。

如前文所述,在小样本任务中support节点和query节点在图中的距离可能很远,为了捕捉远距离节点间的long-range dependencies,作者利用hubs节点概念,从两个维度计算location embedding来解决这个问题。

所谓Hub节点,指的是PageRank算法种得分最高的若干个节点,它们对于整个图来说非常重要。用HV\mathcal{H\subset V}表示通过某种centrality方法计算得到的hub节点集合。hub节点在task level和graph level分别扮演者两个重要的角色。

一、Task level

给定任务t=(St,Qt)t=(S_t,Q_t),为了捕捉该任务下所有节点之间的依赖,我们为每个节点v分配一个relative location,称之为RL。具体来说,对于vStQt\forall v\in S_t\cup Q_t,它的RL embedding计算方式如下:

evSt=ϕ({Ps,v:sSt};θg,θp)(6)\mathbf{e}_{v}^{S_{t}}=\phi\left(\left\{\mathcal{P}_{s, v}: s \in S_{t}\right\} ; \theta_{g}, \theta_{p}\right)\tag 6

其中Ps,v\mathcal P_{s,v}表示路径集合,可以通过起始节点为s,终点节点为v,最大路径长度为lpl_p的随机游走得到。但是在图中节点s和v之间的距离可能很远,导致原生的随机游走方法效率会很低,而且很大可能在lpl_p步内找不到一条合适路径。

为了解决这个问题,作者通过hub节点来采样路径。具体来说,首先分别从s节点和v节点开始,终点为hub节点,路径长度为lp/2l_p/2,利用随机游走采样路径得到segment sampling。然后对这些segments进行重组得到path,比如图3(b)的(b1)部分存在v6v9v_6\rightarrow v_9v8v9v_8\rightarrow v_9两个segments可以重组成路径v8v9v6v_8\rightarrow v_9\rightarrow v_6

二、Graph level

RL embedding是对单个任务内所有节点之间的依赖进行建模,为了学习一个通用的可迁移先验知识,需要将这种task level依赖关系对其到graph level。作者将hub节点H\mathcal H看做globel reference节点集合,给每个节点分配一个absolute location,称之为AL,并计算对应的AL embedding:

evH=ϕ({Ph,v:hH};θg,θp)(7)\mathbf{e}_{v}^{\mathcal{H}}=\phi\left(\left\{\mathcal{P}_{h, v}: h \in \mathcal{H}\right\} ; \theta_{g}, \theta_{p}\right)\tag 7

具体来说需要采样起始节点为hub节点,终点节点为v的路径(其实就是前面提到的segment,无需重复计算),然后使用path encoder计算对应embedding。

2.3 其他

一、分类层

给定任务t中的某个节点v,它的分类得分计算方式如下:

ψ(v;Θ)=SoFTMAX(σ(W[hvevStevH]))(8)\psi(v ; \Theta)=\operatorname{SoFTMAX}\left(\sigma\left(\mathbf{W}\left[\mathbf{h}_{v}\left\|\mathbf{e}_{v}^{S_{t}}\right\| \mathbf{e}_{v}^{\mathcal{H}}\right]\right)\right)\tag 8

二、目标函数

交叉熵损失最为目标函数:

L(St;Θ)=vSti=1mI(v)=iln(ψ(v;Θ)[i])(9)L\left(S_{t} ; \Theta\right)=-\sum_{v \in S_{t}} \sum_{i=1}^{m} \mathbb{I}_{\ell(v)=i} \ln (\psi(v ; \Theta)[i])\tag 9

三、复杂度

整个算法包含两部分:一是segment采样和path重构(这部分只需要执行一次),二是基于元任务优化模型目标。作者文中说伪代码和复杂度分析在附件中提供,但是我这个版本的论文中并没有supplementary。

3. 实验

3.1 实验设置

数据集划分:

Baselines:

  1. GNNs:GCN,GraphSAGE,GAT,训练阶段利用base class上的分类任务训练GNN模型,测试阶段冻结GNN部分参数,用support集微调分类层参数。
  2. GNN+‘s:还是上面三种GNNs架构,不过每种模型都在base calsses上进行预训练,测试阶段利用support set对模型参数微调,在query set上测试。
  3. 元学习模型:Meta-GNN,Proto-GNN

对于所有元学习模型,都采用GraphSAGE作为GNN架构,对于RALE,采用自监督模型作为path encoder。

3.2 实验结果

对比GNNs和GNN+’s,可以发现没有绝对的优劣之分,用测试任务对GNNs的聚合函数进行微调可能导致过拟合,因此直接使用微调方式无法解决小样本问题。

Meta-GNN、Proto-GNN和GNNs、GNN+'s相比,性能只取得了很小的提升,表明如果不对节点间的依赖进行建模,元学习对于小样本图学习的帮助有限。

RALE相比之下,在大部分实验设置下都取得了优异的表现。但是在Amazon数据集,5-shot设定下,RALE表现较差。作者分析可能的原因是: The possible reason is that the co-purchasing ties between diverse items on a large e-commerce platform like Amazon are weaker than email ex-changes between users or topical relatedness between posts participated in by the same user on social networks. Thus, it becomes less important to capture the dependencies, and the attention mechanism in GAT/GAT+ is more suited to weighing diverse neighbors especially when given more shots。大概意思是亚马逊这种大型电子商务平台上不同物体间co-purchasing比较弱,因此对dependency的捕捉不太依赖,而GAT中的注意力机制更适合对各种邻居进行加权,尤其在more shots情况下。(感觉也没解释太清楚)

3.3 模型分析

一、消融实验

RALE\r表示去掉RL embedding,RALE\a表示去掉AL embedding,RALE\ar表示都去掉。

二、hubs和随机游走

hub节点比例越小,两个节点之间可能存在的路径就越少,图5a展示了不同hub ratio下模型准确率,可以看到随着hub ratio增加,模型准确度会增加,但是hub ratio达到5%后,模型性能基本不变。

图5b-d分别展示了不同lpl_p、d和w下模型性能,lpl_p表示最大路径长度,ww表示每个节点的number of walks,ll表示walk length。这三个量的增大,都会导致模型性能和coverage ratio的增加。但是在相同设置下,使用hub节点(图中的红线)采样路径模型准确度更高,表示作者的hub-based采样策略是有效的。

三、参数敏感性

针对embedding的维度和学习率进行了敏感性实验。

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信