本文共 5405 字,大约阅读时间需要 18 分钟。
近年来,深度学习技术在图像识别和目标检测领域取得了显著进展。然而,这些模型的训练依赖于大量标注数据,其可分类性严受数据范围的限制。在现实场景中,模型往往难以泛化到未见过的新类别,这就引出了少样本学习(Few-Shot Learning, FSL)这一重要研究课题。
少样本学习是机器学习的一个重要分支,旨在通过小量训练数据实现对新任务的分类能力。与传统监督学习不同,少样本学习模型仅需少量训练样本即可泛化到新类别。例如,在医学图像分析中,某些罕见疾病可能缺乏足够的训练数据,少样本学习提供了一个有效的解决方案。
少样本学习的分类可以根据训练样本数量进一步细分为以下几种类型:
解决 Few-Shot Learning 问题通常有两种主要策略:
数据级方法 (Data-Level Adaptation, DLA):如果基础数据集不足以支持目标任务的训练,DLA通过引入更多数据来解决问题。例如,在分类任务中,利用基础数据集中的其他类别图像进行预训练。
参数级方法 (Parameter-Level Adaptation, PLA):从参数空间角度优化模型,通过限制模型的自由度、使用正则化等方法防止过拟合。这种方法依赖于元学习技术,通过在大型参数空间中搜索最佳模型。
目前有四种主要的少样本学习图像分类算法:
基于梯度的元学习 (Gradient-Based Meta-Learning, GBML)
GBML通过基础模型训练和任务表示共享特征来实现元学习。MAML算法通过少量梯度步骤确保元参数的可靠初始化,使模型快速适应新任务。匹配网络 (Matching Networks)
匹配网络是解决 Few-Shot Learning 问题的第一个度量学习方法。它通过计算支持集和查询集图像的特征嵌入余弦相似度,利用交叉熵损失更新特征嵌入模型。原型网络 (Prototypical Networks)
原型网络通过对每个类别的支持集图像嵌入进行平均,生成类别原型。查询图像嵌入与原型进行比较,输出最接近原型的类别标签。关系网络 (Relation Networks)
关系网络基于可训练的距离函数,输入查询图像嵌入与类别原型,输出可学习的分类关系分数。这种方法通过改进距离计算模块提升了分类性能。CLIP(Contrastive Language-Image Pre-Training)是一种跨模态预训练模型,无需针对任务进行优化即可实现零样本分类。通过对比语言和图像特征,CLIP可以将图像与文本描述关联起来。
引入依赖包
pip install ftfy regex tqdmpip install git+https://github.com/openai/CLIP.gitimport numpy as npfrom packaging import pkg_resourcesprint("Torch version:", torch.__version__)加载模型
import clipclip.available_modelsmodel, preprocess = clip.load("ViT-B/32")model = model.cuda().eval()input_resolution = model.visual.input_resolutioncontext_length = model.context_lengthvocab_size = model.vocab_sizeprint(f"Model parameters: {sum(int(np.prod(p.shape)) for p in model.parameters())}:")print("Input resolution:", input_resolution)print("Context length:", context_length)print("Vocab size:", vocab_size)图像预处理与特征提取
import osimport skimagefrom PIL import Imageimport numpy as npimport torchfrom collections import OrderedDictfrom packaging import pkg_resourcesplt.figure(figsize=(16, 5))descriptions = { "page": "a page of text about segmentation", "chelsea": "a facial photo of a tabby cat", "astronaut": "a portrait of an astronaut with the American flag", "rocket": "a rocket standing on a launchpad", "motorcycle_right": "a red motorcycle standing in a garage", "camera": "a person looking at a camera on a tripod", "horse": "a black-and-white silhouette of a horse", "coffee": "a cup of coffee on a saucer"}original_images = []texts = []for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg"): name = os.path.splitext(filename)[0] if name not in descriptions: continue image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB") plt.subplot(2, 4, len(images) + 1) plt.imshow(image) plt.title(f"{filename}\n{descriptions[name]}") plt.xticks([]) plt.yticks([]) original_images.append(image) texts.append(descriptions[name])特征归一化与相似性计算
image_input = torch.tensor(np.stack(images)).cuda()text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()with torch.no_grad(): image_features = model.encode_image(image_input).float() text_features = model.encode_text(text_tokens).float()image_features /= image_features.norm(dim=-1, keepdim=True)text_features /= text_features.norm(dim=-1, keepdim=True)similarity = text_features.cpu().numpy().Tcount = len(descriptions)plt.figure(figsize=(20, 14))plt.imshow(similarity, vmin=0.1, vmax=0.3)plt.colorbar()plt.yticks(range(count), [texts[i] for i in range(count)], fontsize=18)plt.xticks([])for i, image in enumerate(original_images): plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower") for x in range(similarity.shape[1]): for y in range(similarity.shape[0]): plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12) for side in ["left", "top", "right", "bottom"]: plt.gca().spines[side].set_visible(False) plt.xlim([-0.5, count - 0.5]) plt.ylim([count + 0.5, -2]) plt.title("Cosine similarity between text and image features", size=20)零样本分类示例
from torchvision.datasets import CIFAR100cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]text_tokens = clip.tokenize(text_descriptions).cuda()with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)plt.figure(figsize=(16, 16))for i, image in enumerate(original_images): plt.subplot(4, 4, 2 * i + 1) plt.imshow(image) plt.axis("off") plt.subplot(4, 4, 2 * i + 2) y = np.arange(top_probs.shape[-1]) plt.grid() plt.barh(y, top_probs[i]) plt.gca().invert_yaxis() plt.gca().set_axisbelow(True) plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()]) plt.xlabel("probability") plt.subplots_adjust(wspace=0.5)plt.show()通过上述方法,我们可以利用 CLIP 模型实现零样本分类任务,模型能够在没有标注数据的情况下准确识别新类别。
转载地址:http://jhsfk.baihongyu.com/