标题:快速入门指南:使用timm进行模型训练的开端
友情链接:ACEJoy
导读:
本文旨在帮助开发人员快速了解如何将timm集成到模型训练流程中。首先,您需要安装timm。接下来,我们将通过示例代码演示如何加载预训练模型、列出具有预训练权重的模型、微调预训练模型、以及如何使用预训练模型进行特征提取、图像增强和推理。让我们一起开始这个令人兴奋的旅程吧!
正文:
快速入门
本快速入门指南旨在帮助开发人员快速了解如何将timm集成到他们的模型训练流程中。在开始之前,您需要先安装timm。有关安装的详细信息,请参阅安装指南。
加载预训练模型
通过create_model()函数可以加载预训练模型。下面的示例展示了如何加载预训练的mobilenetv3_large_100模型。
import timm
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
model.eval()
需要注意的是,默认情况下返回的PyTorch模型处于训练模式,如果要进行推理,需要调用.eval()
方法将其设置为评估模式。
列出具有预训练权重的模型
要列出timm中打包的具有预训练权重的模型,可以使用list_models()函数。如果指定pretrained=True,该函数将只返回具有预训练权重的模型名称。
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
您还可以使用特定模式来列出名称中包含特定字符串的模型。
微调预训练模型
要微调任何预训练模型,只需更改分类器(即最后一层)即可。
model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_finetune_classes)
如果要在自己的数据集上进行微调,请编写一个PyTorch训练循环或调整timm的训练脚本以使用自己的数据集。
使用预训练模型进行特征提取
在不修改网络结构的情况下,可以使用model.forward_features(input)来替代通常的model(input)方法,对任何模型进行特征提取。这将跳过头部分类器和全局池化操作。有关更详细的使用timm进行特征提取的指南,请参阅特征提取部分。
图像增强
为了将图像转换为模型接受的有效输入,可以使用timm.data.create_transform()函数,并提供模型期望的输入尺寸。这将返回一个通用的变换对象,其中包含了一些合理的默认设置。
timm.data.create_transform((3, 224, 224))
预训练模型在训练时应用了特定的数据转换。如果您在图像上使用错误的转换,模型将无法理解所见的图像!要了解给定预训练模型使用了哪些转换,可以查看其预训练配置。
model.pretrained_cfg
您可以使用timm.data.resolve_data_config()函数来解析与数据相关的配置。
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)
使用预训练模型进行推理
下面,我们将结合前面的内容,使用预训练模型进行推理。首先,我们需要一张图片作为输入。我们从网络上加载一张叶子标题:快速入门指南:timm模型训练的开端
导读:
本文将带您快速了解如何将timm集成到模型训练流程中。从安装timm开始,到加载预训练模型、列出可用的预训练模型、微调模型、使用模型进行特征提取、图像增强和推理等,我们将逐步介绍这些内容。让我们开始这个令人兴奋的旅程吧!
正文:
快速入门
本快速入门指南旨在帮助开发人员快速了解如何将timm集成到模型训练流程中。在开始之前,您需要先安装timm,具体安装方法请参考官方的安装指南。
加载预训练模型
要加载预训练模型,可以使用create_model()函数。以下示例展示了如何加载预训练的mobilenetv3_large_100模型。
import timm
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
model.eval()
需要注意的是,默认情况下返回的PyTorch模型处于训练模式,如果要进行推理,需要调用.eval()
方法将其设置为评估模式。
列出具有预训练权重的模型
要列出timm中已打包的具有预训练权重的模型,可以使用list_models()函数。如果指定pretrained=True,该函数将只返回具有预训练权重的模型名称。
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
您还可以使用特定模式来列出名称中包含特定字符串的模型。
微调预训练模型
要微调任何预训练模型,只需更改分类器(即最后一层)即可。
model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_finetune_classes)
如果要在自己的数据集上进行微调,请编写一个PyTorch训练循环或调整timm的训练脚本以使用自己的数据集。
使用预训练模型进行特征提取
在不修改网络结构的情况下,可以使用model.forward_features(input)来替代通常的model(input)方法,对任何模型进行特征提取。这将跳过头部分类器和全局池化操作。有关更详细的使用timm进行特征提取的指南,请参考官方文档中的特征提取部分。
图像增强
要将图像转换为模型可接受的有效输入,可以使用timm.data.create_transform()函数,并提供模型所期望的输入尺寸。这将返回一个通用的转换对象,其中包含了一些合理的默认设置。
timm.data.create_transform((3, 224, 224))
需要注意的是,预训练模型在训练时应用了特定的数据转换。如果您在图像上使用错误的转换,模型将无法理解所见的图像!要了解给定预训练模型使用了哪些转换,可以查看其预训练配置。
model.pretrained_cfg
您可以使用timm.data.resolve_data_config()函数来解析与数据相关的配置。
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)
使用预训练模型进行推理
下面,我们将结合前面的内容,使用预训练模型进行推理。首先,我们需要准备一张图片作为输入。我们从网络上加载了一张叶子的图片。
import requests
from PIL import Image
from io import BytesIO
url = 'https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image
接下来,我们再次创建模型和转换。这次,我们确保将模型设置为评估模式。
model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)
我们可以将图片通过转换函数进行处理,以便与模型进行推理。
image_tensor = transform(image)
image_tensor.shape
现在,我们可以将处理后的图片传递给模型进行预测。这里我们使用unsqueeze(0)将其添加一个批次维度,因为模型需要一个批次的输入。
output = model(image_tensor.unsqueeze(0))
output.shape
为了得到预测的概率,我们对输出应用softmax函数。这会得到一个形状为(num_classes,)的张量。
probabilities = torch.nn.functional.softmax(output[0], dim=0)
probabilities.shape
接下来,我们使用torch.topk找到前5个预测类别的索引和概率值。
values, indices = torch.topk(probabilities, 5)
indices
通过检查对应于最高概率的索引值,我们可以查看模型的预测结果。
IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
[{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
至此,我们完成了使用预训练模型进行推理的过程。我们加载了一张图片,将其通过转换函数处理后输入模型,并得到了模型的预测结果。通过这个示例,我们展示了如何使用timm进行快速入门。
结语:
本文通过一个快速入门的示例,介绍了如何使用timm进行模型训练。我们展示了加载预训练模型、列出具有预训练权重的模型、微调预训练模型、使用预训练模型进行特征提取、图像增强和推理的方法。希望本文对您了解和使用timm有所帮助,让您能够更轻松地进行模型训练和推理!