Concept Learners for Few-Shot Learning

https://arxiv.org/pdf/2007.07375

https://github.com/snap-stanford/comet

Concept Learners for Few-Shot Learning,2021,ICLR

总结:文章立足点比较好,从人类认知模式出发,即人类认知的核心在于结构化、可复用的concepts,比如鸟类生物的羽毛、翅膀、喙,通过组合这些concepts我们可以快速、准确的识别鸟类生物。但是现有的元学习方法都是学习一个joint、unstructured先验知识,这限制了模型的泛化能力。因此作者从concept这个维度,提出了COMET模型,通过聚合多个不同concept learner学到的信息进行小样本分类。但是从实验结果看,COMET的主要缺点是对concept的依赖过大。使用人工标注concept时,和baseline相比COMET性能有很大幅度提升,但是使用自动提取的concept时,性能提升有限,而且这种性能的提升可能更多的来自于模型的复杂度(因为N个concept learner相当于堆叠了N个原型网络,模型复杂度提高了N倍)。

1. 简介

1.1 摘要

Developing algorithms that are able to generalize to a novel task given only a few labeled examples represents a fundamental challenge in closing the gap between machine- and human-level performance. The core of human cognition lies in the structured, reusable concepts that help us to rapidly adapt to new tasks and provide reasoning behind our decisions. However, existing meta-learning methods learn complex representations across prior labeled tasks without imposing any structure on the learned representations. Here we propose COMET, a meta-learning method that improves generalization ability by learning to learn along human-interpretable concept dimensions. Instead of learning a joint unstructured metric space, COMET learns mappings of high-level concepts into semi-structured metricspaces, and effectively combines the outputs of independent concept learners. We evaluate our model on few-shot tasks from diverse domains, including fine-grained image classification, document categorization and cell type annotationon a novel dataset from a biological domain developed in our work. COMET significantly outperforms strong meta-learning baselines, achieving 6–15% relative improvement on the most challenging 1-shot learning tasks, while unlike existing methods providing interpretations behind the model’s predictions.

为了缩小机器和人类学习性能之间的差距,开发能够在之给定少量样本下就能适应新任务的算法是一个基本挑战。人类认知的核心在于结构化、可复用的concepts,它可以帮助我们快速适应新任务并提供决策背后的推理。但是现有的元学习方法在学习复杂表征的过程中,并不会添加任何结构。本文我们提出了COMET模型——一种元学习方法,从人类可解释的concept维度出发,通过learning to learn来提高模型泛化能力。和元学习不同,COMET不是学习一个联合的无结构的度量空间,而是学习将高层级的concept映射到半结构化度量空间,然后讲这些独立的concept learners有效结合到一起。我们在多个领域,包括图像分类、文本分类、生物,评估了我们的方法。在1-shot设定下,和其他元学习方法相比,COMET取得了6%-15%的性能提升,同时不像其他方法,我们的模型是可解释的。

1.2 本文工作

先验知识: 要理解本文,需要先理解“concept”这一概念。所谓“concept”,我觉得可以理解成物体具有代表性的局部组件,以鸟为例,它的羽毛、嘴巴、翅膀等都是非常重要的的concept,可以帮助人类进行识别。“ Intuitively, concepts can be seen as part-based representations of the input and reflect the way humans reason about the world. ”原文中这句话也说的比较形象。

动机: 人类的知识是结构化的,由各种可以重复使用的concepts组成。比如我们在识别新的鸟类物种时,我们早就知道了羽毛、翅膀、喙这些concept的大概样子,所以来了一张鸟类照片,我们只需要关注这些具体concepts,然后将它们组合到一起就能识别它的种类。但是现有的元学习方法,学习到的是一个joint and unstructured先验知识,这限制了模型的泛化能力。

本文工作: 从人类可解释的concept维度,提出了一种新的元学习模型COMET。COMET为每个concept学习一个单独的度量空间,这里作者借鉴原型网络思想,为每个concept学习一个concept prototype。然后将不同concept learners和concept prototype中的信息有效聚合在一起进行最终预测。和现有的元学习方法相比,比如原型网络、匹配网络等等,COMET模型是可解释的,这是十分重要的,尤其在小样本场景下。

2. 方法

假设: 假定输入数据的维度可以划分成若干相关维度的子集,分别对应人类可解释的高层级的concept。这些集合可能有重叠、噪声、不完整,但是在许多真实场景中这种划分是存在的。例如,CV中concept可以对应图像的segment,NLP中对应语义相关的单词,生物学中可以对应于外部的知识库和本体。其实在很多领域中,已经有现成的concepts,或者可以通过现有技术进行自动生成。

Concept符号表示:C={c(j)}j=1N\mathcal C = \{\mathbf c^{(j)}\}_{j=1}^N表示N个concepts集合,其中每个concept c(j){0,1}D\mathbb c^{(j)}\in\{0,1\}^D为一个D维binary向量,cij=1c_i^{j}=1表示 ithi-th 维度可以用来描述该concept,DD表示输入数据的维度。对于C\mathcal C,我们不做任何限制,这也就意味着里面可以有重复或者冗余的concept。

一、算法细节

COMET算法如上图1所示,它不是只学习一个映射函数fθ:RDRMf_{\boldsymbol{\theta}}: \mathbb{R}^{D} \rightarrow \mathbb{R}^{M},而是为每一个concept都学习一个单独的embedding函数fθ(j):RDRMf_{\boldsymbol{\theta}}^{(j)}: \mathbb{R}^{D} \rightarrow \mathbb{R}^{M}。如图1所示,concept嵌入函数被称之为concept learners,是一个深度神经网络下的非线性函数。每个concept learner j都会计算一个concept prototypes pk(j)p_k^{(j)},表示 kk 类别下该concept的原型。具体计算方式如下:

pk(j)=1Sk(xi,yi)Skfθ(j)(xic(j))(3)\mathbf{p}_{k}^{(j)}=\frac{1}{\left|\mathcal{S}_{k}\right|} \sum_{\left(\mathbf{x}_{i}, y_{i}\right) \in \mathcal{S}_{k}} f_{\boldsymbol{\theta}}^{(j)}\left(\mathbf{x}_{i} \circ \mathbf{c}^{(j)}\right)\tag 3

其中\circ表示Hadamard product。最终每个类别 kk 都会得到N个concept原型 {pk(j)}j=1K\{p_k^{(j)}\}_{j=1}^K

二、类别判定

给定一个查询数据 xq\mathbf x_q,先得到它的所有concept embeddings后,然后通过计算这些embedding到对应concept原型之间的距离来判定其类别。具体计算公式如下:

pθ(y=kxq)=exp(jd(fθ(j)(xqc(j)),pk(j)))kexp(jd(fθ(j)(xqc(j)),pk(j)))(4)p_{\boldsymbol{\theta}}\left(y=k \mid \mathbf{x}_{q}\right)=\frac{\exp \left(-\sum_{j} d\left(f_{\boldsymbol{\theta}}^{(j)}\left(\mathbf{x}_{q} \circ \mathbf{c}^{(j)}\right), \mathbf{p}_{k}^{(j)}\right)\right)}{\sum_{k^{\prime}} \exp \left(-\sum_{j} d\left(f_{\boldsymbol{\theta}}^{(j)}\left(\mathrm{x}_{q} \circ \mathbf{c}^{(j)}\right), \mathbf{p}_{k^{\prime}}^{(j)}\right)\right)}\tag 4

其中d()d(·)表示距离函数,作者这里采用欧式距离,因为实验中作者发现欧式距离效果比余弦距离好。

三、可解释性

除了算法本身外,COMET还可以通过下面两个功能,对算法背后的预测结果提供强有力的解释。

  1. Local and global concept importance score

    COMET模型中,每个类别不在是由一个原型表示,而是由N个concept原型表示。

    • local score:对于某个查询样本xq\mathbf x_qd(fθ(j)(xqc(j)),pk(j))d\left(f_{\boldsymbol{\theta}}^{(j)}\left(\mathbf{x}_{q} \circ \mathbf{c}^{(j)}\right), \mathbf{p}_{k}^{(j)}\right)的倒数可以表示local得分,得分越高表示该部分concept对于该样本类型判别来说越重要。这也直观的解释了模型在预测xq\mathbb x_q类型时,背后的推理过程。
    • global score:同样地,对于一组查询样本或者整个类别样本,我们可以从global score角度对模型的行为进行解释。通过计算每种concept embedding和concept prototype之间距离的平均值,可以得到对于这一组样本,哪中concept最重要。
  2. 寻找局部相似样本

    给定一个concept jj,COMET可以通过计算concept embedding和concept prototype之间的距离并进行排序,找到局部相似或者差异很大的样本。

3. 实验

数据集: 分别在CV、NLP、生物三个领域对COMET的性能进行了测试,使用的数据集分别为CUB、Reuters和Tabula Muris。

Baselines: FineTune(ICLR 2019)、匹配网络、MAML、关系网络、MetaOptNet、DeepEMD(只能用于图像分类)和原型网络。

实验设置: 所有实验都是5-way设置,测试阶段随机采样600个episode计算平均精度。CUB数据集上采用4层CNN为骨架,输入大小为84×8484\times 84。更具体信息可以参看原文附件A。

3.1 基础实验

一、对比实验

下表1展示了COMET和其他baseline的实验结果,可以看到COMET模型在性能上取得了大幅度提升,平均接近10%。

为了证明COMET性能的提升不是额外的权重带来的(因为COMET中N个concept learner相当于N个原型网络),作者对比了COMET和强化原型网络,结果如下表2所示:

其中第一行表示训练多个原型网络投票表决,第二行表示concept learner之间共享权重,第三行为完整的COMET。可以看到和ProtoNetEns相比,COMET的性能仍然有较大幅度提升尤其在后两个数据集上。不过ProtoNetEns和ProtoNet相比性能也得到了较多提升,因此说明COMET的性能提升也有部分来源于额外的权重(堆复杂度)。

二、concept数量消融

可以看到,随着concept数量增加,模型性能会提升,但concept躲到一定程度后,性能提升并不明显,存在性能上限。另外从上图右边可以看出,concept数量即使在1500时模型性能并没有受到影响,这证明了COMET的稳定性很好,不会受冗余的concept的影响。

3.2 无监督concept标注

前面的数据集都是基于人工标注的concept进行的,下面的实验利用自动提取的concept测试COMET的性能。下图5展示了CUB数据集上自动提取的landmarks,concepts数量设定为30:

虽然提取出的坐标比较粗糙,但是如下表3所示,和最好方法相比,COMET方法依旧有性能提升:

这里可以看到和人工标记concept相比,COMET带来的性能提升大幅度下降。虽然基于自动提取concept的COMET性能也得到了部分提升,但是我觉得这种性能的提升可能来源于模型复杂度,即额外参数。作者这里没有和ProtoNetEns进行对比,我觉得如果对比两者性能的话,差别应该不大。因此,可以看出COMET对concept的依赖是很大的,如果使用模糊的concept,很难带来令人满意的性能提升。

3.3 可解释性

回答下列4个问题:

  1. 对于某个查询样本,哪个concept对于类型判别最重要

  2. 对于某个类别,哪个concept最重要

  3. 哪些样本具有局部相似性

    global concept指的是将整张图片看做一个concept,这时往往反映的时背景相似度。

  4. 哪些样本能够最好的代表concept原型

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2022 Yin Peng
  • 引擎: Hexo   |  主题:修改自 Ayer
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信