快速入门指南:使用timm进行模型训练的开端

标题:快速入门指南:使用timm进行模型训练的开端

友情链接:ACEJoy

导读:
本文旨在帮助开发人员快速了解如何将timm集成到模型训练流程中。首先,您需要安装timm。接下来,我们将通过示例代码演示如何加载预训练模型、列出具有预训练权重的模型、微调预训练模型、以及如何使用预训练模型进行特征提取、图像增强和推理。让我们一起开始这个令人兴奋的旅程吧!

正文:
快速入门
本快速入门指南旨在帮助开发人员快速了解如何将timm集成到他们的模型训练流程中。在开始之前,您需要先安装timm。有关安装的详细信息,请参阅安装指南。

加载预训练模型
通过create_model()函数可以加载预训练模型。下面的示例展示了如何加载预训练的mobilenetv3_large_100模型。

import timm

model = timm.create_model('mobilenetv3_large_100', pretrained=True)
model.eval()

需要注意的是,默认情况下返回的PyTorch模型处于训练模式,如果要进行推理,需要调用.eval()方法将其设置为评估模式。

列出具有预训练权重的模型
要列出timm中打包的具有预训练权重的模型,可以使用list_models()函数。如果指定pretrained=True,该函数将只返回具有预训练权重的模型名称。

import timm
from pprint import pprint

model_names = timm.list_models(pretrained=True)
pprint(model_names)

您还可以使用特定模式来列出名称中包含特定字符串的模型。

微调预训练模型
要微调任何预训练模型,只需更改分类器(即最后一层)即可。

model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_finetune_classes)

如果要在自己的数据集上进行微调,请编写一个PyTorch训练循环或调整timm的训练脚本以使用自己的数据集。

使用预训练模型进行特征提取
在不修改网络结构的情况下,可以使用model.forward_features(input)来替代通常的model(input)方法,对任何模型进行特征提取。这将跳过头部分类器和全局池化操作。有关更详细的使用timm进行特征提取的指南,请参阅特征提取部分。

图像增强
为了将图像转换为模型接受的有效输入,可以使用timm.data.create_transform()函数,并提供模型期望的输入尺寸。这将返回一个通用的变换对象,其中包含了一些合理的默认设置。

timm.data.create_transform((3, 224, 224))

预训练模型在训练时应用了特定的数据转换。如果您在图像上使用错误的转换,模型将无法理解所见的图像!要了解给定预训练模型使用了哪些转换,可以查看其预训练配置。

model.pretrained_cfg

您可以使用timm.data.resolve_data_config()函数来解析与数据相关的配置。

data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)

使用预训练模型进行推理
下面,我们将结合前面的内容,使用预训练模型进行推理。首先,我们需要一张图片作为输入。我们从网络上加载一张叶子标题:快速入门指南:timm模型训练的开端

导读:
本文将带您快速了解如何将timm集成到模型训练流程中。从安装timm开始,到加载预训练模型、列出可用的预训练模型、微调模型、使用模型进行特征提取、图像增强和推理等,我们将逐步介绍这些内容。让我们开始这个令人兴奋的旅程吧!

正文:
快速入门
本快速入门指南旨在帮助开发人员快速了解如何将timm集成到模型训练流程中。在开始之前,您需要先安装timm,具体安装方法请参考官方的安装指南。

加载预训练模型
要加载预训练模型,可以使用create_model()函数。以下示例展示了如何加载预训练的mobilenetv3_large_100模型。

import timm

model = timm.create_model('mobilenetv3_large_100', pretrained=True)
model.eval()

需要注意的是,默认情况下返回的PyTorch模型处于训练模式,如果要进行推理,需要调用.eval()方法将其设置为评估模式。

列出具有预训练权重的模型
要列出timm中已打包的具有预训练权重的模型,可以使用list_models()函数。如果指定pretrained=True,该函数将只返回具有预训练权重的模型名称。

import timm
from pprint import pprint

model_names = timm.list_models(pretrained=True)
pprint(model_names)

您还可以使用特定模式来列出名称中包含特定字符串的模型。

微调预训练模型
要微调任何预训练模型,只需更改分类器(即最后一层)即可。

model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_finetune_classes)

如果要在自己的数据集上进行微调,请编写一个PyTorch训练循环或调整timm的训练脚本以使用自己的数据集。

使用预训练模型进行特征提取
在不修改网络结构的情况下,可以使用model.forward_features(input)来替代通常的model(input)方法,对任何模型进行特征提取。这将跳过头部分类器和全局池化操作。有关更详细的使用timm进行特征提取的指南,请参考官方文档中的特征提取部分。

图像增强
要将图像转换为模型可接受的有效输入,可以使用timm.data.create_transform()函数,并提供模型所期望的输入尺寸。这将返回一个通用的转换对象,其中包含了一些合理的默认设置。

timm.data.create_transform((3, 224, 224))

需要注意的是,预训练模型在训练时应用了特定的数据转换。如果您在图像上使用错误的转换,模型将无法理解所见的图像!要了解给定预训练模型使用了哪些转换,可以查看其预训练配置。

model.pretrained_cfg

您可以使用timm.data.resolve_data_config()函数来解析与数据相关的配置。

data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)

使用预训练模型进行推理
下面,我们将结合前面的内容,使用预训练模型进行推理。首先,我们需要准备一张图片作为输入。我们从网络上加载了一张叶子的图片。


import requests
from PIL import Image
from io import BytesIO

url = 'https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

接下来,我们再次创建模型和转换。这次,我们确保将模型设置为评估模式。

model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
transform = timm.data.create_transform(
    **timm.data.resolve_data_config(model.pretrained_cfg)
)

我们可以将图片通过转换函数进行处理,以便与模型进行推理。

image_tensor = transform(image)
image_tensor.shape

现在,我们可以将处理后的图片传递给模型进行预测。这里我们使用unsqueeze(0)将其添加一个批次维度,因为模型需要一个批次的输入。

output = model(image_tensor.unsqueeze(0))
output.shape

为了得到预测的概率,我们对输出应用softmax函数。这会得到一个形状为(num_classes,)的张量。

probabilities = torch.nn.functional.softmax(output[0], dim=0)
probabilities.shape

接下来,我们使用torch.topk找到前5个预测类别的索引和概率值。

values, indices = torch.topk(probabilities, 5)
indices

通过检查对应于最高概率的索引值,我们可以查看模型的预测结果。

IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
[{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]

至此,我们完成了使用预训练模型进行推理的过程。我们加载了一张图片,将其通过转换函数处理后输入模型,并得到了模型的预测结果。通过这个示例,我们展示了如何使用timm进行快速入门。

结语:
本文通过一个快速入门的示例,介绍了如何使用timm进行模型训练。我们展示了加载预训练模型、列出具有预训练权重的模型、微调预训练模型、使用预训练模型进行特征提取、图像增强和推理的方法。希望本文对您了解和使用timm有所帮助,让您能够更轻松地进行模型训练和推理!

发表评论