当前位置:网站首页>猫狗分类-简单CNN
猫狗分类-简单CNN
2022-07-13 18:20:00 【booze-J】
猫狗分类的数据集可以查看图像数据预处理。
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
简单CNN实现猫狗分类代码:
1.导入第三方库
from keras.models import Sequential
from keras.layers import Convolution2D,MaxPooling2D
from keras.layers import Activation,Dropout,Flatten,Dense
from tensorflow.keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
import os
2.定义模型
# 定义模型
model = Sequential()
model.add(Convolution2D(input_shape=(150,150,3),filters=32,kernel_size=3,strides=1,padding="same",activation="relu"))
model.add(Convolution2D(filters=32,kernel_size=3,strides=1,padding="same",activation="relu"))
model.add(MaxPooling2D(pool_size=2,strides=2,padding="valid"))
model.add(Convolution2D(filters=64,kernel_size=3,strides=1,padding="same",activation="relu"))
model.add(Convolution2D(filters=64,kernel_size=3,strides=1,padding="same",activation="relu"))
model.add(MaxPooling2D(pool_size=2,strides=2,padding="valid"))
model.add(Convolution2D(filters=128,kernel_size=3,strides=1,padding="same",activation="relu"))
model.add(Convolution2D(filters=128,kernel_size=3,strides=1,padding="same",activation="relu"))
model.add(MaxPooling2D(pool_size=2,strides=2,padding="valid"))
model.add(Flatten())
model.add(Dense(64,activation="relu"))
model.add(Dropout(0.5))
model.add(Dense(2,activation="softmax"))
# 定义优化器
adam = Adam(lr=1e-4)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=adam,
loss="categorical_crossentropy",
metrics=['accuracy']
)
# 查看模型的结构
model.summary()
运行结果:
3.训练数据和测试数据生成
# 训练集数据生成
train_datagen = ImageDataGenerator(
rescale=1./255,# 归一化处理
shear_range=0.2,# 随机裁剪
zoom_range=0.2,# 图片放大
horizontal_flip=True # 水平翻转
)
# 测试集数据处理
test_datagen = ImageDataGenerator(rescale=1./255)
测试集通常不需要做复杂的数据生成,测试集只是未来测试才用,而训练集时为了
做数据生成是为了:
- 1.增加数据量,使得图片各种各样
- 2.增加模型的鲁棒性,使其泛化性更好
flow_from_directory:
- directory:目标文件夹路径,对于每一个类,该文件夹都要包含一个子文件夹.子文件夹中任何JPG、PNG、BNP、PPM的图片都会被生成器使用
- target_size:整数tuple,默认为(256,256).图像将被resize成该尺寸
- color_mode:颜色模式,为"grayscale",“rgb"之一,默认为"rgb”.代表这些图片是否会被转换为单通道或三通道的图片.
- classes:可选参数,为子文件夹的列表,如[‘dogs’,‘cats’]默认为None.若未提供,则该类别列表将从directory下的子文件夹名称/结构自动推断,每一个子文件夹都会被认为是一个新的类。(类别的顺序将按照字母表顺序映射到标签值)。通过属性class_indices可获得文件夹名与类的序号的对应字典。
- class_mode:“categorical”,“binary”,“sparse"或None之一,默认为"categorical”.该参数决定了返回的标签数组的式,"categorical"会返回2D的one-hot编码标签,"binary"返回1D的二值标签."sparse"返回1D的整数标签,如果为None则不返回任何标签,生成器将仅仅生成batch数据,这种情况在使用model.predict_generator()和model.evaluate_generator()等函数时会用到.
- batch_size:batch数据的大小,默认为32
- shuffle:是否打乱数据,默认为True
- seed:可选参数,打乱数据和进行变换时的随机数种子
batch_size = 32
# 生成训练数据
train_generator = train_datagen.flow_from_directory(
"../input/cat-and-dog-classify/train/train",# 训练数据路径
target_size=(150,150),# 设置图片大小
batch_size=batch_size # 批次大小
)
# 测试数据
test_generator = test_datagen.flow_from_directory(
"../input/cat-and-dog-classify/test/test",# 训练数据路径
target_size=(150,150),# 设置图片大小
batch_size=batch_size # 批次大小
)
# 统计文件个数
totalFileCount_train = sum([len(files) for root,dirs,files in os.walk("../input/cat-and-dog-classify/train/train")])
totalFileCount_test = sum([len(files) for root,dirs,files in os.walk("../input/cat-and-dog-classify/test/test")])
4.训练模型
model.fit_generator(
train_generator,
steps_per_epoch=totalFileCount_train/batch_size,
epochs=50,
validation_data=test_generator,
validation_steps=totalFileCount_train/batch_size
)
# 保存模型
model.save("CNN1.h5")
这里面有一个steps_per_epoch=totalFileCount_train/batch_size这个是计算每批次的总步数,一批次的总步数等于数据量除以batch_size。
运行结果:
边栏推荐
- Flutter:环境搭建、项目创建
- Small program graduation project of wechat enterprise company (1) development outline
- 【每日一题】在二叉树中找到两个节点的最近公共祖先
- Configuration and vant component of jump and navigation of wechat applet page
- 分账系统如何给连锁便利店带来交易效率革命?
- 最后一篇CSDN博客
- "Telecom grade" has been running for many years, and CICA technology has launched the core transaction database antdb7.0
- 命令提示符查看某端口占用情况,并清除占用
- Cocoscreator animation and particles move according to the painting path
- [daily question 1] find the nearest common ancestor of two nodes in the binary tree
猜你喜欢

Small program graduation project of wechat enterprise company (1) development outline

Simple canvas animation principle

Common regular expressions

Flutter:环境搭建、项目创建

Mobile automation uses commands to view the app package name

How to disable shutter raisedbutton

App测试流程及测试点

第54章 业务逻辑之折扣、商品类别实体定义实现

小白必学的现货黄金知识(24个术语)

C primer plus learning notes - 4. File IO (input / output)
随机推荐
一、mysql的安装部署
"Everyday Mathematics" serial 59: February 28
51单片机智能家居环境检测 烟雾温度GSM短信提示报警器(原理图+程序+仿真+PCB)
[daily question 1] binary search tree and bidirectional linked list
猫狗分类-VGG16-bottleneck
【u-boot】u-boot Sandbox编译构建和使用总结
CVPR | 基于密度与深度分解的自增强非成对图像去雾
How to disable shutter raisedbutton
小程序毕设作品之微信教室预约小程序毕业设计(3)后台功能
Mallbook: how to promote the rapid implementation of supply chain finance through the supply chain settlement management system?
分许伦敦银最新走势还得看K线
golang 处理web post、get请求以及string to json格式的转化
Day102.尚医通项目
ui文件转换为py文件方法
送你的代码上太空,一起开发“最伟大的作品”
Differences and relations between malloc, vmalloc and kmalloc, free, kfree and vfree
How does the distribution system bring a revolution in transaction efficiency to chain convenience stores?
小程序毕设作品之微信教室预约小程序毕业设计(5)任务书
Libraries in dart
EasyCVR服务无法启动的原因分析及磁盘空间易满的处理小技巧