网站首页 文章专栏 LSTM 凑齐 batch 大小
LSTM 凑齐 batch 大小
创建于:2021-07-04 08:29:46 更新于:2021-12-02 14:00:19 羽瀚尘 81
深度学习 深度学习,LSTM

简介

如果要设计一些lstm网络,如果输入不是整数batch,即每次的输入不是设计的128等数字,循环神经网络就会出错。这里给出两种方法。

方法1

该方法设计得比较早,原理就是判断需要的batch size和数据集的size。其中X和y为总体数据集。基本功能就是让最后一个batch的大小固定,特点就是即使batch比整个数据集大也能通过重复数据的方法补足batch。

def evaluate(self, X, y, batch_size=None, writer=None, step=None):
        N = X.shape[0]
        if batch_size == None:
            batch_size = N
        
        total_loss = 0

        true_labels = []
        predict_values = []
        
        for i in range(0, N, batch_size):
            
            if i + batch_size <= N:
                X_batch = X[i:i + batch_size]
                y_batch = y[i:i + batch_size]
            else:
                padding =  i + batch_size - N 
                

                # if padding size is larger than N 
                if padding > N:
                    times = int(padding / N) 
                    padding = padding - times * N 

                    N = N + times*N + padding
                    X_batch = np.concatenate((X[i:i + batch_size], \
                                np.tile(X,(times,1)), X[:padding]))
                    y_batch = np.concatenate((y[i:i + batch_size], \
                                np.tile(y,(times,1)), y[:padding]))
                else:
                    N = N  + padding
                    X_batch = np.concatenate((X[i:i + batch_size],  X[:padding]))
                    y_batch = np.concatenate((y[i:i + batch_size],  y[:padding]))
            feed = {
                self.model.X: X_batch,
                self.model.y: y_batch,
                self.model.training: False
            }
            
            loss = self.model.loss
            summary_op = self.model.summary_op
            
            step_loss, y_true, y_pred, summary = self.sess.run([loss,  \
                                                            self.model.y, \
                                                            self.model.prob,  \
                                                            summary_op], feed_dict=feed)
            true_labels.append(y_true)
            predict_values.append(y_pred > config.threshold)

           
            total_loss += step_loss * X_batch.shape[0]
            
            
            if writer:
                writer.add_summary(summary, step)
      
        total_loss /= N

        # 计算所有的f1

        '''
        下行运算完后,true_labels各维度为[batch大小, 类别数目, val分了多少次运行完毕]
        '''
        true_labels = np.array(true_labels) 
        true_labels = np.reshape(true_labels, (-1,true_labels.shape[-2]))

        predict_values = np.array(predict_values)
        predict_values = np.reshape(predict_values, (-1,predict_values.shape[-2]))

        total_acc = f1_score(true_labels, predict_values, average='samples')
        report = classification_report(y_true=true_labels, y_pred=predict_values)
        
        return total_loss, total_acc, report

方法2

该方法是20190922训练ECG时想出来的,优点是简单,缺点是不能应对batch设置得比整个数据集还要大的情况(不过这种情况也没那么容易发生).

import numpy as np
start_idx = np.random.randint(total_num - batch_size)
batch_data = total_data[start_idx : start_idx + batch_size]