TensorFlow解决MNIST数字识别问题
编辑
384
2019-10-14
废话
这个MNIST数字识别问题是我实现的第一个神经网络,虽然过程基本上都是对着书上的代码敲,但还是对神经网络的训练过程有了一定的了解,同时也复习了前面几章关于TensorFlow和神经网络的一些基本概念。
MNIST介绍
MNIST是一个非常有名的手写体数字识别数据集,通常用来作为深度学习的入门样例。
MNIST的数据集可以在http://yann.lecun.com/exdb/mnist/下载
TensorFlow提供了一个类来处理MNIST数据,能够自动下载并转化MNIST数据的格式。
训练神经网络
直接先贴代码
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# MNIST相关常数
INPUT_NODE = 784
OUTPUT_NODE = 10
# 神经网络参数
LAYER1_NODE = 500
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
if avg_class == None:
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
return tf.matmul(layer1, weights2) + biases2
else:
layer1 = tf.nn.relu(
tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1)
)
return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)
def train(mnist):
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
weights1 = tf.Variable(
tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
weights2 = tf.Variable(
tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
y = inference(x, None, weights1, biases1, weights2, biases2)
global_step = tf.Variable(0, trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
regularization = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularization
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE, global_step, mnist.train.num_examples, LEARNING_RATE_DECAY
)
train_step = tf.train.GradientDescentOptimizer(learning_rate) \
.minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
tf.global_variables_initializer().run()
validate_feed = {
x: mnist.validation.images,
y_: mnist.validation.labels
}
test_feed = {x: mnist.test.images, y_: mnist.test.labels}
for i in range(TRAINING_STEPS):
if i % 1000 == 0:
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
test_acc = sess.run(accuracy, feed_dict=test_feed)
print("After %d training step(s), validation accuracy "
"using average model is %g , test accuracy is %g" % (i, validate_acc, test_acc))
xs, ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op, feed_dict={x: xs, y_: ys})
test_acc = sess.run(accuracy, feed_dict=test_feed)
print("After %d training step(s), test accuracy using average model is %g" % (TRAINING_STEPS, test_acc))
def main(argv=None):
mnist = input_data.read_data_sets("/temp/data", one_hot=True)
train(mnist)
if __name__ == '__main__':
tf.app.run()
然后是输出结果
Extracting /temp/data\train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:/Users/lesil/PycharmProjects/matchzoo/MNIST.py:93: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /temp/data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting /temp/data\t10k-images-idx3-ubyte.gz
Extracting /temp/data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
2019-08-11 11:43:46.478172: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
After 0 training step(s), validation accuracy using average model is 0.1596 , test accuracy is 0.1702
After 1000 training step(s), validation accuracy using average model is 0.9766 , test accuracy is 0.975
After 2000 training step(s), validation accuracy using average model is 0.9812 , test accuracy is 0.9809
After 3000 training step(s), validation accuracy using average model is 0.9828 , test accuracy is 0.9828
After 4000 training step(s), validation accuracy using average model is 0.9836 , test accuracy is 0.9837
After 5000 training step(s), validation accuracy using average model is 0.9834 , test accuracy is 0.9835
After 6000 training step(s), validation accuracy using average model is 0.985 , test accuracy is 0.985
After 7000 training step(s), validation accuracy using average model is 0.9846 , test accuracy is 0.9845
After 8000 training step(s), validation accuracy using average model is 0.9852 , test accuracy is 0.9842
After 9000 training step(s), validation accuracy using average model is 0.9844 , test accuracy is 0.9852
After 10000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9844
After 11000 training step(s), validation accuracy using average model is 0.9854 , test accuracy is 0.9845
After 12000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.984
After 13000 training step(s), validation accuracy using average model is 0.9844 , test accuracy is 0.984
After 14000 training step(s), validation accuracy using average model is 0.9854 , test accuracy is 0.9842
After 15000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9842
After 16000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9841
After 17000 training step(s), validation accuracy using average model is 0.9856 , test accuracy is 0.9838
After 18000 training step(s), validation accuracy using average model is 0.9848 , test accuracy is 0.9848
After 19000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9835
After 20000 training step(s), validation accuracy using average model is 0.9864 , test accuracy is 0.9844
After 21000 training step(s), validation accuracy using average model is 0.9868 , test accuracy is 0.9845
After 22000 training step(s), validation accuracy using average model is 0.9856 , test accuracy is 0.9844
After 23000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9842
After 24000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9845
After 25000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9845
After 26000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9843
After 27000 training step(s), validation accuracy using average model is 0.9864 , test accuracy is 0.984
After 28000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9843
After 29000 training step(s), validation accuracy using average model is 0.9864 , test accuracy is 0.9842
After 30000 training step(s), test accuracy using average model is 0.9846
Process finished with exit code 0
几个坑点
- 书上的代码有部分缩进错误,在python中缩进错误是直接gg的。在这里要通过看训练的过程(也就是train函数的部分)纠正一下原本的缩进错误。
- 在使用L2正则化损失函数时,注意是l2而不是12,因为这里ide没有补全提示__(为什么?)****比较容易出现typo。
训练过程
一轮训练的过程
首先计算当前参数下神经网络前向传播的结果,然后在所有代表神经网络参数的变量上使用滑动平均,然后计算使用了滑动平均之后的前向传播
- 0
- 0
-
赞助
微信赞赏码 -
分享