单隐含层神经网络mnist手写识别keras实现

2017年2月11日 8984点热度 0人点赞 0条评论

mnist数据集包含60000个训练样本与10000个测试样本,每个样本为20*20的像素图片。期望学习图像得到识别(分类)图像中表示的阿拉伯数字(0~9)。

mnist经典数据集被用作分类算法的常用测试集,目前多种方法对mnist的处理有不同的性能:

Type Classifier Distortion Preprocessing Error rate (%)
Linear classifier Pairwise linear classifier None Deskewing 7.6[9]
K-Nearest Neighbors K-NN with non-linear deformation (P2DHMDM) None Shiftable edges 0.52[17]
Boosted Stumps Product of stumps on Haar features None Haar features 0.87[18]
Non-Linear Classifier 40 PCA + quadratic classifier None None 3.3[9]
Support vector machine Virtual SVM, deg-9 poly, 2-pixel jittered None Deskewing 0.56[19]
Neural network 2-layer 784-800-10 None None 1.6[20]
Neural network 2-layer 784-800-10 elastic distortions None 0.7[20]
Deep neural network 6-layer 784-2500-2000-1500-1000-500-10 elastic distortions None 0.35[21]
Convolutional neural network 6-layer 784-40-80-500-1000-2000-10 None Expansion of the training data 0.31[14]
Convolutional neural network 6-layer 784-50-100-500-1000-10-10 None Expansion of the training data 0.27[15]
Convolutional neural network Committee of 35 CNNs, 1-20-P-40-P-150-10 elastic distortions Width normalizations 0.23[8]
Convolutional neural network Committee of 5 CNNs, 6-layer 784-50-100-500-1000-10-10 None Expansion of the training data 0.21[16]

这里不以优化为目的,只使用keras实现。

隐含层1层。输入层激活函数为tanh,隐含层激活函数为tanh,输出层级或函数为softmax

 

 

代码:

 

from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Activation, Dense, Dropout, Flatten
import numpy as np
import pandas as pd

# np.random.seed(133723)  # for reproducibility

trainFileName = "./data/train.csv"
nb_classes = 10
nb_epoch = 20
batch_size = 128


# rawData = np.genfromtxt(trainFileName, delimiter=',')
rawData = np.array(pd.read_csv(trainFileName))
print("raw data size", rawData.shape)
''' data '''
trainData = rawData[0:int(rawData.shape[0] * 0.7)]
# 划分训练集与测试集
testData = rawData[int(rawData.shape[0] * 0.7):]


''' label '''
trainLabel = trainData[:, 0]
testLabel = testData[:, 0]
trainData = trainData[:, 1:]
testData = testData[:, 1:]
input_shape = trainData.shape[1]

y_train = np_utils.to_categorical(trainLabel, nb_classes)
y_test = np_utils.to_categorical(testLabel, nb_classes)
print("convert train label to ", y_train.shape[1], "classes")
print("convert test label to ", y_test.shape[1], "classes")

''' keras model'''
print("trainData shape:", trainData.shape, "y_train shape", y_train.shape)
print("testData shape:", testData.shape, "y_test shape", y_test.shape)

model = Sequential()

model.add(Dense(output_dim=128, input_dim=input_shape, init='normal', activation='tanh'))
model.add(Dropout(0.25))
model.add(Dense(output_dim=128, init='normal', activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(output_dim=nb_classes, init='normal', activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(trainData, y_train, batch_size=batch_size, nb_epoch=nb_epoch,
          verbose=1, validation_data=(testData,y_test))
score = model.evaluate(testData, y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
W, b = model.layers[0].get_weights()
# print('Weights=', W, '\n biases=', b)
y_test_pred = model.predict(testData)
print(np_utils.probas_to_classes(y_test)[:10],np_utils.probas_to_classes(y_test_pred)[:10])

'''test '''
predictFileName = "./data/test.csv"
x_pre = np.array(pd.read_csv(predictFileName))
print("x_pre shape", x_pre.shape)
y_pred = model.predict(x_pre)
pre_class = np_utils.probas_to_classes(y_pred)
print("predict class :", pre_class)
index = np.linspace(1,len(pre_class),len(pre_class))
print("index",index)

np.savetxt("./data/pre.csv", list(zip(index,pre_class)), delimiter=',',fmt='%10.5f')

 

References

[1] https://www.kaggle.com/c/digit-recognizer/data

[2] https://en.wikipedia.org/wiki/MNIST_database

Dong Wang

Master student of computer science at Uppsala University in Sweden. My primary research interests are deep learning, computer vision, federated learning and internet-of-things.

文章评论

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据