网站首页 文章专栏 LSTM 凑齐 batch 大小
如果要设计一些lstm网络,如果输入不是整数batch,即每次的输入不是设计的128等数字,循环神经网络就会出错。这里给出两种方法。
该方法设计得比较早,原理就是判断需要的batch size和数据集的size。其中X和y为总体数据集。基本功能就是让最后一个batch的大小固定,特点就是即使batch比整个数据集大也能通过重复数据的方法补足batch。
def evaluate(self, X, y, batch_size=None, writer=None, step=None):
N = X.shape[0]
if batch_size == None:
batch_size = N
total_loss = 0
true_labels = []
predict_values = []
for i in range(0, N, batch_size):
if i + batch_size <= N:
X_batch = X[i:i + batch_size]
y_batch = y[i:i + batch_size]
else:
padding = i + batch_size - N
# if padding size is larger than N
if padding > N:
times = int(padding / N)
padding = padding - times * N
N = N + times*N + padding
X_batch = np.concatenate((X[i:i + batch_size], \
np.tile(X,(times,1)), X[:padding]))
y_batch = np.concatenate((y[i:i + batch_size], \
np.tile(y,(times,1)), y[:padding]))
else:
N = N + padding
X_batch = np.concatenate((X[i:i + batch_size], X[:padding]))
y_batch = np.concatenate((y[i:i + batch_size], y[:padding]))
feed = {
self.model.X: X_batch,
self.model.y: y_batch,
self.model.training: False
}
loss = self.model.loss
summary_op = self.model.summary_op
step_loss, y_true, y_pred, summary = self.sess.run([loss, \
self.model.y, \
self.model.prob, \
summary_op], feed_dict=feed)
true_labels.append(y_true)
predict_values.append(y_pred > config.threshold)
total_loss += step_loss * X_batch.shape[0]
if writer:
writer.add_summary(summary, step)
total_loss /= N
# 计算所有的f1
'''
下行运算完后,true_labels各维度为[batch大小, 类别数目, val分了多少次运行完毕]
'''
true_labels = np.array(true_labels)
true_labels = np.reshape(true_labels, (-1,true_labels.shape[-2]))
predict_values = np.array(predict_values)
predict_values = np.reshape(predict_values, (-1,predict_values.shape[-2]))
total_acc = f1_score(true_labels, predict_values, average='samples')
report = classification_report(y_true=true_labels, y_pred=predict_values)
return total_loss, total_acc, report
该方法是20190922训练ECG时想出来的,优点是简单,缺点是不能应对batch设置得比整个数据集还要大的情况(不过这种情况也没那么容易发生).
import numpy as np
start_idx = np.random.randint(total_num - batch_size)
batch_data = total_data[start_idx : start_idx + batch_size]