MELR: Meta-Learning via Modeling Episode-Level Relationships for Few-Shot Learning

https://openreview.net/pdf?id=D3PcGLdMx0

MELR: Meta-Learning via Modeling Episode-Level Relationships for Few-Shot Learning,2021,ICLR

总结:这是一篇针对小样本学习模型稳定性的文章。作者从小样本元学习模型抗干扰能力比较弱,在poor sampling情况下稳定性比较差这一点入手,提出了一种新的元学习框架MELR。具体来说基于CEAM和CECR这两个技术,提升了元学习模型在poor sampling下的稳定性。前者是一种跨episode的自注意力方法,后者是一种跨episode的正则化方法。个人觉得模型的实际意义不是很大,实际应用中,尤其标准数据集中这种poor sampling并不常见,对模型的性能应该没太多影响。作者提出的MELR模型和传统模型相比增加了一定复杂度,但是从标准数据集上的实验来看,性能并没有提高太多。不过,个人认为小样本模型抗干扰能力或者说稳定性确实是一个比较有研究价值的方向。

1. 简介

1.1 摘要

Most recent few-shot learning (FSL) approaches are based on episodic training whereby each episode samples few training instances (shots) per class to imitate the test condition. However, this strict adhering to test condition has a negative side effect, that is, the trained model is susceptible to the poor sampling of few shots. In this work, for the first time, this problem is addressed by exploiting inter-episode relationships. Specifically, a novel meta-learning via modeling episode-level relationships (MELR) framework is proposed. By sampling two episodes containing the same set of classes for meta-training, MELR is designed to ensure that the meta-learned model is robust against the presence of poorly-sampled shots in the meta-test stage. This is achieved through two key components: (1)a Cross-Episode Attention Module (CEAM) to improve the ability of alleviating the effects of poorly-sampled shots, and (2) a Cross-Episode Consistency Regularization (CECR) to enforce that the two classifiers learned from the two episodes are consistent even when there are unrepresentative instances. Extensive experiments for non-transductive standard FSL on two benchmarks show that our MELR achieves 1.0%–5.0% improvements over the baseline (i.e., ProtoNet) used for FSLin our model and outperforms the latest competitors under the same settings.

现有小样本学习方法大多数都是基于episodic训练,每个episode中每个类别只采样少量样本,来模拟测试环境。但是这种严格的遵循测试条件存在负面影响,即训练出的模型容易受到小样本poor sampling的影响。本文,我们首次通过挖掘不同episode之间的关系来解决这个问题。具体来说,我们通过对episode-level关系进行建模,提出了一种新的元学习方法MELR。通过每次采样两个类别相同的episode用于元训练,MELR能够保证模型在poorly-sampled的测试环境下也能保持稳定。这个功能是通过两个关键组件实现的:(1)Cross-Episode Attention Module(CEAM),帮助模型抵消pooly-sample shots的影响;(2)Cross-Episode Consistency Regularization(CECR),当存在unrepresentative样本时,让来自不同episode的两个分类器竟可能保持一致。两个标准FSL数据集下的大量非转导实验表明我们的MELR模型和baseline相比能取得1%~5%的性能提升,并且在相同设置下优于最新的竞争对手。

注:什么是poorly-sampled样本?就是品质比较差、噪声比较多的样本。比如,用一只背对着我们,同时身体被部分遮挡的猫来训练分类器,这时训练出的模型可能无法识别品质好的猫的图片。

1.2 本文工作

背景: 传统的深度神经网络,比如CNN,依赖于大量有标签数据,但是在很多实际应用场景中无法获取充足的有标签样本。因此,近些年来小样本学习得到了广泛研究,尤其是基于元学习的小样本学习方法。

动机: 现有的元学习方法基本都是采用episode 方式训练模型,即在训练时利用样本充足的基类构造episode。每个episode都模拟测试环境构建的,即一个N-way K-shot的任务。这种完全模拟测试环境的方式可以保证元学习模型能够快速适应新的小样本任务。但是这会给模型带来负面影响:模型容易受到poor sampling的影响,即模型稳定性不够

原因: 上述所说的弊端,本质上还是因为小样本场景下,训练样本数量不够多,模型抗噪声能力比较弱,容易受到对抗攻击,导致模型稳定性较差,尤其在1-shot设定下。比如,在1-shot设定下的support集中,采用一张背对着我们同时身体被部分遮挡的猫来训练分类器,那么查询集中一张品质较高的猫可能就无法被识别。

本文工作: 作者提出了一种新的元学习框架MELR,通过CEAM和CECR两个功能组件提高模型在poor sampling下的稳定性 。前者是一个跨episode的自注意力方法,后者是一个跨episode的正则化方法。

2. 方法

在传统元学习方法中,生成N-way K-shot episode的方式如下:

(1)从基类CbC_b中选取N个类别,并re-index得到Ce={1,2,...,N}C_e=\{1,2,...,N\}

(2)CeC_e中每个类别随机选取K个support样本和Q个query样本,得到Se={(xi,yi)yiCe,i=1,2,...,N×K}S_e=\{(x_i,y_i)|y_i\in C_e,i=1,2,...,N\times K\}Qe={(xi,yi)yiCe,i=1,2,...,N×Q}Q_e=\{(x_i,y_i)|y_i\in C_e,i=1,2,...,N\times Q\}

对于episode e,损失函数通常定义如下:

Lfsc(e)=E(xi,yi)QeL(yi,f(ψ(xi);Se))(1)L_{f s c}(e)=\mathbb{E}_{\left(x_{i}, y_{i}\right) \in \mathcal{Q}_{e}} L\left(y_{i}, f\left(\psi\left(x_{i}\right) ; \mathcal{S}_{e}\right)\right)\tag 1

其中ψ\psi表示特征提取函数,f(;Se):RdRNf\left(\cdot ; \mathcal{S}_{e}\right): \mathbb{R}^{d} \rightarrow \mathbb{R}^{N}表示得分函数(比如softmax),L(,)L(·,·)表示损失函数(通常用交叉熵损失)。通过多个上面这样的episode不断训练模型后,在meta-test episode上测试模型。

上图展示了本文提出的MELR模型框架图。可以看到,和传统元学习相比,区别在于每一次训练都会构建两个类别相同的episode,然后经过CEAM对所有样本特征进行转换后,再计算查询集样本的类别得分,最后损失函数中会添加一项正则化损失。下面具体介绍下作者提出的CEAM和CECR方法。

2.1 CEAM

每一轮从同一个类别集合CeC_e采样两个episode,e(1)=(Se(1),Qe(1))e^{(1)}=(S_e^{(1)},Q_e^{(1)})e(2)=(Se(2),Qe(2))e^{(2)}=(S_e^{(2)},Q_e^{(2)})e(1)e(2)=e^{(1)}\cap e^{(2)}=\emptyset。我了降低badly-sampled样本的影响,作者提出了CEAM——跨episode注意力模型。

具体来说,S(1)=[ψ(xi)T;xiSe(1)]RNK×d\mathbf{S}^{(1)}=\left[\psi\left(x_{i}\right)^{T} ; x_{i} \in \mathcal{S}_{e}^{(1)}\right] \in \mathbb{R}^{N K \times d}Q(1)=[ψ(xi)T;xiQe(1)]RNQ×d\mathbf{Q}^{(1)}=\left[\psi\left(x_{i}\right)^{T} ; x_{i} \in \mathcal{Q}_{e}^{(1)}\right] \in\mathbb R^{NQ\times d}分别表示支持集和查询集样本的特征矩阵,S(2)\mathbf S^{(2)}Q(2)\mathbf Q^{(2)}同理。同时F(1)=[S(1);Q(1)]RN(K+Q)×d\mathbf{F}^{(1)}=\left[\mathbf{S}^{(1)} ; \mathbf{Q}^{(1)}\right] \in \mathbb{R}^{N(K+Q) \times d}F(2)\mathbf F^{(2)}同理。具体计算方法如下:

F^(1)=CEAM(F(1),S(2),S(2))=F(1)+softmax(FQ(1)SK(2)Td)SV(2)(2)\hat{\mathbf{F}}^{(1)}=\operatorname{CEAM}\left(\mathbf{F}^{(1)}, \mathbf{S}^{(2)}, \mathbf{S}^{(2)}\right)=\mathbf{F}^{(1)}+\operatorname{softmax}\left(\frac{\mathbf{F}_{Q}^{(1)} \mathbf{S}_{K}^{(2) T}}{\sqrt{d}}\right) \mathbf{S}_{V}^{(2)}\tag 2

其中输入的Q、K、V三元组是将原有特征矩阵映射到一个潜在空间中,计算方式如下:

FQ(1)=F(1)WQRN(K+Q)×d(3)\begin{aligned}\mathbf{F}_{Q}^{(1)}&=\mathbf{F}^{(1)} \mathbf{W}_{Q} \in \mathbb{R}^{N(K+Q) \times d} \end{aligned}\tag 3

SK(2)=S(2)WKRNK×d(4)\begin{aligned} \mathbf{S}_{K}^{(2)}=\mathbf{S}^{(2)} \mathbf{W}_{K} \in \mathbb{R}^{N K \times d} \end{aligned}\tag 4

SV(2)=S(2)WVRNK×d(5)\begin{aligned} \mathbf{S}_{V}^{(2)}&=\mathbf{S}^{(2)} \mathbf{W}_{V} \in \mathbb{R}^{N K \times d} \end{aligned}\tag 5

对于e(2)e^{(2)},计算方式也是一样的:

F^(2)=CEAM(F(2),S(1),S(1))=F(2)+softmax(FQ(2)SK(1)Td)SV(1)(6)\hat{\mathbf{F}}^{(2)}=\operatorname{CEAM}\left(\mathbf{F}^{(2)}, \mathbf{S}^{(1)}, \mathbf{S}^{(1)}\right)=\mathbf{F}^{(2)}+\operatorname{softmax}\left(\frac{\mathbf{F}_{Q}^{(2)} \mathbf{S}_{K}^{(1) T}}{\sqrt{d}}\right) \mathbf{S}_{V}^{(1)}\tag 6

其中WQ,WK and WV\mathbf{W}_{Q}, \mathbf{W}_{K} \text { and } \mathbf{W}_{V}都是可学习参数。

2.2 CECR

为了进一步提高模型对badly-sampled样本的抗干扰能力,作者提出了CECR方法,强迫两个分类器能够保持一致的预测。具体来说,CECR采用基于知识蒸馏的策略。f(;S^(1)):RdRNf(\cdot ; \hat{\mathbf{S}}^{(1)}): \mathbb{R}^{d} \rightarrow \mathbb{R}^{N}f(;S^(2)):RdRNf(\cdot ; \hat{\mathbf{S}}^{(2)}): \mathbb{R}^{d} \rightarrow \mathbb{R}^{N}分别表示基于S^(1)\mathbf{\hat S}^{(1)}S^(2)\mathbf{\hat S}^{(2)}训练得到的分类器。为了评价哪个分类器性能更好,我们在Q^e(1,2)=Q^e(1)Q^e(2)={(q^i(1,2),yi(1,2)),i=1,2,,2NQ})\hat{\mathcal{Q}}_{e}^{(1,2)}=\hat{\mathcal{Q}}_{e}^{(1)} \cup \hat{\mathcal{Q}}_{e}^{(2)}=\{(\hat{\mathbf{q}}_{i}^{(1,2)}, y_{i}^{(1,2)}), i=1,2, \cdots, 2 N Q\})计算两个分类器的准确度。

我们将准确度高的分类器作为teacher分类器,另一个作为student分类器。下面假设f(;S^(1))f(\cdot ; \hat{\mathbf{S}}^{(1)})准确度高于f(;S^(2))f(\cdot ; \hat{\mathbf{S}}^{(2)}),知识蒸馏损失定义如下:

Lcecr(e(1),e(2);T)=E(q^i(1,2),yi(1,2))Q^e(1,2)L(f(q^i(1,2);S^(1)),f(q^i(1,2);S^(2));T)(7)L_{c e c r}\left(e^{(1)}, e^{(2)} ; T\right)=\mathbb{E}_{\left(\hat{\mathbf{q}}_{i}^{(1,2)}, y_{i}^{(1,2)}\right) \in \hat{\mathcal{Q}}_{e}^{(1,2)} }L^{\prime}\left(f\left(\hat{\mathbf{q}}_{i}^{(1,2)} ; \hat{\mathbf{S}}^{(1)}\right), f\left(\hat{\mathbf{q}}_{i}^{(1,2)} ; \hat{\mathbf{S}}^{(2)}\right) ; T\right)\tag 7

TT温度参数,如果使用softmax函数作为分类器,则有:

L(f(q^i(1,2);S^(1)),f(q^i(1,2);S^(2));T)=j=1Nσj(f(q^i(1,2);S^(1));T)log(σj(f(q^i(1,2);S^(2));T))(8)\begin{aligned}& L^{\prime}\left(f\left(\hat{\mathbf{q}}_{i}^{(1,2)} ; \hat{\mathbf{S}}^{(1)}\right), f\left(\hat{\mathbf{q}}_{i}^{(1,2)} ; \hat{\mathbf{S}}^{(2)}\right) ; T\right) \\=&-\sum_{j=1}^{N} \sigma_{j}\left(f\left(\hat{\mathbf{q}}_{i}^{(1,2)} ; \hat{\mathbf{S}}^{(1)}\right) ; T\right) \log \left(\sigma_{j}\left(f\left(\hat{\mathbf{q}}_{i}^{(1,2)} ; \hat{\mathbf{S}}^{(2)}\right) ; T\right)\right)\end{aligned}\tag 8

需要注意的是反向传播时,f(;S^(1))f(·;\mathbf{\hat S}^{(1)})的梯度会被掐断,因为teacher分类器的输出用来指导学生分类器了。(这应该是知识蒸馏中的常规操作)。

2.3 模型训练

对于每个episode的分类损失函数定义如下:

Lfsc(e)=E(q^i,yi)Q^eL(yi,fProtoNet (q^i;S^))=E(q^i,yi)Q^elogσyi(fProtoNet (q^i;S^))(9)\begin{aligned}L_{f s c}(e) &=\mathbb{E}_{\left(\hat{\mathbf{q}}_{i}, y_{i}\right) \in \hat{\mathcal{Q}}_{e}} L\left(y_{i}, f_{\text {ProtoNet }}\left(\hat{\mathbf{q}}_{i} ; \hat{\mathbf{S}}\right)\right) \\&=\mathbb{E}_{\left(\hat{\mathbf{q}}_{i}, y_{i}\right) \in \hat{\mathcal{Q}}_{e}}-\log \sigma_{y_{i}}\left(f_{\text {ProtoNet }}\left(\hat{\mathbf{q}}_{i} ; \hat{\mathbf{S}}\right)\right)\end{aligned}\tag 9

在结合CECR中的模型损失,我们定义模型的最终损失为:

Ltotal =12(Lfsc(e(1))+Lfsc(e(2)))+λLcecr(e(t),e(s);T)(10)L_{\text {total }}=\frac{1}{2}\left(L_{f s c}\left(e^{(1)}\right)+L_{f s c}\left(e^{(2)}\right)\right)+\lambda L_{c e c r}\left(e^{(t)}, e^{(s)} ; T\right)\tag {10}

模型的伪代码如下图所示:

3. 实验

数据集: miniImageNet和tieredImageNet

实现细节: 分别使用Conv4-64,Conv4-512和ResNet作为特征提取器。为了加速训练过程,三个模型都在之前工作基础上先进行了预训练。

3.1 对比实验

注: ProtoNet \text { ProtoNet }^{\dagger}表示训练是也采样两个episode。

  1. ResNet-12优于Conv4-512,Conv4-512优于Conv4-64。这个是符合我们的直觉的,更大的神经网络,学习到的图像特征质量更好。
  2. MELR和其他所有方法相比,在所有设定下都取得了最优效果。
  3. 和baseline ProtoNet相比,MELR在1-shot下取得的性能提升大于5-shot,因为1-shot下模型poor sampling的可能性更大。这从侧面佐证了作者提出的MELR模型可以提高模型稳定性。

3.2 拓展实验

  1. 消融实验

    (1)图(a)表示对MELR的各个组件进行消融,可以看到CECR和CEAM对模型性能的提升都有帮助。

    (2)图(b)对CEAM的不同实现方式进行了对比,本文CEAM的实现方式是将support样本作为“keys”和“values”,另一个episode的所有样本作为“queries”,表示为SupportAllSupport\rightarrow All。还可以将prototypes作为“keys”和“values”或者单独将query样本作为"queries",分别表示为: Prototype  Query, Support  Query, and Prototype  All \text { Prototype } \rightarrow \text { Query, Support } \rightarrow \text { Query, and Prototype } \rightarrow \text { All }

    (3)图©对比了CECR的不同实现方法。

  2. 可视化

    (1)图(a)~©可视化了  ProtoNet \text { ProtoNet }^{\dagger}ProtoNet+CEAMProtoNet^{\dagger}+CEAMProtoNet+CEAM+CECRProtoNet^{\dagger}+CEAM+CECR 三个模型中的数据分布。

    (2)图(d)~(e)可是花了查询集合支持集中的注意力权重热图。

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信