什么是机器学习中的批处理?

数据挖掘 Python 神经网络 深度学习 rnn
2022-02-18 23:57:47

Karpathy 的 LSTM 批处理网络LSTM 批处理网络以批处理方式操作

def checkSequentialMatchesBatch():
    """ check LSTM I/O forward/backward interactions """
    n,b,d = (5, 3, 4) # sequence length, batch size, hidden size
    input_size = 10
    WLSTM = LSTM.init(input_size, d) # input size, hidden size
    X = np.random.randn(n,b,input_size)
    #...

def checkBatchGradient():
    """ check that the batch gradient is correct """
    # lets gradient check this beast
    n,b,d = (5, 3, 4) # sequence length, batch size, hidden size
    input_size = 10
    WLSTM = LSTM.init(input_size, d) # input size, hidden size
    X = np.random.randn(n,b,input_size)
    #...

批量申请什么?我只熟悉输入一个热词表示向量,无法理解批量的 LSTM 学习过程。请在文本处理方面进行说明。

提前致谢。

2个回答

公认的答案是正确的,但从分类的角度考虑批次也可能会有所帮助。

假设您有一个二进制分类问题,您尝试使用多层感知器来解决,每个类有 1000 个示例。

在训练模型时,您不想等到模型看到所有数据后才执行权重更新。这是计算效率低下的。相反,例如,您从每个类中抽取 100 个随机示例,并将其称为“批次”。您在该批次上训练模型,执行权重更新,然后移动到下一个批次,直到您看到训练集中的所有示例。以这种方式通过训练集的一次称为“epoch”。

批处理是数据集中的一组实例。例如,一批 100 个文本样本将被用于一起训练您的模型。