找到合适的 CNN 模型架构和参数

数据挖掘 神经网络 美国有线电视新闻网 训练 图像预处理 卷积神经网络
2022-03-13 13:49:09

我目前正在创建一个 CNN 模型,用于对字体是否为ArialVerdanaTimes New Roman进行分类Georgia总而言之,有类16因为我还考虑过检测字体是regular,bold还是. 所以italicsbold italics4 fonts * 4 styles = 16 classes

我在训练中使用的数据如下:

 Training data set : 800 image patches of 256 * 256 dimension (50 for each class)
 Validation data set : 320 image patches of 256 * 256 dimension (20 for each class)
 Testing data set : 160 image patches of 256 * 256 dimension (10 for each class)

以下是我的数据的示例屏幕截图:

在此处输入图像描述

下面是我的初始代码:

import numpy as np
import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Activation
from keras.layers.core import Dense, Flatten
from keras.optimizers import Adam
from keras.metrics import categorical_crossentropy
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import *
from matplotlib import pyplot as plt
import itertools
import matplotlib.pyplot as plt
import pickle


image_width = 256
image_height = 256

train_path = 'font_model_data/train'
valid_path =  'font_model_data/valid'
test_path = 'font_model_data/test'


train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 16)
valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 16)
test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 160)


 imgs, labels = next(train_batches)

 #CNN model
 model = Sequential([
     Conv2D(32, (3,3), activation='relu', input_shape=(image_width, image_height, 3)),
     Flatten(),
     Dense(16, activation='softmax'),
 ])

 print(model.summary())

 model.compile(Adam(lr=.0001),loss='categorical_crossentropy', metrics=['accuracy'])
 model.fit_generator(train_batches, steps_per_epoch = 50, validation_data= valid_batches, validation_steps = 20, epochs = 1, verbose = 2)

 model_pickle = open('cnn_font_model.pickle', 'wb')
 pickle.dump(model, model_pickle)
 model_pickle.close()
 print('Training Done.')

 test_imgs, test_labels = next(test_batches)

 predictions = model.predict_generator(test_batches, steps = 160, verbose = 2)
 print(predictions)

谁能建议我如何知道正确的网络架构和参数以获得最佳精度?我应该如何开始调整我的网络?

1个回答

在深度学习的许多情况下,从容量非常高且可能过拟合的模型开始效果很好。从那时起,您可以减少模型容量以缩小训练和验证错误之间的差距。在Goodwell 的深度学习书的这一章中,您可以找到对手动超参数选择以及它们如何影响模型容量的很好描述。

此外,对于许多任务,已经存在精心设计的解决方案。因此,请检查对类似任务有效的方法并尝试这些架构。例如,MNIST 手写识别与您的任务有些相似。Wikipedia提供了几种在 MNIST 上运行良好的架构。

文章“Handwritten Digit Recognition using Convolutional Neural Networks in Python with Keras”也包含了 MNIST 的架构。同样,这可能与您的任务足够接近,因此我建议您尝试一下。这篇文章包括一个与您的架构非常相似但也更复杂的架构。

在其他情况下,您还可以查看预训练模型但在这里它甚至可能不需要。