딥러닝의 정석 예제 코드 chapter 3 텐서플로로 신경망 구현하기 MNIST

딥러닝 공부를 위해 딥러닝의 정석 (Fundamentals of Deep Learning)이라는 책을 구매했다.
책의 예제 코드는 https://github.com/darksigma/Fundamentals-of-deep-learning-book에서 받을 수 있다.
위 github에 올라간 소스는 빌드할 때 간혹 빌드가 되지 않는 경우가 많았다.

아래는 chapter 3 텐서플로로 신경망 구현하기의 예제 코드로 MNIST 데이터를 학습하는 코드와 학습 결과다.
책의 예제와 다르게 dropout을 추가하였다.
MNIST 데이터는 처음 코드 실행 시 해당위치에 파일이 없을 경우 다운 받은 후 실행된다.


[소스코드]

import tensorflow as tf
import time, shutil, os
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
#tf.set_random_seed(777)
mnist = input_data.read_data_sets("../MNIST_data/",one_hot=True)

# Architecture
n_hidden_1 = 256
n_hidden_2 = 256

# Parameters
learning_rate = 0.001
training_epochs = 30
batch_size = 100
display_step = 1

def layer(input, weight_shape, bias_shape):
    weight_init = tf.random_normal_initializer(stddev=(2.0/weight_shape[0])**0.5)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable("W",weight_shape,initializer=weight_init)
    b = tf.get_variable("b",bias_shape,initializer=bias_init)
    return tf.nn.relu(tf.matmul(input,W)+b)

def inference(x,keep_prob):
    with tf.variable_scope("hidden_1"):
        _hidden_1 = layer(x,[784,n_hidden_1],[n_hidden_1])
        hidden_1 = tf.nn.dropout(_hidden_1,keep_prob)
    
    with tf.variable_scope("hidden_2"):
        _hidden_2 = layer(hidden_1,[n_hidden_1,n_hidden_2],[n_hidden_2])
        hidden_2 = tf.nn.dropout(_hidden_2,keep_prob)
    
    with tf.variable_scope("output"):
        output = layer(hidden_2,[n_hidden_2,10],[10])
        
    return output

#loss func
def loss(output, y):
    xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=y)
    loss = tf.reduce_mean(xentropy)
    return loss

#Training
def training(cost, global_step):
    tf.summary.scalar("cost",cost)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(cost,global_step=global_step)
    return train_op

#Evaluate
def evaluate(output, y):
    correct_prediction = tf.equal(tf.argmax(output,1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    tf.summary.scalar("validation",accuracy)
    return accuracy

if __name__ == '__main__':
    
    if os.path.exists('mlp_logs/'):
        shutil.rmtree('mlp_logs/',ignore_errors=True)
    
    with tf.Graph().as_default():
        with tf.variable_scope('mlp_model'):
            x = tf.placeholder(tf.float32, [None,784]) #mnist data image of shape 28*28=784
            y = tf.placeholder(tf.float32,[None,10]) #0~9 digits recognition -> 10 classes
            keep_prob = tf.placeholder(tf.float32)
            
            output = inference(x,keep_prob)
            cost = loss(output,y)
            global_step = tf.Variable(0,name='global_step',trainable=False)
            train_op = training(cost,global_step)
            eval_op = evaluate(output,y)
            
            summary_op = tf.summary.merge_all()
            saver=tf.train.Saver()
            
            sess=tf.Session()
            
            summary_writer = tf.summary.FileWriter('mlp_logs/',graph_def=sess.graph_def)
            
            init_op = tf.global_variables_initializer()
            
            sess.run(init_op)
            
            #saver.restore(sess, "mlp_logs/model-checkpoint-66000")
            
            #Training cycle
            for epoch in range(training_epochs):
                
                avg_cost = 0.
                total_batch= int(mnist.train.num_examples/batch_size)
                
                for i in range(total_batch):
                    minibatch_x, minibatch_y = mnist.train.next_batch(batch_size)
                    
                    sess.run(train_op,feed_dict={x:minibatch_x, y:minibatch_y, keep_prob:0.5})
                    avg_cost += sess.run(cost,feed_dict={x:minibatch_x, y:minibatch_y, keep_prob:0.5})/total_batch
                
                if epoch % display_step == 0:
                    print("Epoch:", '%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
                    summary_str,accuracy = sess.run([summary_op,eval_op], feed_dict={x:mnist.validation.images, y:mnist.validation.labels, keep_prob:1})
                    print("validation accuracy :",accuracy)
                    summary_writer.add_summary(summary_str,sess.run(global_step))
                    saver.save(sess,'mlp_logs/model-checkpoint',global_step=global_step)
            
            print("Optimization Finished!")
            
            accuracy = sess.run(eval_op, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob:1})
            print("Test Accuracy:", accuracy)


[실행 결과]

Python 3.7.3 (default, Apr 24 2019, 15:29:51) [MSC v.1915 64 bit (AMD64)]
Type "copyright", "credits" or "license" for more information.

IPython 7.6.1 -- An enhanced Interactive Python.

runfile('D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py', wdir='D:/working/deep_learning/Fundamentals-of-deep-learning')
WARNING:tensorflow:From D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:14: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:36: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
WARNING:tensorflow:From D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:49: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated. Pass a `Graph` object instead, such as `sess.graph`.
Epoch: 0001 cost= 0.525223048
validation accuracy : 0.9502
Epoch: 0002 cost= 0.229476243
validation accuracy : 0.9648
Epoch: 0003 cost= 0.182049453
validation accuracy : 0.9718
Epoch: 0004 cost= 0.154824135
validation accuracy : 0.9738
Epoch: 0005 cost= 0.135052014
validation accuracy : 0.9762
Epoch: 0006 cost= 0.120618501
validation accuracy : 0.9764
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py:966: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
Epoch: 0007 cost= 0.113985141
validation accuracy : 0.9788
Epoch: 0008 cost= 0.105176964
validation accuracy : 0.9786
Epoch: 0009 cost= 0.099285459
validation accuracy : 0.9792
Epoch: 0010 cost= 0.093149076
validation accuracy : 0.978
Epoch: 0011 cost= 0.089750568
validation accuracy : 0.9792
Epoch: 0012 cost= 0.082228115
validation accuracy : 0.9814
Epoch: 0013 cost= 0.083891792
validation accuracy : 0.9818
Epoch: 0014 cost= 0.080721498
validation accuracy : 0.9822
Epoch: 0015 cost= 0.077188218
validation accuracy : 0.983
Epoch: 0016 cost= 0.073490976
validation accuracy : 0.9804
Epoch: 0017 cost= 0.071725895
validation accuracy : 0.981
Epoch: 0018 cost= 0.069402803
validation accuracy : 0.9826
Epoch: 0019 cost= 0.068292858
validation accuracy : 0.9814
Epoch: 0020 cost= 0.069043860
validation accuracy : 0.9816
Epoch: 0021 cost= 0.063489767
validation accuracy : 0.9828
Epoch: 0022 cost= 0.063236821
validation accuracy : 0.9836
Epoch: 0023 cost= 0.066826821
validation accuracy : 0.9836
Epoch: 0024 cost= 0.062443590
validation accuracy : 0.9818
Epoch: 0025 cost= 0.057299115
validation accuracy : 0.9818
Epoch: 0026 cost= 0.059020603
validation accuracy : 0.9814
Epoch: 0027 cost= 0.056841880
validation accuracy : 0.9826
Epoch: 0028 cost= 0.059074335
validation accuracy : 0.9816
Epoch: 0029 cost= 0.058189238
validation accuracy : 0.9828
Epoch: 0030 cost= 0.053373935
validation accuracy : 0.9832
Optimization Finished!
Test Accuracy: 0.9826

관련 글:

댓글

이 블로그의 인기 게시물

간단한 cfar 알고리즘에 대해

쉽게 설명한 파티클 필터(particle filter) 동작 원리와 예제

아두이노(arduino) 심박센서 (heart rate sensor) 심박수 측정 example code

리눅스 디바이스 드라이버 기초와 예제

windows에서 간단하게 크롬캐스트(Chromecast)를 통해 윈도우 화면 미러링 방법