Pytorch 分类网络训练方法(Resnet152为例)
创始人
2024-01-30 11:04:07
0

Pytorch resnet 分类网络训练方法

文章目录

  • Pytorch resnet 分类网络训练方法
    • conda环境安装与配置
      • 新建conda环境并激活
      • 配置
    • 准备数据集
    • 训练
      • 修改训练脚本
      • 执行脚本训练
    • 测试效果
    • 转换模型

conda环境安装与配置

新建conda环境并激活

conda create --name pytorch170_py3_7 python=3.7
conda activate pytorch170_py3_7

配置

官网上面找与本机cuda对应的版本安装

  1. 选择合适的版本安装
# CUDA 10.1
conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch# CPU Only
conda install pytorch==1.7.0 torchvision==0.8.0 cpuonly -c pytorch
  1. 检测环境是否成功
python  
import torch  
print(torch.__version__)            //查看pytorch版本
print(torch.version.cuda)           //查看cuda版本
print(torch.cuda.is_available())    //验证cuda是否可用

准备数据集

  1. 在训练数据目录下新建 train 和 val 两个文件夹

  2. 在train文件夹下新建 各个分类文件夹。如:class1,class2,class3,class4

  3. 在val文件夹下同样如此

  4. 将不同分类的图片放到个子文件夹里面

#最终文件结构为
-数据集根目录:
----------------train:
----------------------class1
----------------------class2
----------------------class3
----------------------class4
----------------val:
----------------------class1
----------------------class2
----------------------class3
----------------------class4

训练

修改训练脚本

from __future__ import print_function 
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy#训练数据目录
data_dir = "./DataSets"# 选择训练网络 [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "resnet"#权重文件输出目录及名称
output_path="./resnet152.pth"#学习率
lr_rate=0.001# 种类数目
num_classes = 4# Batch Size取决你电脑内存大小
batch_size = 8# 训练次数 
num_epochs = 500# False时候更新所有参数,True时候只更新最后一层的参数
feature_extract = Falsedef train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):since = time.time()val_acc_history = []val_loss_history=[]train_acc_history = []train_loss_history=[]best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):# Get model outputs and calculate loss# Special case for inception because in training it has an auxiliary output. In train#   mode we calculate the loss by summing the final output and the auxiliary output#   but in testing we only consider the final output.if is_inception and phase == 'train':# From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958outputs, aux_outputs = model(inputs)loss1 = criterion(outputs, labels)loss2 = criterion(aux_outputs, labels)loss = loss1 + 0.4*loss2else:outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())if phase == 'train':train_acc_history.append(epoch_acc)train_loss_history.append(epoch_loss) else:val_acc_history.append(epoch_acc)val_loss_history.append(epoch_loss)          print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# load best model weightsmodel.load_state_dict(best_model_wts)return model, train_acc_history,train_loss_history,val_acc_history,val_loss_historydef set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falsedef initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):# Initialize these variables which will be set in this if statement. Each of these#   variables is model specific.model_ft = Noneinput_size = 0if model_name == "resnet":""" Resnet152"""model_ft = models.resnet152(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)input_size = 480 #224elif model_name == "alexnet":""" Alexnet"""model_ft = models.alexnet(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)input_size = 224elif model_name == "vgg":""" VGG11_bn"""model_ft = models.vgg11_bn(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)input_size = 224elif model_name == "squeezenet":""" Squeezenet"""model_ft = models.squeezenet1_0(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))model_ft.num_classes = num_classesinput_size = 224elif model_name == "densenet":""" Densenet"""model_ft = models.densenet121(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier.in_featuresmodel_ft.classifier = nn.Linear(num_ftrs, num_classes) input_size = 224elif model_name == "inception":""" Inception v3 Be careful, expects (299,299) sized images and has auxiliary output"""model_ft = models.inception_v3(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)# Handle the auxilary netnum_ftrs = model_ft.AuxLogits.fc.in_featuresmodel_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)# Handle the primary netnum_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs,num_classes)input_size = 299else:print("Invalid model name, exiting...")exit()return model_ft, input_sizeif __name__ == '__main__':print("PyTorch Version: ",torch.__version__)print("Torchvision Version: ",torchvision.__version__)# Initialize the model for this runmodel_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)# Print the model we just instantiatedprint(model_ft) data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(input_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(input_size),transforms.CenterCrop(input_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}print("Initializing Datasets and Dataloaders...")# Create training and validation datasetsimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}# Create training and validation dataloadersdataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}# Detect if we have a GPU availabledevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# Send the model to GPUmodel_ft = model_ft.to(device)params_to_update = model_ft.parameters()print("Params to learn:")if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)# Observe that all parameters are being optimizedoptimizer_ft = optim.SGD(params_to_update, lr=lr_rate, momentum=0.9)# Setup the loss fxncriterion = nn.CrossEntropyLoss()# Train and evaluatemodel_ft, trainacc,trainloss,valacc,valloss = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))torch.save(model_ft.state_dict(), output_path)plt.title("Acc&Loss")plt.xlabel("Training Epochs")plt.ylabel("Value")plt.plot(range(1,num_epochs+1),trainacc,label="TrainAcc")plt.plot(range(1,num_epochs+1),trainloss,label="TrainLoss")plt.plot(range(1,num_epochs+1),valacc,label="ValAcc")plt.plot(range(1,num_epochs+1),valloss,label="ValLoss")plt.ylim((0,1.))plt.xticks(np.arange(1, num_epochs+1, 1.0))plt.legend()plt.show()

执行脚本训练

python finetuning_torchvision_models_tutorial.py

测试效果

# -*- coding: utf-8 -*-
#!/usr/bin/python
# -*- coding: UTF-8 -*-import torchvision as tv
import torchvision.transforms as transforms
import torch
from PIL import Image
import torch.nn as nninput_size = 480
names = ['class1', 'class2','class3','class4']
def pridict():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model=tv.models.resnet152()num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, 5)  # 分类数量model.load_state_dict(torch.load("resnet152.pth"))model = model.to(device)model.eval()  # 预测模式# 获取测试图片,并行相应的处理img = Image.open('4.jpg')transform = transforms.Compose([transforms.Resize(input_size),transforms.CenterCrop(input_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img = transform(img)img = img.unsqueeze(0)img = img.to(device)with torch.no_grad():py = model(img)_, predicted = torch.max(py, 1)  # 获取分类结果classIndex_ = predicted[0]print('预测结果', names[classIndex_])if __name__ == '__main__':pridict()

转换模型

将模型转换为c++可用的模型

import torch
import torchvision
import torch.nn as nn
#model = torchvision.models.resnet50(pretrained=True)
model=torchvision.models.resnet152()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)  # make the change
model.load_state_dict(torch.load("resnet152.pth"))model.eval()
example = torch.rand(1, 3, 480, 480)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("resnet152.pt")

相关内容

热门资讯

喜欢穿一身黑的男生性格(喜欢穿... 今天百科达人给各位分享喜欢穿一身黑的男生性格的知识,其中也会对喜欢穿一身黑衣服的男人人好相处吗进行解...
发春是什么意思(思春和发春是什... 本篇文章极速百科给大家谈谈发春是什么意思,以及思春和发春是什么意思对应的知识点,希望对各位有所帮助,...
网络用语zl是什么意思(zl是... 今天给各位分享网络用语zl是什么意思的知识,其中也会对zl是啥意思是什么网络用语进行解释,如果能碰巧...
为什么酷狗音乐自己唱的歌不能下... 本篇文章极速百科小编给大家谈谈为什么酷狗音乐自己唱的歌不能下载到本地?,以及为什么酷狗下载的歌曲不是...
华为下载未安装的文件去哪找(华... 今天百科达人给各位分享华为下载未安装的文件去哪找的知识,其中也会对华为下载未安装的文件去哪找到进行解...
怎么往应用助手里添加应用(应用... 今天百科达人给各位分享怎么往应用助手里添加应用的知识,其中也会对应用助手怎么添加微信进行解释,如果能...
家里可以做假山养金鱼吗(假山能... 今天百科达人给各位分享家里可以做假山养金鱼吗的知识,其中也会对假山能放鱼缸里吗进行解释,如果能碰巧解...
四分五裂是什么生肖什么动物(四... 本篇文章极速百科小编给大家谈谈四分五裂是什么生肖什么动物,以及四分五裂打一生肖是什么对应的知识点,希...
一帆风顺二龙腾飞三阳开泰祝福语... 本篇文章极速百科给大家谈谈一帆风顺二龙腾飞三阳开泰祝福语,以及一帆风顺二龙腾飞三阳开泰祝福语结婚对应...
美团联名卡审核成功待激活(美团... 今天百科达人给各位分享美团联名卡审核成功待激活的知识,其中也会对美团联名卡审核未通过进行解释,如果能...