介绍
机器学习,人工智能,模式识别课题项目,基于tensorflow机器学习库使用CNN算法通过对四种花卉数据集进行训练,得出训练模型。同时基于Django框架开发可视化系统,实现上传图片预测是否为玫瑰,蒲公英,郁金香,向日葵等花卉,并集成后台管理系统,可查看概率值,以及上传预测信息。
技术栈
- 机器学习库:tensorflow
- 算法:CNN
- WEB框架:Django
核心部分
# 进行batch的训练
try:
# 执行MAX_STEP步的训练,一步一个batch
for step in np.arange(MAX_STEP):
if coord.should_stop():
break
_, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])
# 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer
if step % 10 == 0:
print(\'Step %d, train loss = %.2f, train accuracy = %.2f%%\' % (step, tra_loss, tra_acc * 100.0))
summary_str = sess.run(summary_op)
train_writer.add_summary(summary_str, step)
# 每隔100步,保存一次训练好的模型
if (step + 1) == MAX_STEP:
checkpoint_path = os.path.join(logs_train_dir, \'model.ckpt\')
saver.save(sess, checkpoint_path, global_step=step)
except tf.errors.OutOfRangeError:
print(\'Done training -- epoch limit reached\')
finally:
coord.request_stop()
联系v:sql2202