持续学习的目的是模仿人类在连续任务中持续积累知识的能力,其主要挑战是在持续学习新任务后如何保持对以前所学任务的表现,即避免灾难性遗忘(catastrophic forgetting)。持续学习和多任务学习(multi-task learning)的区别在于:后者在同一时间可以得到所有任务,模型可以同时学习所有任务;而在持续学习中任务 一个一个出现,模型在某一时刻只能学习一个任务的知识,并且在学习新知识的过程中避免遗忘旧知识。
南加州大学联合 Google Research 提出了一种解决持续学习(continual learning)的新方法通道式轻量级重编码(Channel-wise Lightweight Reprogramming [CLR]):通过在固定任务不变的 backbone 中添加可训练的轻量级模块,对每层通道的特征图进行重编程,使得重编程过的特征图适用于新任务。这个可训练的轻量级模块仅仅占整个backbone的0.6%,每个新任务都可以有自己的轻量级模块,理论上可以持续学习无穷多新任务而不会出现灾难性遗忘。文已发表在 ICCV 2023。
-
论文地址: https://arxiv.org/pdf/2307.11386.pdf
-
项目地址: https://github.com/gyhandy/Channel-wise-Lightweight-Reprogramming
-
数据集地址: http://ilab.usc.edu/andy/skill102
通常解决持续学习的方法主要分为三大类:基于正则化的方法、动态网络方法和重放方法。
-
基于正则化的方法是模型在学习新任务的过程中对参数更新添加限制,在学习新知识的同时巩固旧知识。
-
动态网络方法是在学习新任务的时候添加特定任务参数并对旧任务的权重进行限制。
-
重放方法假设在学习新任务的时候可以获取旧任务的部分数据,并与新任务一起训练。
本文提出的CLR方法是一种动态网络方法。下图表示了整个过程的 pipeline:研究者使用与任务无关的不可变部分作为共享的特定任务参数,并添加特定任务参数对通道特征进行重编码。与此同时为了尽可能地减少训练每个任务的重编码参数,研究者只需要调整模型中内核的大小,并学习从 backbone 到特定任务知识的通道线性映射来实现重编码。在持续学习中,对于每一个新任务都可以训练得到一个轻量级模型;这种轻量级的模型需要训练的参数很少,即使任务很多,总共需要训练的参数相对于大模型来说也很小,并且每一个轻量级模型都可以达到很好的效果。
持续学习关注于从数据流中学习的问题,即通过特定的顺序学习新任务,不断扩展其已获得的知识,同时避免遗忘以前的任务,因此如何避免灾难性遗忘是持续学习研究的主要问题。研究者从以下三个方面考虑:
-
重用而不是重学:对抗重编码(Adversarial Reprogramming [1])是一种通过扰动输入空间,在不重新学习网络参数的情况下,"重编码" 一个已经训练并冻结的网络来解决新任务的方法。研究者借用了 “重编码” 的思想,在原始模型的参数空间而不是输入空间进行了更轻量级但也更强大的重编程。
-
通道式转换可以连接两个不同的核:GhostNet [2] 的作者发现传统网络在训练后会得到一些相似的特征图,因此他们提出了一种新型网络架构 GhostNet:通过对现有特征图使用相对廉价的操作(比如线性变化)生成更多的特征图,以此来减小内存。受此启发,本文方法同样使用线性变换生成特征图来增强网络,这样就能以相对低廉的成本为各个新任务量身定制。
-
轻量级参数可以改变模型分布:BPN [3] 通过在全连接层中增加了有益的扰动偏差,使网络参数分布从一个任务转移到另一个任务。然而 BPN 只能处理全连接层,每个神经元只有一个标量偏置,因此改变网络的能力有限。相反研究者为卷积神经网络(CNN)设计了更强大的模式(在卷积核中增加 “重编码” 参数),从而在每项新任务中实现更好的性能。
通道式轻量级重编码首先用一个固定的 backbone 作为一个任务共享的结构,这可以是一个在相对多样性的数据集(ImageNet-1k, Pascal VOC)上进行监督学习的预训练模型,也可以是在无语义标签的代理任务上学习的自监督学习模型(DINO,SwAV)。不同于其他的持续学习方法(比如 SUPSUP 使用一个随机初始化的固定结构,CCLL 和 EFTs 使用第一个任务学习后的模型作为 backbone),CLR 使用的预训练模型可以提供多种视觉特征,但这些视觉特征在其他任务上需要 CLR 层进行重编码。具体来说,研究者利用通道式线性变化(channel-wise linear transformation)对原有卷积核产生的特征图像进行重编码。
图中展示了 CLR 的结构。CLR 适用于任何卷积神经网络,常见的卷积神经网络由 Conv 块(Residual 块)组成,包括卷积层、归一化层和激活层。
研究者首先把预训练的 backbone 固定,然后在每个固定卷积块中的卷积层后面加入通道式轻量级重编程层 (CLR 层)来对固定卷积核后的特征图进行通道式线性变化。
给定一张图片 X,对于每个卷积核 fkO,可以得到通过卷积核的特征图 X’,其中每个通道的特征可以表示为 x'k=fk(X);之后用 2D 卷积核来对 X’的每个通道 x'k 进行线性变化,假设每个卷积核 fk 对应的线性变化的卷积核为 CLRk(),那么可以得到重编码后的特征图 。研究者将 CLR 卷积核的初始化为同一变化核(即对于的 2D 卷积核,只有中间参数为 1,其余都为 0),因为这样可以使得最开始训练时原有固定 backbone 产生的特征和加入 CLR layer 后模型产生的特征相同。同时为了节约参数并防止过拟合,研究者并不会在的卷积核后面加入 CLR 层,CLR 层只会作用在的卷积核后。对于经过 CLR 作用的 ResNet50 来说,增加的可训练参数相比于固定的 ResNet50 backbone 只占 0.59%。
对于持续学习,加入 CLR 的模型(可训练的 CLR 参数和不可训练的 backbone)可以依次学习每个任务。在测试的时候,研究者假设有一个 task oracle 可以告诉模型测试图片属于哪个任务,之后固定的 backbone 和相对应的任务专有 CLR 参数可以进行最终预测。由于 CLR 具有绝对参数隔离的性质(每个任务对应的 CLR 层参数都不一样并且共享的 backbone 不会变化),因此 CLR 不会受到任务数量的影响。
数据集:研究者使用图像分类作为主要任务,实验室收集了 53 个图像分类数据集,有大约 180 万张图片和 1584 个种类。这 53 个数据集包含了 5 个不同的分类目标:物体识别,风格分类,场景分类,计数和医疗诊断。
基线:研究者选择了 13 种基线,大概可以分成 3 个种类
-
动态网络:PSP,SupSup,CCLL,Confit,EFTs
-
正则化:EWC,online-EWC,SI,LwF
-
还有一些不属于持续学习的基线,比如 SGD 和 SGD-LL。SGD 学习每个任务时对整个网络进行微调;SGD-LL 是一个变体,它对所有任务都使用一个固定的 backbone 和一个可学习的共享层,其长度等于所有任务最大的种类数量。
为了评估所有方法在克服灾难性遗忘的能力,研究者跟踪了学习新任务后每个任务的准确性。如果某个方法存在灾难性遗忘,那么在学习新任务后,同一任务的准确率就会很快下降。一个好的持续学习算法可以在学习新任务后保持原有的表现,这就意味着旧任务应受到新任务的影响最小。下图展示了本文方法从学完第 1 到第 53 个任务后第 1 个任务的准确率。总体而言,本文方法可以保持最高的准确率。更重要的是它很好地避免了灾难性遗忘并保持和原始训练方式得到的相同准确率无论持续学习多少个任务。
下图所有方法在学完全部任务后的平均准确率。平均准确率反映了持续学习方法的整体表现。由于每个任务的难易程度不同,当增加一项新任务时,所有任务的平均精确度可能会上升或下降,这取决于增加的任务是简单还是困难。
对于持续学习,虽然获得更高的平均准确率非常重要,但是一个好的算法也希望可以最大限度地减少对额外网络参数的要求和计算成本。"添加一项新任务的额外参数" 表示与原始 backbone 参数量的百分比。本文以 SGD 的计算成本为单位,其他方法的计算成本按 SGD 的成本进行归一化处理。
本文方法通过在相对多样化的数据集上使用监督学习或自监督学习的方法来训练得到预训练模型,从而作为与任务无关的不变参数。为了探究不同预训练方法的影响,本文选择了四种不同的、与任务无关的、使用不同数据集和任务训练出来的预训练模型。对于监督学习,研究者使用了在 ImageNet-1k 和 Pascal-VOC 在图像分类上的预训练模型;对于自监督学习,研究者使用了 DINO 和 SwAV 两种不同方法得到的预训练模型。下表展示了使用四种不同方法得到预训练模型的平均准确率,可以看出来无论哪种方法最后的结果都很高(注:Pascal-VOC 是一个比较小的数据集,所以准确率相对低一点),并且对不同的预训练 backbone 具有稳健性。
[1]. Adversarial reprogramming of neural networks.
[2]. Ghostnet: More features from cheap operations.
[3].Beneficial perturbation network for designing general adaptive artificial intelligence systems.
© 版权声明
文章版权归作者所有,未经允许请勿转载。
关注公众号,免费获取chatgpt账号
相关文章