Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 34 additions & 34 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

import argparse
import torch
import pickle
import numpy as np
import os
import math
import random
import pickle
import numpy as np
import os
import math
import random
import sys
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.io

import data
import data

from sklearn.decomposition import PCA
from torch import nn, optim
Expand Down Expand Up @@ -127,10 +127,10 @@
test_2_rnn_inp = data.get_rnn_input(
test_2_tokens, test_2_counts, test_2_times, args.num_times, args.vocab_size, args.num_docs_test)

## get embeddings
## get embeddings
print('Getting embeddings ...')
emb_path = args.emb_path
vect_path = os.path.join(args.data_path.split('/')[0], 'embeddings.pkl')
vect_path = os.path.join(args.data_path.split('/')[0], 'embeddings.pkl')
vectors = {}
with open(emb_path, 'rb') as f:
for l in f:
Expand All @@ -142,7 +142,7 @@
embeddings = np.zeros((vocab_size, args.emb_size))
words_found = 0
for i, word in enumerate(vocab):
try:
try:
embeddings[i] = vectors[word]
words_found += 1
except KeyError:
Expand All @@ -162,9 +162,9 @@
if args.mode == 'eval':
ckpt = args.load_from
else:
ckpt = os.path.join(args.save_path,
ckpt = os.path.join(args.save_path,
'detm_{}_K_{}_Htheta_{}_Optim_{}_Clip_{}_ThetaAct_{}_Lr_{}_Bsz_{}_RhoSize_{}_L_{}_minDF_{}_trainEmbeddings_{}'.format(
args.dataset, args.num_topics, args.t_hidden_size, args.optimizer, args.clip, args.theta_act,
args.dataset, args.num_topics, args.t_hidden_size, args.optimizer, args.clip, args.theta_act,
args.lr, args.batch_size, args.rho_size, args.eta_nlayers, args.min_df, args.train_embeddings))

## define model and optimizer
Expand Down Expand Up @@ -202,7 +202,7 @@ def train(epoch):
acc_kl_alpha_loss = 0
cnt = 0
indices = torch.randperm(args.num_docs_train)
indices = torch.split(indices, args.batch_size)
indices = torch.split(indices, args.batch_size)
for idx, ind in enumerate(indices):
optimizer.zero_grad()
model.zero_grad()
Expand All @@ -228,20 +228,20 @@ def train(epoch):
cnt += 1

if idx % args.log_interval == 0 and idx > 0:
cur_loss = round(acc_loss / cnt, 2)
cur_nll = round(acc_nll / cnt, 2)
cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
cur_kl_eta = round(acc_kl_eta_loss / cnt, 2)
cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2)
cur_loss = round(acc_loss / cnt, 2)
cur_nll = round(acc_nll / cnt, 2)
cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
cur_kl_eta = round(acc_kl_eta_loss / cnt, 2)
cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2)
lr = optimizer.param_groups[0]['lr']
print('Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format(
epoch, idx, len(indices), lr, cur_kl_theta, cur_kl_eta, cur_kl_alpha, cur_nll, cur_loss))
cur_loss = round(acc_loss / cnt, 2)
cur_nll = round(acc_nll / cnt, 2)
cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
cur_kl_eta = round(acc_kl_eta_loss / cnt, 2)
cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2)

cur_loss = round(acc_loss / cnt, 2)
cur_nll = round(acc_nll / cnt, 2)
cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
cur_kl_eta = round(acc_kl_eta_loss / cnt, 2)
cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2)
lr = optimizer.param_groups[0]['lr']
print('*'*100)
print('Epoch----->{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format(
Expand All @@ -254,7 +254,7 @@ def visualize():
model.eval()
with torch.no_grad():
alpha = model.mu_q_alpha
beta = model.get_beta(alpha)
beta = model.get_beta(alpha)
print('beta: ', beta.size())
print('\n')
print('#'*100)
Expand All @@ -267,7 +267,7 @@ def visualize():
top_words = list(gamma.cpu().numpy().argsort()[-args.num_words+1:][::-1])
topic_words = [vocab[a] for a in top_words]
topics_words.append(' '.join(topic_words))
print('Topic {} .. Time: {} ===> {}'.format(k, t, topic_words))
print('Topic {} .. Time: {} ===> {}'.format(k, t, topic_words))

print('\n')
print('Visualize word embeddings ...')
Expand All @@ -284,8 +284,8 @@ def visualize():

# print('\n')
# print('Visualize word evolution ...')
# topic_0 = None ### k
# queries_0 = ['woman', 'gender', 'man', 'mankind', 'humankind'] ### v
# topic_0 = None ### k
# queries_0 = ['woman', 'gender', 'man', 'mankind', 'humankind'] ### v

# topic_1 = None
# queries_1 = ['africa', 'colonial', 'racist', 'democratic']
Expand Down Expand Up @@ -329,7 +329,7 @@ def get_theta(eta, bows):
q_theta = model.q_theta(inp)
mu_theta = model.mu_q_theta(q_theta)
theta = F.softmax(mu_theta, dim=-1)
return theta
return theta

def get_completion_ppl(source):
"""Returns document completion perplexity.
Expand Down Expand Up @@ -375,7 +375,7 @@ def get_completion_ppl(source):
print('{} PPL: {}'.format(source.upper(), ppl_all))
print('*'*100)
return ppl_all
else:
else:
indices = torch.split(torch.tensor(range(args.num_docs_test)), args.eval_batch_size)
tokens_1 = test_1_tokens
counts_1 = test_1_counts
Expand Down Expand Up @@ -440,7 +440,7 @@ def get_topic_quality():
model.eval()
with torch.no_grad():
alpha = model.mu_q_alpha
beta = model.get_beta(alpha)
beta = model.get_beta(alpha)
print('beta: ', beta.size())

print('\n')
Expand Down Expand Up @@ -504,17 +504,17 @@ def get_topic_quality():
scipy.io.savemat(ckpt+'_beta.mat', {'values': beta}, do_compression=True)
if args.train_embeddings:
print('saving word embedding matrix rho...')
rho = model.rho.weight.cpu().numpy()
rho = model.rho.weight.cpu().detach().numpy()
scipy.io.savemat(ckpt+'_rho.mat', {'values': rho}, do_compression=True)
print('computing validation perplexity...')
val_ppl = get_completion_ppl('val')
print('computing test perplexity...')
test_ppl = get_completion_ppl('test')
else:
else:
with open(ckpt, 'rb') as f:
model = torch.load(f)
model = model.to(device)

print('saving alpha...')
with torch.no_grad():
alpha = model.mu_q_alpha.cpu().numpy()
Expand Down