딥러닝의 정석 예제 코드 chapter 3 텐서플로로 신경망 구현하기 MNIST

딥러닝 공부를 위해 딥러닝의 정석 (Fundamentals of Deep Learning)이라는 책을 구매했다.
책의 예제 코드는 https://github.com/darksigma/Fundamentals-of-deep-learning-book에서 받을 수 있다.
위 github에 올라간 소스는 빌드할 때 간혹 빌드가 되지 않는 경우가 많았다.

아래는 chapter 3 텐서플로로 신경망 구현하기의 예제 코드로 MNIST 데이터를 학습하는 코드와 학습 결과다.
책의 예제와 다르게 dropout을 추가하였다.
MNIST 데이터는 처음 코드 실행 시 해당위치에 파일이 없을 경우 다운 받은 후 실행된다.


[소스코드]

import tensorflow as tf
import time, shutil, os
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
#tf.set_random_seed(777)
mnist = input_data.read_data_sets("../MNIST_data/",one_hot=True)

# Architecture
n_hidden_1 = 256
n_hidden_2 = 256

# Parameters
learning_rate = 0.001
training_epochs = 30
batch_size = 100
display_step = 1

def layer(input, weight_shape, bias_shape):
    weight_init = tf.random_normal_initializer(stddev=(2.0/weight_shape[0])**0.5)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable("W",weight_shape,initializer=weight_init)
    b = tf.get_variable("b",bias_shape,initializer=bias_init)
    return tf.nn.relu(tf.matmul(input,W)+b)

def inference(x,keep_prob):
    with tf.variable_scope("hidden_1"):
        _hidden_1 = layer(x,[784,n_hidden_1],[n_hidden_1])
        hidden_1 = tf.nn.dropout(_hidden_1,keep_prob)
    
    with tf.variable_scope("hidden_2"):
        _hidden_2 = layer(hidden_1,[n_hidden_1,n_hidden_2],[n_hidden_2])
        hidden_2 = tf.nn.dropout(_hidden_2,keep_prob)
    
    with tf.variable_scope("output"):
        output = layer(hidden_2,[n_hidden_2,10],[10])
        
    return output

#loss func
def loss(output, y):
    xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=y)
    loss = tf.reduce_mean(xentropy)
    return loss

#Training
def training(cost, global_step):
    tf.summary.scalar("cost",cost)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(cost,global_step=global_step)
    return train_op

#Evaluate
def evaluate(output, y):
    correct_prediction = tf.equal(tf.argmax(output,1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    tf.summary.scalar("validation",accuracy)
    return accuracy

if __name__ == '__main__':
    
    if os.path.exists('mlp_logs/'):
        shutil.rmtree('mlp_logs/',ignore_errors=True)
    
    with tf.Graph().as_default():
        with tf.variable_scope('mlp_model'):
            x = tf.placeholder(tf.float32, [None,784]) #mnist data image of shape 28*28=784
            y = tf.placeholder(tf.float32,[None,10]) #0~9 digits recognition -> 10 classes
            keep_prob = tf.placeholder(tf.float32)
            
            output = inference(x,keep_prob)
            cost = loss(output,y)
            global_step = tf.Variable(0,name='global_step',trainable=False)
            train_op = training(cost,global_step)
            eval_op = evaluate(output,y)
            
            summary_op = tf.summary.merge_all()
            saver=tf.train.Saver()
            
            sess=tf.Session()
            
            summary_writer = tf.summary.FileWriter('mlp_logs/',graph_def=sess.graph_def)
            
            init_op = tf.global_variables_initializer()
            
            sess.run(init_op)
            
            #saver.restore(sess, "mlp_logs/model-checkpoint-66000")
            
            #Training cycle
            for epoch in range(training_epochs):
                
                avg_cost = 0.
                total_batch= int(mnist.train.num_examples/batch_size)
                
                for i in range(total_batch):
                    minibatch_x, minibatch_y = mnist.train.next_batch(batch_size)
                    
                    sess.run(train_op,feed_dict={x:minibatch_x, y:minibatch_y, keep_prob:0.5})
                    avg_cost += sess.run(cost,feed_dict={x:minibatch_x, y:minibatch_y, keep_prob:0.5})/total_batch
                
                if epoch % display_step == 0:
                    print("Epoch:", '%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
                    summary_str,accuracy = sess.run([summary_op,eval_op], feed_dict={x:mnist.validation.images, y:mnist.validation.labels, keep_prob:1})
                    print("validation accuracy :",accuracy)
                    summary_writer.add_summary(summary_str,sess.run(global_step))
                    saver.save(sess,'mlp_logs/model-checkpoint',global_step=global_step)
            
            print("Optimization Finished!")
            
            accuracy = sess.run(eval_op, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob:1})
            print("Test Accuracy:", accuracy)


[실행 결과]

Python 3.7.3 (default, Apr 24 2019, 15:29:51) [MSC v.1915 64 bit (AMD64)]
Type "copyright", "credits" or "license" for more information.

IPython 7.6.1 -- An enhanced Interactive Python.

runfile('D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py', wdir='D:/working/deep_learning/Fundamentals-of-deep-learning')
WARNING:tensorflow:From D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:14: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:36: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
WARNING:tensorflow:From D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:49: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated. Pass a `Graph` object instead, such as `sess.graph`.
Epoch: 0001 cost= 0.525223048
validation accuracy : 0.9502
Epoch: 0002 cost= 0.229476243
validation accuracy : 0.9648
Epoch: 0003 cost= 0.182049453
validation accuracy : 0.9718
Epoch: 0004 cost= 0.154824135
validation accuracy : 0.9738
Epoch: 0005 cost= 0.135052014
validation accuracy : 0.9762
Epoch: 0006 cost= 0.120618501
validation accuracy : 0.9764
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py:966: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
Epoch: 0007 cost= 0.113985141
validation accuracy : 0.9788
Epoch: 0008 cost= 0.105176964
validation accuracy : 0.9786
Epoch: 0009 cost= 0.099285459
validation accuracy : 0.9792
Epoch: 0010 cost= 0.093149076
validation accuracy : 0.978
Epoch: 0011 cost= 0.089750568
validation accuracy : 0.9792
Epoch: 0012 cost= 0.082228115
validation accuracy : 0.9814
Epoch: 0013 cost= 0.083891792
validation accuracy : 0.9818
Epoch: 0014 cost= 0.080721498
validation accuracy : 0.9822
Epoch: 0015 cost= 0.077188218
validation accuracy : 0.983
Epoch: 0016 cost= 0.073490976
validation accuracy : 0.9804
Epoch: 0017 cost= 0.071725895
validation accuracy : 0.981
Epoch: 0018 cost= 0.069402803
validation accuracy : 0.9826
Epoch: 0019 cost= 0.068292858
validation accuracy : 0.9814
Epoch: 0020 cost= 0.069043860
validation accuracy : 0.9816
Epoch: 0021 cost= 0.063489767
validation accuracy : 0.9828
Epoch: 0022 cost= 0.063236821
validation accuracy : 0.9836
Epoch: 0023 cost= 0.066826821
validation accuracy : 0.9836
Epoch: 0024 cost= 0.062443590
validation accuracy : 0.9818
Epoch: 0025 cost= 0.057299115
validation accuracy : 0.9818
Epoch: 0026 cost= 0.059020603
validation accuracy : 0.9814
Epoch: 0027 cost= 0.056841880
validation accuracy : 0.9826
Epoch: 0028 cost= 0.059074335
validation accuracy : 0.9816
Epoch: 0029 cost= 0.058189238
validation accuracy : 0.9828
Epoch: 0030 cost= 0.053373935
validation accuracy : 0.9832
Optimization Finished!
Test Accuracy: 0.9826

관련 글:

댓글

이 블로그의 인기 게시물

간단한 cfar 알고리즘에 대해

windows에서 간단하게 크롬캐스트(Chromecast)를 통해 윈도우 화면 미러링 방법

아두이노(arduino) 심박센서 (heart rate sensor) 심박수 측정 example code

CA-CFAR 예제 코드

바로 프로젝트 적용 가능한 FIR Filter (low/high/band pass filter )를 c나 python으로 만들기