딥러닝의 정석 ch 6 skip-gram model 예제 코드

본 글은 딥러닝의 정석(Fundamentals of Deep Learning) 6장 임베딩과 표상학습의 skip-gram 모델 예제 코드를 싣고 있다.

skip-gram 모델에 사용되는 데이터셋을 생성하는 코드와 데이터는 github 페이지에서 다운 받을 수 있다. 다운 받은 후 input_word_data.py에 'import matplotlib.pyplot as plt' 코드를 추가 해야 한다. 

skip-gram 모델의 소스는 아래와 같다. 이 파일도 github에서 다운받을 수 있다. 하지만 일부 라인에서 에러를 일으킬 수 있다. 본 글의 코드는 실행되도록 일부 코드를 수정했다.

import os,shutil
import archive.input_word_data as data
import numpy as np
import tensorflow as tf
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

tf.reset_default_graph()
tf.set_random_seed(777)

# TRAINING PARAMETERS
batch_size = 128                                             # Number of training examples per batch
embedding_size = 128                                        # Dimension of embedding vectors
skip_window = 5                                             # Window size for context to the left and right of target
num_skips = 4                                               # How many times to reuse target to generate a label for context.
batches_per_epoch = int(data.data_size*num_skips/batch_size)     # Number of batches per epoch of training
training_epochs = 3                                         # Number of epochs to utilize for training
neg_size = 64                                               # Number of negative samples to use for NCE
display_step = 10000                                         # Frequency with which to print statistics
val_step = 100000                                            # Frequency with which to perform validation
learning_rate = 0.1                                         # Learning rate for SGD

# NEAREST NEIGHBORS VALIDATION PARAMETERS
val_size = 20
val_dist_span = 500
val_examples = np.random.choice(val_dist_span, val_size, replace=False)
top_match = 8
plot_num = 500


def embedding_layer(x, embedding_shape):
    with tf.variable_scope('embedding'):
        embedding_init = tf.random_uniform(embedding_shape, -1.0, 1.0)
        embedding_matrix = tf.get_variable('E',initializer=embedding_init)
    return tf.nn.embedding_lookup(embedding_matrix,x), embedding_matrix

def noise_contrastive_loss(embedding_lookup, weight_shape, bias_shape, y ):
    with tf.variable_scope('nce'):
        nce_weight_init = tf.truncated_normal(weight_shape, stddev=1.0/(weight_shape[1])**0.5)
        nce_bias_init = tf.zeros(bias_shape)
        nce_W = tf.get_variable('W', initializer=nce_weight_init)
        nce_b = tf.get_variable('b',initializer=nce_bias_init)
        total_loss = tf.nn.nce_loss(weights = nce_W, biases=nce_b, inputs = embedding_lookup, labels = y,
                                    num_sampled = neg_size, num_classes=data.vocabulary_size)
        return tf.reduce_mean(total_loss)


def training(cost, global_step):
    with tf.variable_scope('training'):
        tf.summary.scalar('cost',cost)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.minimize(cost,global_step=global_step)
    return train_op

def validation(embedding_matrix, x_val):
    norm = tf.reduce_sum(embedding_matrix**2,1,keep_dims=True)**0.5
    normalized = embedding_matrix/norm
    val_embeddings = tf.nn.embedding_lookup(normalized,x_val)
    cosine_similarity = tf.matmul(val_embeddings, normalized, transpose_b=True)
    return normalized, cosine_similarity


if __name__ == '__main__':
    
    if os.path.exists('skipgram_logs/'):
        shutil.rmtree('skipgram_logs/',ignore_errors=True)
    
    with tf.variable_scope('skipgram_model'):
        
        x = tf.placeholder(tf.int32, shape=[batch_size])
        y = tf.placeholder(tf.int32, [batch_size,1])
        val = tf.constant(val_examples, dtype=tf.int32)
        global_step = tf.Variable(0,name='global_step', trainable=False)
        
        e_lookup, e_matrix = embedding_layer(x, [data.vocabulary_size, embedding_size])
        cost = noise_contrastive_loss(e_lookup, [data.vocabulary_size, embedding_size], [data.vocabulary_size],y)
        train_op = training(cost, global_step)
        val_op = validation(e_matrix,val)
        
        sess = tf.Session()
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter('skipgram_logs/',sess.graph)
        
        sess.run(tf.global_variables_initializer())
        
        step = 0
        avg_cost = 0
        
        for epoch in range(training_epochs):
            for minibatch in range(batches_per_epoch):
                step += 1
                minibatch_x, minibatch_y = data.generate_batch(batch_size, num_skips, skip_window)
                feed_dict = {x:minibatch_x, y:minibatch_y}
                
                _, new_cost, train_summary = sess.run([train_op, cost,summary_op],feed_dict=feed_dict)
                summary_writer.add_summary(train_summary,sess.run(global_step))
                avg_cost += new_cost/display_step
                
                if step % display_step == 0:
                    print("Elapsed: ", str(step), " batches. Cost =", "{:.9f}".format(avg_cost))
                    avg_cost = 0
                    
                if step % val_step == 0:
                    _, similarity = sess.run(val_op)
                    for i in range(val_size):
                        val_word = data.reverse_dictionary[val_examples[i]]
                        neighbors = (-similarity[i,:]).argsort()
                        neighbors = neighbors[1:top_match+1]
                        print_str = 'Nearest neighbor of ' + val_word + ' :'
                        for k in range(top_match):
                            print_str += ' '+data.reverse_dictionary[neighbors[k]] + ','
                        
                        print(print_str)
                
        final_embeddings,_= sess.run(val_op)
    
    tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
    plot_embedding = np.asfarray(final_embeddings[:plot_num,:], dtype='float')
    low_dim_embs = tsne.fit_transform(plot_embedding)
    labels = [data.reverse_dictionary[i] for i in range(500)]
    data.plot_with_labels(low_dim_embs, labels)

아래는 학습 cost의 변화를 표시한 그래프다. 200k step후 cost의 변화가 이전처럼 많지 않다. 마치 더이상 학습이 아루어지지 않는것 처럼 보인다.

책에 의하면, 학습이 진행 될수록, 관련도가 높은 단어를 찾는다고 한다. 아래는 초기 학습 단계와 학습이 거의 끝난 단계에서 관련도가 높은 단어들 찾아 표시한 것이다.

Elapsed:  100000  batches. Cost = 5.025829982
Nearest neighbor of official : new, turned, ark, approximately, keep, sentence, lunar, implementation,
Nearest neighbor of full : agave, inside, acts, processing, across, users, preserved, rainy,
Nearest neighbor of important : account, miles, l, met, olympic, pull, elect, ft,
Nearest neighbor of west : listing, bullet, interior, names, much, arms, infectious, zeus,
Nearest neighbor of just : buddha, perspective, page, absolute, width, throughout, alive, bulgarian,
Nearest neighbor of t : ammonia, independent, belarus, sources, alaska, otherwise, least, biography,
Nearest neighbor of television : reduction, said, footballer, rapidly, nine, group, generals, in,
Nearest neighbor of km : hunting, succession, writer, one, equally, brazil, isbn, felt,
Nearest neighbor of development : ada, conflict, less, honor, exclusive, centuries, recovered, exception,
Nearest neighbor of this : brown, alaska, browns, he, naturally, so, knights, distinctive,
Nearest neighbor of western : kg, potential, man, anna, r, inventor, claiming, out,
Nearest neighbor of science : province, UNK, apple, human, slavery, underwent, contributed, finite,
Nearest neighbor of islands : electron, helping, diameter, labels, aluminium, execution, of, mean,
Nearest neighbor of under : alaska, beijing, clark, do, defence, victims, appearance, president,
Nearest neighbor of generally : gore, lincoln, davis, independence, area, mount, designs, calculations,
Nearest neighbor of given : million, direction, loyal, branch, powered, gun, executive, developers,
Nearest neighbor of case : swedish, era, anime, sent, daughters, ruled, caesar, are,
Nearest neighbor of must : UNK, nuremberg, verbs, ca, wwii, nintendo, mr, technical,
Nearest neighbor of when : arbitration, increased, quotes, popular, reports, hall, believed, lincoln,
Nearest neighbor of china : ranked, loan, mountain, please, sexual, vessels, english, targeted,


Elapsed:  1500000  batches. Cost = 4.732197985
Nearest neighbor of official : turned, approximately, infectious, new, ark, keep, assigned, primitive,
Nearest neighbor of full : agave, inside, rainy, descriptive, conjugation, UNK, reed, processing,
Nearest neighbor of important : banner, pull, unfortunately, account, probe, actual, steady, treating,
Nearest neighbor of west : infectious, rounds, bullet, listing, healthy, karate, boston, compressed,
Nearest neighbor of just : buddha, perspective, absolute, placement, page, adequate, entity, lens,
Nearest neighbor of t : i, ammonia, orthography, resignation, glacial, you, predominantly, lisa,
Nearest neighbor of television : reduction, shanghai, footballer, rapidly, op, generals, insects, album,
Nearest neighbor of km : three, one, hunting, isbn, total, five, six, succession,
Nearest neighbor of development : ada, diocese, exclusive, honor, conflict, recovered, dances, employed,
Nearest neighbor of this : it, the, brown, equivalent, so, naturally, converted, a,
Nearest neighbor of western : anna, mughal, kg, claiming, potential, asia, southern, man,
Nearest neighbor of science : underwent, security, burned, asserted, human, anthem, ranges, slavery,
Nearest neighbor of islands : electron, helping, labels, isotope, diameter, execution, beatles, rouge,
Nearest neighbor of under : beijing, defence, harper, dunes, alaska, dictator, clark, reasonably,
Nearest neighbor of generally : speculation, calculations, montreal, products, implied, stanley, designs, spread,
Nearest neighbor of given : loyal, direction, transported, subordinate, graves, connecticut, million, if,
Nearest neighbor of case : daughters, elementary, prevented, emancipation, swedish, era, implications, expect,
Nearest neighbor of must : verbs, safely, anyway, hole, neighboring, bearing, nuremberg, some,
Nearest neighbor of when : arbitration, decides, believed, in, battleships, quotes, reports, wait,
Nearest neighbor of china : please, loan, buses, stuart, improving, ranked, mountain, vessels,

아래는 t-SNE을 사용하여 skip-gram 임베딩을 시각화한 것이다. 좌측 하단에 알파벳들이 모여있는 것이 보인다. 어느정도 관련도가 있는 단어들이 모여 있는것으로 보인다.

관련 글:

댓글

이 블로그의 인기 게시물

간단한 cfar 알고리즘에 대해

쉽게 설명한 파티클 필터(particle filter) 동작 원리와 예제

아두이노(arduino) 심박센서 (heart rate sensor) 심박수 측정 example code

리눅스 디바이스 드라이버 기초와 예제

windows에서 간단하게 크롬캐스트(Chromecast)를 통해 윈도우 화면 미러링 방법