A Universal Representation Transformer Layer for Few-Shot Image Classification

https://arxiv.org/pdf/2006.11702

https://github.com/liulu112601/URT

A Universal Representation Transformer Layer for Few-Shot Image Classification,2021,ICLR

总结:文章提出的URT方法基于SUR方法改进而来,用于multi-domain下的小样本学习。现有的小样本学习方法基本都是从单个dataset/domain(样本充足)中采样元任务,但是这在实际应用中是不现实的。本文作者提出的URT方法,作者将Transformer中的自注意力机制整合到小样本学习中,从多个预训练好的backbone中自动选取合适的backbone学习样本表示。和SUR方法相比通用性更高,并且在大多数数据集上都取得了更好的分类性能。文章中图画的挺好看,实验比较丰富,可以借鉴到其他领域论文中。

1. 简介

1.1 摘要

Few-shot classification aims to recognize unseen classes when presented with only a small number of samples. We consider the problem of multi-domain few-shot image classification, where unseen classes and examples come from diverse data sources. This problem has seen growing interest and has inspired the development of benchmarks such as Meta-Dataset. A key challenge in this multi-domain setting is to effectively integrate the feature representations from the diverse set of training domains. Here, we propose a Universal Representation Transformer (URT) layer, that meta-learns to leverage universal features for few-shot classification by dynamically re-weighting and composing the most appropriate domain-specific representations. In experiments, we show that URT sets a new state-of-the-art result on Meta-Dataset. Specifically, it achieves top-performance on the highest number of data sources compared to competing methods. We analyze variants of URT and present a visualization of the attention score heatmaps that sheds light on how the model performs cross-domain generalization. Our code is available at https://github.com/liulu112601/URT.

小样本分类旨在只有少量样本的情况下能够区分unseen类别。本文作者考虑multi-domain小样本图像分类问题,unseen类别和样本来自不同的domain。该问题吸引了很多研究人员的关注,并刺激了相关标准数据集的发展比如Meta-Dataset。multi-domain设定下的一个关键挑战是如何有效地整合来自不同domain的特征信息。本文作者提出了一个Universal Representation Transformer(URT)层,通过动态加权,组合最合适的domain-specific表示来学习universal特征用于小样本分类。在实验部分,我们展示了URT在Meta-Dataset上取得了最优结果。具体来说,和competing方法相比,URT在大部分数据集上都取得了top-performance。我们分析了URT的各种变体,并且展示了模型学到的注意力权重来解释模型如何实现cross-domain generalization。我们的代码可以通过 https://github.com/liulu112601/URT 获取。

1.2 本文工作

应用场景: multi-domain few-shot classification,不仅仅每个task中训练样本少,而且训练集和测试集样本来自多个不同的域。这种场景下,模型不仅要解决常规的小样本分类存在的挑战(即每个类别只有少量样本),还要实现跨域学习。

现有方法: Triantafillou等人构造了一个用于跨域小样本分类的标准数据集Meta-Dataset,基于此人们也提出了一些该场景下的方法,其中最为突出的就是SUR(Selecting Universal Representation),本文提出URT就是在该方法基础上改进而来。

URT: SUR方法设计了一个流程手动为每个backbone加权,为新任务中样本学习一个通用表示,没有迁移(跨任务/跨域)学习。本文作者提出的URT方法,利用Transformer中的注意力机制自动学习pre-trained backbones的权重,学习task-adapted表示。

1.3 Problem Setting

小样本分类: 在每个类别只有少量样本的情况下实现分类,每个任务中包含一个支持集S和查询集Q。S中有N个类别,每个类别有K个样本,在S上训练模型,在Q上测试模型分类性能,这就是一个N-way-K-shot小样本分类任务。

元学习: 小样本场景下的一种学习技术。一种常用的训练元学习模型的方式是episodic training,在大型数据集中采样多个小样本任务T=(Q,S)T=(Q,S),用这些任务迭代训练模型,元学习模型优化目标通常如下:

minΘE(S,Q)p(T)[L(S,Q,Θ)],L(S,Q,Θ)=1Q(x,y)Qlogp(yx,S;Θ)+λΩ(Θ)(1)\min _{\Theta} \mathbb{E}_{(S, Q) \sim p(T)}[\mathcal{L}(S, Q, \Theta)], \mathcal{L}(S, Q, \Theta)=\frac{1}{|Q|} \sum_{(\boldsymbol{x}, y) \sim Q}-\log p(y \mid \boldsymbol{x}, S ; \Theta)+\lambda \Omega(\Theta)\tag 1

Meta-Dataset: 传统小样本分类任务的设定是N-way-K-shot,每个任务支持集中包含N个类别,每个类别有K个样本,而meta-task采样自同一个domain或者dataset(比如Omniglot和miniImageNet),但是这在实际应用中是不现实的。因此,Triantafillou等人提出了Meta-Dataset,它由10个不同数据集(domain)组成,其中8个用于训练。并且从benchmark中提取的每个任务中N和K的值时不同的。

2. URT

图片解释:(1)假设某个task中S集中两个类别“Waxcap”和“Fly aganic”,每个类别各有两张图片;Q集合有一张图片(上图第一张)。(2)模型有四个预训练好的backbones,分别计算上图5个样本representation{ri}\{r_i\}(上图中四种颜色矩形方块表示四个backbones学习到的特征向量concat在一起)。(3)每个类别所有样本的representation求平均值得到该类别的特征表示(类似类别原型)r(Si)r(S_i)。(4)对r(Si)r(S_i)进行自注意力操作得到该类别下的注意力权重。(5)该任务下所有类别注意力权重求平均值得到该任务下的注意力权重。(5)将该任务下的注意力权重和查询集样本的特征表示相乘得到查询集样本的最终表示用于分类。

2.1 Single-Head URT层

图1中展示的就是单头URT层,ri(x)r_i(x)表示第i个domain下backbone学习到的表示,该样本的universal表示定义为:

r(x)=concat(r1(x),,rm(x))(2)r(\mathbf{x})=\operatorname{concat}\left(r_{1}(\mathbf{x}), \ldots, r_{m}(\mathbf{x})\right)\tag 2

支持集中每个类别的表示(类似类别原型)定义为:

r(Sc)=1ScxScr(x)(3)r\left(S_{c}\right)=\frac{1}{\left|S_{c}\right|} \sum_{\boldsymbol{x} \in S_{c}} r(\boldsymbol{x})\tag 3

然后为每个类别单独执行自注意力操作,计算权重得分,具体操作图1右边所示:

(1)计算Queries qcq_cqc=Wqr(Sc)+bq\mathbf{q}_{c}=\mathbf{W}^{q} r\left(S_{c}\right)+\mathbf{b}^{q},其中WqW^qbqb^q为可学习参数。

(2)计算Keys ki,ck_{i,c}ki,c=Wkri(Sc)+bk\mathbf{k}_{i, c}=\mathbf{W}^{k} r_{i}\left(S_{c}\right)+\mathbf{b}^{k},其中WkW^kbkb^k为可学习参数,ri(Sc)=1/ScxScri(x)r_{i}\left(S_{c}\right)=1 /\left|S_{c}\right| \sum_{\boldsymbol{x} \in S_{c}} r_{i}(\boldsymbol{x})

(3)计算Attention scores αi\alpha_i,下式中l表示queries和keys的维度:

αi,c=exp(βi,c)iexp(βi,c),βi,c=qcki,cl(4)\alpha_{i, c}=\frac{\exp \left(\beta_{i, c}\right)}{\sum_{i^{\prime}} \exp \left(\beta_{i^{\prime}, c}\right)}, \beta_{i, c}=\frac{\mathbf{q}_{c}^{\top} \mathbf{k}_{i, c}}{\sqrt{l}}\tag 4

αi=cαi,cN(5)\alpha_{i}=\frac{\sum_{c} \alpha_{i, c}}{N}\tag 5

得到该任务下每个backbone的注意力得分αi\alpha_i后,重新计算查询集样本的表示:

ϕ(x)=iαiri(x)(6)\phi(\mathrm{x})=\sum_{i} \alpha_{i} r_{i}(\mathrm{x})\tag 6

2.2 Multi-Head URT Layer

H头URT层即将单头URT层的注意力计算部分重复执行H次,然后将公式6替换成公式7:

ϕ(x)=concat(ϕ1(x),,ϕII(x))(7)\phi(\mathrm{x})=\operatorname{concat}\left(\phi_{1}(\mathrm{x}), \ldots, \phi_{\mathrm{II}}(\mathrm{x})\right)\tag 7

这里根据经验,作者发现随机初始化每个头的注意力权重不能实现各个头之间的互补性和唯一性,因此作者参照其他人的做法添加了一个正则化损失,避免注意力得分的重复:

Ω(Θ)=(AAI)F2(8)\Omega(\Theta)=\left\|\left(\mathbf{A} \mathbf{A}^{\top}-\mathbf{I}\right)\right\|_{F}^{2}\tag 8

其中AA表示注意力得分矩阵。

2.3 模型训练

采用原型网络中的分类方法,根据样本特征表示和类别原型之间的距离计算样本属于该类别的概率:

p(y=cx,S;Θ)=exp(d(ϕ(x)pc))c=1Nexp(d(ϕ(x)pc))(9)p(y=c \mid x, S ; \Theta)=\frac{\exp \left(-d\left(\phi(x)-p_{c}\right)\right)}{\sum_{c^{\prime}=1}^{N} \exp \left(-d\left(\phi(x)-p_{c^{\prime}}\right)\right)}\tag 9

其中pc=1/ScxScϕ(x)\boldsymbol{p}_{c}=1 /\left|S_{c}\right| \sum_{\boldsymbol{x} \in S_{c}} \phi(\boldsymbol{x})表示类别c的原型。本文作者采用余弦相似度作为度量,具体算法如下图所示:

3. 实验

主要回答下面3个问题:

  1. URT和Meta-Dataset数据集上的state-of-art方法相比性能如何?
  2. URT中的自注意力是否具有可解释性?学习到的注意力得分是否有效?
  3. URT即使在backbone采用不同方式训练的情况下是否持续有效?

实验设置:采用Meta-Dataset数据集,预训练的backbone在训练URT时冻结参数,URT层训练10000个episodes,初始学习率为0.01,采用余弦学习率scheduler,训练episode有50%概率来自ImageNet数据源等等。

3.1 对比实验

作者对比了URT和SUR、一些基于微调迁移学习的baselines以及元学习方法,实验结果如下:

下图展示了不同方法在MNIST,CIFAR-10,CIFAR-100上的实验结果:

3.2 可解释性

下图展示了two heads URT在测试任务中生成的注意力权重:

左边蓝色表示first head学习到的注意力权重热图,右边橙色表示second head学习到的注意力权重热图。

3.3 通用性

为了进一步证明URT在学习universal representation上的优越性能,作者基于一组不同的backbone架构进行了实验。Following SUR,作者先在ILSVRC上训练一个backbone,然后在其他每个数据集上分别学习单独的FiLM层,学习domain-specific backbone。实验结果如下表所示:

3.4 消融实验

下表中w/o Wqw/o\ W^q表示计算queries时Wq=0W^q = 0,保留偏置bqb^qw/o Wkw/o\ W^k表示计算keys时Wk=0W^k=0,保留偏置bkb^k。其他行同理,实验结果如下表所示:

作者还对head数目做了消融实验,结果如下表所示:

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信