UNet - 训练数据train
创始人
2024-01-29 02:32:21
0

目录

1. train 训练数据

2. Loss 值

3. 完整代码


1. train 训练数据

训练的代码只是在之前图像分类的基础上做了一些更改,具体的可以看下面的文章

pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类https://blog.csdn.net/qq_44886601/article/details/127498256

首先,导入之前定义的UNet 网络

然后,加载训练集和测试集

因为加载数据集被重写过,所以这里传入的是训练的图像,然后根据里面的replace就能找到对应的标签

这里训练的时候可以将数据打乱,测试的时候没有必要,batch_size 因为电脑硬件的问题设置成2,再大的话这里内存就会不够了

 

然后定义优化器和损失函数,这里用的是BCE加上sigmoid的损失函数

训练的时候,要将模式改为train模式,然后训练的步骤很常规

梯度清零->前向传播->计算损失函数->反向传播->更新参数

 

这里测试的时候有些区别

因为这里UNet 网络的输出是一幅图像,而之前将label改为了二值图像(归一化后是0 1)。所以这里计算准确率的时候,将预测的图像也变为二值图像,计算准确率用的是对应图像像素点的灰度值是否相等的方法

 

最后保留最好准确率的那个参数就行了

 

2. Loss 值

这是跑了20 个epoch的输出

 

3. 完整代码

from model import UNet                  # 导入Unet 网络
from dataset import Data_Loader         # 数据处理
from torch import optim
import torch.nn as nn
import torch# 网络训练模块
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # GPU or CPU
print(device)
net = UNet(in_channels=1, num_classes=1)                                # 加载网络
net.to(device)                                                          # 将网络加载到device上# 加载训练集
train_path = "./data/train/image"
trainset = Data_Loader(train_path)
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=2,shuffle=True)# len(trainset)  样本总数:21# 加载测试集
test_path = "./data/test/image"
testset = Data_Loader(test_path)
test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=2)optimizer = optim.RMSprop(net.parameters(),lr = 0.000001,weight_decay=1e-8,momentum=0.9)     # 定义优化器
criterion = nn.BCEWithLogitsLoss()                                                           # 定义损失函数save_path = './UNet.pth'        # 网络参数的保存路径
best_acc = 0.0                  # 保存最好的准确率for epoch in range(20):net.train()     # 训练模式running_loss = 0.0for image,label in train_loader:                   # 读取数据和labeloptimizer.zero_grad()                          # 梯度清零pred = net(image.to(device))                   # 前向传播loss = criterion(pred, label.to(device))       # 计算损失loss.backward()                                # 反向传播optimizer.step()                               # 梯度下降running_loss += loss.item()                    # 计算损失和net.eval()  # 测试模式acc = 0.0   # 正确率total = 0with torch.no_grad():for test_image, test_label in test_loader:outputs = net(test_image.to(device))     # 前向传播outputs[outputs >= 0] = 1  # 将预测图片转为二值图片outputs[outputs < 0] = 0acc += (outputs == test_label.to(device)).sum().item() / (480*480)     # 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数total += test_label.size(0)accurate = acc / total  # 计算整个test上面的正确率print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f %%' %(epoch + 1, running_loss, accurate*100))if accurate > best_acc:     # 保留最好的精度best_acc = accuratetorch.save(net.state_dict(), save_path)     # 保存网络参数

相关内容

热门资讯

喜欢穿一身黑的男生性格(喜欢穿... 今天百科达人给各位分享喜欢穿一身黑的男生性格的知识,其中也会对喜欢穿一身黑衣服的男人人好相处吗进行解...
发春是什么意思(思春和发春是什... 本篇文章极速百科给大家谈谈发春是什么意思,以及思春和发春是什么意思对应的知识点,希望对各位有所帮助,...
网络用语zl是什么意思(zl是... 今天给各位分享网络用语zl是什么意思的知识,其中也会对zl是啥意思是什么网络用语进行解释,如果能碰巧...
为什么酷狗音乐自己唱的歌不能下... 本篇文章极速百科小编给大家谈谈为什么酷狗音乐自己唱的歌不能下载到本地?,以及为什么酷狗下载的歌曲不是...
华为下载未安装的文件去哪找(华... 今天百科达人给各位分享华为下载未安装的文件去哪找的知识,其中也会对华为下载未安装的文件去哪找到进行解...
怎么往应用助手里添加应用(应用... 今天百科达人给各位分享怎么往应用助手里添加应用的知识,其中也会对应用助手怎么添加微信进行解释,如果能...
家里可以做假山养金鱼吗(假山能... 今天百科达人给各位分享家里可以做假山养金鱼吗的知识,其中也会对假山能放鱼缸里吗进行解释,如果能碰巧解...
四分五裂是什么生肖什么动物(四... 本篇文章极速百科小编给大家谈谈四分五裂是什么生肖什么动物,以及四分五裂打一生肖是什么对应的知识点,希...
一帆风顺二龙腾飞三阳开泰祝福语... 本篇文章极速百科给大家谈谈一帆风顺二龙腾飞三阳开泰祝福语,以及一帆风顺二龙腾飞三阳开泰祝福语结婚对应...
美团联名卡审核成功待激活(美团... 今天百科达人给各位分享美团联名卡审核成功待激活的知识,其中也会对美团联名卡审核未通过进行解释,如果能...