作者简介:在校大学生一枚,华为云享专家,阿里云专家博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专家委员会(TIPCC)志愿者,以及编程爱好者,期待和大家一起学习,一起进步~
.
博客主页:ぃ灵彧が的学习日志
.
本文专栏:人工智能
.
专栏寄语:若你决定灿烂,山无遮,海无拦
.
中草药识别案例是图像分类问题,相较于目标检测、实例分割、行为识别、轨迹跟踪等难度较大的计算机视觉任务,图像分类只需要让计算机『看出』图片里的物体类别,更为基础但极为重要。图像分类在许多领域都有着广泛的应用,如:安防领域的智能视频分析和人脸识别等,医学领域的中草药识别,互联网领域基于内容的图像检索和相册自动归类,农业领域的害虫识别等。
本实践代码运行的环境配置如下:Python版本为3.7,PaddlePaddle版本为2.0.0,操作平台为AI Studio。大部分深度学习项目都要经过以下几个过程:数据准备、模型配置、模型训练、模型评估。
import paddle
import numpy as np
import matplotlib.pyplot as plt
print(paddle.__version__)# cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
# device = paddle.set_device('gpu')
本案例整体结构如下所示:
现在我们开始训练模型,训练步骤如下:
model = VGGNet()
model.train()
# 配置loss函数
cross_entropy = paddle.nn.CrossEntropyLoss()
# 配置参数优化器
optimizer = paddle.optimizer.Adam(learning_rate=train_parameters['learning_strategy']['lr'],parameters=model.parameters()) steps = 0
Iters, total_loss, total_acc = [], [], []for epo in range(train_parameters['num_epochs']):for _, data in enumerate(train_loader()):steps += 1x_data = data[0]y_data = data[1]predicts, acc = model(x_data, y_data)loss = cross_entropy(predicts, y_data)loss.backward()optimizer.step()optimizer.clear_grad()if steps % train_parameters["skip_steps"] == 0:Iters.append(steps)total_loss.append(loss.numpy()[0])total_acc.append(acc.numpy()[0])#打印中间过程print('epo: {}, step: {}, loss is: {}, acc is: {}'\.format(epo, steps, loss.numpy(), acc.numpy()))#保存模型参数if steps % train_parameters["save_steps"] == 0:save_path = train_parameters["checkpoints"]+"/"+"save_dir_" + str(steps) + '.pdparams'print('save model to: ' + save_path)paddle.save(model.state_dict(),save_path)
paddle.save(model.state_dict(),train_parameters["checkpoints"]+"/"+"save_dir_final.pdparams")
draw_process("trainning loss","red",Iters,total_loss,"trainning loss")
draw_process("trainning acc","green",Iters,total_acc,"trainning acc")
改变batch_size优化模型
batch_size指的是一次训练所选取的样本数。
在网络训练过程中,batch_size过大或者过小都会影响训练的性能和速度,batch_size过小,花费时间多,同时梯度震荡严重,不利于收敛;batch_size过大,不同batch的梯度方向没有任何变化,容易陷入局部极小值。
例如,在本案例中,我们直接使用神经网络通常设置的batch_size=16
,训练35个epochs之后模型在验证集上的准确率为: 0.825
在合理范围内,增大batch_size会提高显存的利用率,提高大矩阵乘法的并行化效率,减少每个epoch需要训练的迭代次数。在一定范围内,batch size越大,其确定的下降方向越准,引起训练时准确率震荡越小。
在本案例中,我们设置batch_size=32
,同样训练35个epochs,模型在验证集上的准确率为: 0.842
当然,过大的batch_size同样会降低模型性能。
在本案例中,我们设置batch_size=48
,训练35个epochs之后模型在验证集上的准确率为: 0.817
从以上的实验结果对比中,我们可以清楚的了解到,在模型优化的过程中,找到合适的batch_size是很重要的。
我们使用验证集来评估训练过程保存的最后一个模型,首先加载模型参数,之后遍历验证集进行预测并输出平均准确率
# 模型评估
# 加载训练过程保存的最后一个模型
model__state_dict = paddle.load('work/checkpoints/save_dir_final.pdparams')
model_eval = VGGNet()
model_eval.set_state_dict(model__state_dict)
model_eval.eval()
accs = []
# 开始评估
for _, data in enumerate(eval_loader()):x_data = data[0]y_data = data[1]predicts = model_eval(x_data)acc = paddle.metric.accuracy(predicts, y_data)accs.append(acc.numpy()[0])
print('模型在验证集上的准确率为:',np.mean(accs))
采用与训练过程同样的图片转换方式对测试集图片进行预处理
def load_image(img_path):'''预测图片预处理'''img = Image.open(img_path) if img.mode != 'RGB': img = img.convert('RGB') img = img.resize((224, 224), Image.BILINEAR)img = np.array(img).astype('float32') img = img.transpose((2, 0, 1)) / 255 # HWC to CHW 及归一化return imglabel_dic = train_parameters['label_dict']
我们使用训练过程保存的最后一个模型预测测试集中的图片,首先加载模型,预测并输出每张图片的预测值
import time
# 加载训练过程保存的最后一个模型
model__state_dict = paddle.load('work/checkpoints/save_dir_final.pdparams')
model_predict = VGGNet()
model_predict.set_state_dict(model__state_dict)
model_predict.eval()
infer_imgs_path = os.listdir("infer")
# print(infer_imgs_path)# 预测所有图片
for infer_img_path in infer_imgs_path:infer_img = load_image("infer/"+infer_img_path)infer_img = infer_img[np.newaxis,:, : ,:] #reshape(-1,3,224,224)infer_img = paddle.to_tensor(infer_img)result = model_predict(infer_img)lab = np.argmax(result.numpy())print("样本: {},被预测为:{}".format(infer_img_path,label_dic[str(lab)]))img = Image.open("infer/"+infer_img_path)plt.imshow(img)plt.axis('off')plt.show()sys.stdout.flush()time.sleep(0.5)
输出结果如下图所示:
–
本系列文章内容为根据清华社出版的《自然语言处理实践》所作的相关笔记和感悟,其中代码均为基于百度飞桨开发,若有任何侵权和不妥之处,请私信于我,定积极配合处理,看到必回!!!
最后,引用本次活动的一句话,来作为文章的结语~( ̄▽ ̄~)~:
【学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。】
ps:更多精彩内容还请进入本文专栏:人工智能,进行查看,欢迎大家支持与指教啊~( ̄▽ ̄~)~
上一篇:k8s部署mysql一主两从
下一篇:几率波量子雷达/反事实量子通信