TensorFlow解决MNIST数字识别问题

废话

这个MNIST数字识别问题是我实现的第一个神经网络,虽然过程基本上都是对着书上的代码敲,但还是对神经网络的训练过程有了一定的了解,同时也复习了前面几章关于TensorFlow和神经网络的一些基本概念。

MNIST介绍

MNIST是一个非常有名的手写体数字识别数据集,通常用来作为深度学习的入门样例。

MNIST的数据集可以在http://yann.lecun.com/exdb/mnist/下载

TensorFlow提供了一个类来处理MNIST数据,能够自动下载并转化MNIST数据的格式。

训练神经网络

直接先贴代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# MNIST相关常数
INPUT_NODE = 784
OUTPUT_NODE = 10

# 神经网络参数
LAYER1_NODE = 500
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99


def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
    if avg_class == None:
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
        return tf.matmul(layer1, weights2) + biases2
    else:
        layer1 = tf.nn.relu(
            tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1)
        )
        return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)


def train(mnist):
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
    y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')

    weights1 = tf.Variable(
        tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
    biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
    weights2 = tf.Variable(
        tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
    biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))

    y = inference(x, None, weights1, biases1, weights2, biases2)

    global_step = tf.Variable(0, trainable=False)

    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)

    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)

    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularization = regularizer(weights1) + regularizer(weights2)
    loss = cross_entropy_mean + regularization

    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE, global_step, mnist.train.num_examples, LEARNING_RATE_DECAY
    )

    train_step = tf.train.GradientDescentOptimizer(learning_rate) \
        .minimize(loss, global_step=global_step)

    with tf.control_dependencies([train_step, variables_averages_op]):
        train_op = tf.no_op(name='train')

    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        validate_feed = {
            x: mnist.validation.images,
            y_: mnist.validation.labels
        }

        test_feed = {x: mnist.test.images, y_: mnist.test.labels}

        for i in range(TRAINING_STEPS):
            if i % 1000 == 0:
                validate_acc = sess.run(accuracy, feed_dict=validate_feed)
                test_acc = sess.run(accuracy, feed_dict=test_feed)
                print("After %d training step(s), validation accuracy "
                      "using average model is %g , test accuracy is %g" % (i, validate_acc, test_acc))
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op, feed_dict={x: xs, y_: ys})

        test_acc = sess.run(accuracy, feed_dict=test_feed)
        print("After %d training step(s), test accuracy using average model is %g" % (TRAINING_STEPS, test_acc))


def main(argv=None):
    mnist = input_data.read_data_sets("/temp/data", one_hot=True)
    train(mnist)


if __name__ == '__main__':
    tf.app.run()

然后是输出结果

Extracting /temp/data\train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:/Users/lesil/PycharmProjects/matchzoo/MNIST.py:93: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /temp/data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting /temp/data\t10k-images-idx3-ubyte.gz
Extracting /temp/data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\lesil\Anaconda3\envs\matchzoo\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
2019-08-11 11:43:46.478172: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
After 0 training step(s), validation accuracy using average model is 0.1596 , test accuracy is 0.1702
After 1000 training step(s), validation accuracy using average model is 0.9766 , test accuracy is 0.975
After 2000 training step(s), validation accuracy using average model is 0.9812 , test accuracy is 0.9809
After 3000 training step(s), validation accuracy using average model is 0.9828 , test accuracy is 0.9828
After 4000 training step(s), validation accuracy using average model is 0.9836 , test accuracy is 0.9837
After 5000 training step(s), validation accuracy using average model is 0.9834 , test accuracy is 0.9835
After 6000 training step(s), validation accuracy using average model is 0.985 , test accuracy is 0.985
After 7000 training step(s), validation accuracy using average model is 0.9846 , test accuracy is 0.9845
After 8000 training step(s), validation accuracy using average model is 0.9852 , test accuracy is 0.9842
After 9000 training step(s), validation accuracy using average model is 0.9844 , test accuracy is 0.9852
After 10000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9844
After 11000 training step(s), validation accuracy using average model is 0.9854 , test accuracy is 0.9845
After 12000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.984
After 13000 training step(s), validation accuracy using average model is 0.9844 , test accuracy is 0.984
After 14000 training step(s), validation accuracy using average model is 0.9854 , test accuracy is 0.9842
After 15000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9842
After 16000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9841
After 17000 training step(s), validation accuracy using average model is 0.9856 , test accuracy is 0.9838
After 18000 training step(s), validation accuracy using average model is 0.9848 , test accuracy is 0.9848
After 19000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9835
After 20000 training step(s), validation accuracy using average model is 0.9864 , test accuracy is 0.9844
After 21000 training step(s), validation accuracy using average model is 0.9868 , test accuracy is 0.9845
After 22000 training step(s), validation accuracy using average model is 0.9856 , test accuracy is 0.9844
After 23000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9842
After 24000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9845
After 25000 training step(s), validation accuracy using average model is 0.9862 , test accuracy is 0.9845
After 26000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9843
After 27000 training step(s), validation accuracy using average model is 0.9864 , test accuracy is 0.984
After 28000 training step(s), validation accuracy using average model is 0.9858 , test accuracy is 0.9843
After 29000 training step(s), validation accuracy using average model is 0.9864 , test accuracy is 0.9842
After 30000 training step(s), test accuracy using average model is 0.9846

Process finished with exit code 0

几个坑点

  • 书上的代码有部分缩进错误,在python中缩进错误是直接gg的。在这里要通过看训练的过程(也就是train函数的部分)纠正一下原本的缩进错误。
  • 在使用L2正则化损失函数时,注意是l2而不是12,因为这里ide没有补全提示_​_(为什么?)****​​比较容易出现typo。

训练过程

​一轮训练的过程

首先计算当前参数下神经网络前向传播的结果,然后在所有代表神经网络参数的变量上使用滑动平均,然后计算使用了滑动平均之后的前向传播

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×