딥러닝의 정석 예제 코드 chapter 3 텐서플로로 신경망 구현하기 MNIST
딥러닝 공부를 위해 딥러닝의 정석 (Fundamentals of Deep Learning)이라는 책을
구매했다.
책의 예제 코드는 https://github.com/darksigma/Fundamentals-of-deep-learning-book에서 받을 수 있다.
위 github에 올라간 소스는 빌드할 때 간혹 빌드가 되지 않는 경우가 많았다.
아래는 chapter 3 텐서플로로 신경망 구현하기의 예제 코드로 MNIST 데이터를
학습하는 코드와 학습 결과다.
책의 예제와 다르게 dropout을 추가하였다.
MNIST 데이터는 처음 코드 실행 시 해당위치에 파일이 없을 경우 다운 받은 후
실행된다.
[소스코드]
import tensorflow as tf
import time, shutil, os
from tensorflow.examples.tutorials.mnist import input_data
tf.reset_default_graph()
#tf.set_random_seed(777)
mnist = input_data.read_data_sets("../MNIST_data/",one_hot=True)
# Architecture
n_hidden_1 = 256
n_hidden_2 = 256
# Parameters
learning_rate = 0.001
training_epochs = 30
batch_size = 100
display_step = 1
def layer(input, weight_shape, bias_shape):
weight_init =
tf.random_normal_initializer(stddev=(2.0/weight_shape[0])**0.5)
bias_init = tf.constant_initializer(value=0)
W = tf.get_variable("W",weight_shape,initializer=weight_init)
b = tf.get_variable("b",bias_shape,initializer=bias_init)
return tf.nn.relu(tf.matmul(input,W)+b)
def inference(x,keep_prob):
with tf.variable_scope("hidden_1"):
_hidden_1 =
layer(x,[784,n_hidden_1],[n_hidden_1])
hidden_1 = tf.nn.dropout(_hidden_1,keep_prob)
with tf.variable_scope("hidden_2"):
_hidden_2 =
layer(hidden_1,[n_hidden_1,n_hidden_2],[n_hidden_2])
hidden_2 = tf.nn.dropout(_hidden_2,keep_prob)
with tf.variable_scope("output"):
output = layer(hidden_2,[n_hidden_2,10],[10])
return output
#loss func
def loss(output, y):
xentropy =
tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=y)
loss = tf.reduce_mean(xentropy)
return loss
#Training
def training(cost, global_step):
tf.summary.scalar("cost",cost)
optimizer =
tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(cost,global_step=global_step)
return train_op
#Evaluate
def evaluate(output, y):
correct_prediction = tf.equal(tf.argmax(output,1),
tf.argmax(y,1))
accuracy =
tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
tf.summary.scalar("validation",accuracy)
return accuracy
if __name__ == '__main__':
if os.path.exists('mlp_logs/'):
shutil.rmtree('mlp_logs/',ignore_errors=True)
with tf.Graph().as_default():
with tf.variable_scope('mlp_model'):
x = tf.placeholder(tf.float32,
[None,784]) #mnist data image of shape 28*28=784
y =
tf.placeholder(tf.float32,[None,10]) #0~9 digits recognition -> 10
classes
keep_prob =
tf.placeholder(tf.float32)
output = inference(x,keep_prob)
cost = loss(output,y)
global_step =
tf.Variable(0,name='global_step',trainable=False)
train_op =
training(cost,global_step)
eval_op = evaluate(output,y)
summary_op =
tf.summary.merge_all()
saver=tf.train.Saver()
sess=tf.Session()
summary_writer =
tf.summary.FileWriter('mlp_logs/',graph_def=sess.graph_def)
init_op =
tf.global_variables_initializer()
sess.run(init_op)
#saver.restore(sess,
"mlp_logs/model-checkpoint-66000")
#Training cycle
for epoch in
range(training_epochs):
avg_cost = 0.
total_batch=
int(mnist.train.num_examples/batch_size)
for i in
range(total_batch):
minibatch_x, minibatch_y = mnist.train.next_batch(batch_size)
sess.run(train_op,feed_dict={x:minibatch_x, y:minibatch_y, keep_prob:0.5})
avg_cost += sess.run(cost,feed_dict={x:minibatch_x, y:minibatch_y,
keep_prob:0.5})/total_batch
if epoch %
display_step == 0:
print("Epoch:", '%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
summary_str,accuracy = sess.run([summary_op,eval_op],
feed_dict={x:mnist.validation.images, y:mnist.validation.labels,
keep_prob:1})
print("validation accuracy :",accuracy)
summary_writer.add_summary(summary_str,sess.run(global_step))
saver.save(sess,'mlp_logs/model-checkpoint',global_step=global_step)
print("Optimization Finished!")
accuracy = sess.run(eval_op,
feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob:1})
print("Test Accuracy:", accuracy)
[실행 결과]
Python 3.7.3 (default, Apr 24 2019, 15:29:51) [MSC v.1915 64 bit (AMD64)]
Type "copyright", "credits" or "license" for more information.
IPython 7.6.1 -- An enhanced Interactive Python.
runfile('D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py',
wdir='D:/working/deep_learning/Fundamentals-of-deep-learning')
WARNING:tensorflow:From
D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:14:
read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist)
is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from
tensorflow/models.
WARNING:tensorflow:From
C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260:
maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is
deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From
C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262:
extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist)
is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From
C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267:
extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist)
is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From
C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110:
dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist)
is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From
C:\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290:
DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist)
is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from
tensorflow/models.
WARNING:tensorflow:From
C:\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263:
colocate_with (from tensorflow.python.framework.ops) is deprecated and will
be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From
D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:36:
calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is
deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 -
keep_prob`.
WARNING:tensorflow:From
D:/working/deep_learning/Fundamentals-of-deep-learning/ch3_multilayer_perceptron_mnist.py:49:
softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is
deprecated and will be removed in a future version.
Instructions for updating:
Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.
See `tf.nn.softmax_cross_entropy_with_logits_v2`.
WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated.
Pass a `Graph` object instead, such as `sess.graph`.
Epoch: 0001 cost= 0.525223048
validation accuracy : 0.9502
Epoch: 0002 cost= 0.229476243
validation accuracy : 0.9648
Epoch: 0003 cost= 0.182049453
validation accuracy : 0.9718
Epoch: 0004 cost= 0.154824135
validation accuracy : 0.9738
Epoch: 0005 cost= 0.135052014
validation accuracy : 0.9762
Epoch: 0006 cost= 0.120618501
validation accuracy : 0.9764
WARNING:tensorflow:From
C:\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py:966:
remove_checkpoint (from tensorflow.python.training.checkpoint_management) is
deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
Epoch: 0007 cost= 0.113985141
validation accuracy : 0.9788
Epoch: 0008 cost= 0.105176964
validation accuracy : 0.9786
Epoch: 0009 cost= 0.099285459
validation accuracy : 0.9792
Epoch: 0010 cost= 0.093149076
validation accuracy : 0.978
Epoch: 0011 cost= 0.089750568
validation accuracy : 0.9792
Epoch: 0012 cost= 0.082228115
validation accuracy : 0.9814
Epoch: 0013 cost= 0.083891792
validation accuracy : 0.9818
Epoch: 0014 cost= 0.080721498
validation accuracy : 0.9822
Epoch: 0015 cost= 0.077188218
validation accuracy : 0.983
Epoch: 0016 cost= 0.073490976
validation accuracy : 0.9804
Epoch: 0017 cost= 0.071725895
validation accuracy : 0.981
Epoch: 0018 cost= 0.069402803
validation accuracy : 0.9826
Epoch: 0019 cost= 0.068292858
validation accuracy : 0.9814
Epoch: 0020 cost= 0.069043860
validation accuracy : 0.9816
Epoch: 0021 cost= 0.063489767
validation accuracy : 0.9828
Epoch: 0022 cost= 0.063236821
validation accuracy : 0.9836
Epoch: 0023 cost= 0.066826821
validation accuracy : 0.9836
Epoch: 0024 cost= 0.062443590
validation accuracy : 0.9818
Epoch: 0025 cost= 0.057299115
validation accuracy : 0.9818
Epoch: 0026 cost= 0.059020603
validation accuracy : 0.9814
Epoch: 0027 cost= 0.056841880
validation accuracy : 0.9826
Epoch: 0028 cost= 0.059074335
validation accuracy : 0.9816
Epoch: 0029 cost= 0.058189238
validation accuracy : 0.9828
Epoch: 0030 cost= 0.053373935
validation accuracy : 0.9832
Optimization Finished!
Test Accuracy: 0.9826
관련 글:
딥러닝의 정석 예제 코드 chapter 4 경사 하강법을 넘어서
딥러닝의 정석 예제 코드 chapter 5 합성곱 신경망 MNIST
딥러닝의 정석 예제 코드 chapter 5 합성곱 신경망 CIFAR-10 학습
딥러닝의 정석 예제 코드 chapter 6 임베딩과 표상학습 autoencoder mnist
딥러닝의 정석 예제 코드 chapter 5 합성곱 신경망 MNIST
딥러닝의 정석 예제 코드 chapter 5 합성곱 신경망 CIFAR-10 학습
딥러닝의 정석 예제 코드 chapter 6 임베딩과 표상학습 autoencoder mnist
댓글
댓글 쓰기