如何获得 Keras LSTM 层的输出?

数据挖掘 神经网络 喀拉斯 nlp lstm 嵌入
2022-02-24 20:55:50

我想获得使用 Keras 在 Python 中构建的网络的 LSTM 层的输出(即向量),并且经过训练可以对句子(即序列)进行分类。我该怎么做 ?

我的尝试如下:

使用该功能是否正确model.predict()我在 Keras 中找到了这个视频LSTM | 了解 LSTM 输入和输出形状,它解释了 LSTM 层的输入(即嵌入层之后)是一个大小向量,(number of sequences, number of inputs, embedding dimension)并且相应的 LSTM 输出具有维度(number of sequences, number of LSTM units)model.predict(encodedsequences_data)在链接的视频中,它通过在 LSTM 层之后使用来获取后一个向量(即 lstm 层的输出向量) 。例如,如果我训练一个神经网络来对正面和负面评论进行分类,如下所示:

import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.layers import Dropout
from keras.layers import LSTM
from keras.layers.embeddings import Embedding

# define documents
docs = np.array(['Well done!',
        'Good work',
        'Great effort',
        'nice work',
        'Excellent!',
        'Weak',
        'Poor effort!',
        'not good',
        'poor work',
        'Could have done better.'])

# define class labels
labels = np.array([1,1,1,1,1,0,0,0,0,0])

# train the tokenizer
tokenizer = Tokenizer()
# fit the tokenizer
tokenizer.fit_on_texts(docs)
# encode the sentences
encoded_docs = tokenizer.texts_to_sequences(docs)

vocab_size=len(tokenizer.word_index)+1 

# pad documents to a max length of 4 words
max_length = 4
padded_docs = pad_sequences(encoded_docs, maxlen=max_length, padding='post')
embedding_dim=100

# define the model
model = Sequential()
model.add(Embedding(vocab_size, embedding_dim, input_length=max_length, name='embeddings'))
model.add(LSTM(64))
output=model.predict(padded_docs)
model.add(Dropout(0.25))
model.add(Dense(64))
model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
model.summary()

# compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# fit the model
model.fit(padded_docs, labels, epochs=50, verbose=2)

我应该将 LSTM 层的输入视为 size 的向量(10, 4, 100)相反,输出是一个向量,其形状由下式给出:

In[10]: output.shape
Out[10]: (10, 64)

这在实践中是这样的numpy.ndarray

In[11]: output
Out[11]: 
array([[-7.35682389e-03, -6.29833259e-04, -2.14141682e-02,
         5.49282366e-03, -1.68905873e-02, -7.86065124e-03,
         1.14580495e-02,  1.25549696e-02, -9.16293636e-03,
         6.39621960e-03, -1.34323994e-02, -2.12187809e-03,
         8.44217744e-03,  1.45898620e-02, -1.40892563e-03,
        -3.41916122e-02, -1.31929619e-02,  9.33299214e-03,
         1.19191762e-02, -1.00926403e-03,  1.26794688e-02,
        -2.21233014e-02, -1.12434607e-02, -4.41948650e-03,
         8.63359030e-03, -1.87689364e-02,  7.90755264e-03,
        -9.07730684e-03, -7.35392375e-03, -8.41679424e-03,
         6.20685238e-03, -3.13799526e-03,  1.42355347e-02,
         3.77556833e-04, -7.31376186e-03,  4.97561414e-03,
        -1.09350188e-02, -7.71270739e-03,  1.80931657e-03,
        -4.15941747e-03,  1.03279343e-02,  1.07305320e-02,
        -5.30181499e-03,  6.01283915e-04, -3.80512699e-03,
        -1.37944380e-02, -5.06241946e-03,  3.49981769e-04,
         6.85051316e-03, -3.79729504e-03,  1.81085169e-02,
        -2.42224871e-03, -1.16188182e-02, -9.10035986e-03,
         7.19825737e-03, -8.12581368e-03,  2.17414256e-02,
         1.95931066e-02, -1.33228181e-02,  5.87116787e-03,
         3.40227783e-02,  2.07215771e-02, -4.87452417e-05,
        -2.10380089e-02],
       [-5.61810751e-03,  1.05204089e-02, -1.52371302e-02,
        -2.87347310e-03, -1.25838947e-02, -1.18132159e-02,
         6.62136683e-03,  1.97880785e-03, -6.42609270e-03,
         1.00167897e-02, -1.96973495e-02, -8.16532318e-03,
         8.49343836e-03,  1.55395102e-02, -5.61557105e-03,
        -3.35718431e-02, -1.70064531e-02, -3.50781716e-03,
         1.04237692e-02,  3.42868188e-05,  1.28740557e-02,
        -2.28579864e-02, -3.93496035e-03, -1.70960138e-03,
         1.12372153e-02, -1.32894730e-02,  1.38472551e-02,
        -1.48426415e-02, -1.18360845e-02, -8.26376118e-03,
         2.58601783e-03, -1.19493445e-02,  9.43981484e-03,
         6.56166160e-03, -4.16467572e-03,  1.13104628e-02,
        -1.37887159e-02, -7.19879614e-03, -2.71008583e-04,
        -7.01515237e-03,  1.21428408e-02,  1.30166933e-02,
        -6.22257078e-03, -7.79333059e-03,  1.75176319e-04,
        -1.44108124e-02, -7.61069916e-03,  7.94319343e-03,
        -3.57268099e-03,  5.53885289e-03,  1.59333441e-02,
        -6.24595489e-03, -1.24001894e-02, -6.18426641e-03,
        -2.75318534e-03, -9.32588615e-03,  2.57411599e-02,
         1.81322880e-02, -1.50167728e-02,  2.56175001e-04,
         2.89641470e-02,  1.89595874e-02, -5.74650709e-03,
        -1.62386764e-02],
       [-1.57758372e-03,  6.51306845e-03, -1.83870271e-02,
        -2.17917864e-03, -1.48072215e-02, -6.46742154e-03,
         2.84540933e-03,  1.15868803e-02, -8.95035919e-03,
         1.18787242e-02, -1.53372595e-02, -1.28116598e-03,
         1.38210272e-02,  1.25349807e-02,  3.48864007e-03,
        -2.82158218e-02, -1.32004097e-02,  3.89576564e-03,
         1.16172284e-02,  6.00448996e-03,  1.25120645e-02,
        -2.08042618e-02, -9.06666648e-03, -9.38181765e-03,
         5.93809411e-03, -1.40472483e-02,  1.23529136e-02,
        -6.72566192e-03, -1.32737402e-02, -7.24557228e-03,
        -1.34050148e-03, -7.91837182e-03,  8.88528675e-03,
         4.43336181e-03, -3.90838226e-03,  7.95213319e-03,
        -1.97365284e-02, -7.12051382e-03, -1.71131018e-04,
         9.33982781e-04,  1.24091003e-02,  6.72526145e-03,
        -8.91984720e-03, -6.33321749e-03, -3.09348427e-04,
        -1.35311736e-02, -2.99455877e-03,  4.07836633e-03,
        -1.31862273e-03, -2.41302908e-03,  8.97983275e-03,
         3.61930369e-03, -5.18017821e-03, -6.46935450e-03,
        -8.58186861e-04, -2.87145306e-03,  1.99557561e-02,
         2.37323977e-02, -1.33994315e-02,  8.58740974e-03,
         2.87993196e-02,  2.31237561e-02, -9.56221425e-04,
        -1.58641450e-02],
       [-7.23353820e-03,  1.03754746e-02, -1.44966058e-02,
        -5.24685532e-03, -7.83862360e-03, -1.28473695e-02,
         7.85209332e-03,  4.59963316e-03, -6.33948809e-03,
         1.19380560e-02, -2.20133327e-02, -1.17637683e-02,
         4.68229316e-03,  1.56141464e-02, -2.19842512e-03,
        -3.82021852e-02, -1.75803266e-02, -1.37230149e-03,
         9.44487844e-03, -3.54365772e-03,  1.21249473e-02,
        -2.55316924e-02, -9.54593590e-04, -4.05333284e-03,
         1.15399351e-02, -7.02163018e-03,  1.45128630e-02,
        -1.76015086e-02, -1.56135382e-02, -6.89028949e-03,
         2.78898864e-03, -6.99541951e-03,  1.22199748e-02,
         7.81757478e-03, -3.65910749e-03,  9.27608926e-03,
        -1.18314605e-02, -7.95996375e-03, -5.45767276e-03,
         2.41609639e-03,  1.30290538e-02,  1.31718945e-02,
        -9.51540284e-03, -9.02444124e-03, -4.85872338e-03,
        -1.42892599e-02, -5.70658036e-03,  5.16991317e-03,
         1.60913187e-04,  5.72681241e-03,  1.93462688e-02,
        -7.89363962e-03, -1.39182042e-02, -1.19014597e-02,
        -3.61843006e-04, -9.27691348e-03,  2.38749031e-02,
         2.55510695e-02, -1.21669313e-02,  2.98210932e-03,
         3.11383940e-02,  2.01738160e-02, -1.07578654e-03,
        -1.72522105e-02],
       [-1.20945638e-02,  8.68875161e-03, -2.62735188e-02,
         1.12946620e-02, -1.22309905e-02, -5.58001362e-03,
         1.08549614e-02,  8.95095989e-03, -7.41497055e-03,
         1.16792703e-02, -2.28939448e-02, -3.50253959e-03,
         9.97900032e-03,  1.66020077e-02, -7.89092761e-03,
        -4.07132506e-02, -1.78568307e-02,  4.34662355e-03,
         1.21228509e-02,  3.13125411e-03,  1.45842955e-02,
        -2.05406677e-02, -8.31084419e-03, -4.01034905e-03,
         7.12802447e-03, -1.84220411e-02,  1.13592790e-02,
        -1.26814209e-02, -1.18051590e-02, -8.80334340e-03,
         3.80245410e-03, -8.09166487e-03,  1.76429395e-02,
        -1.92356238e-04, -5.44061745e-03,  1.36529943e-02,
        -1.77526288e-02, -7.82714784e-03,  1.20329263e-03,
        -4.59810719e-03,  1.16233192e-02,  1.27270408e-02,
        -7.62180379e-03, -9.63985734e-03, -7.18248449e-03,
        -2.18094457e-02, -1.00004785e-02, -1.08398555e-03,
         3.66409164e-04, -2.32298975e-03,  1.85762774e-02,
        -3.05683468e-03, -1.08189192e-02, -8.16119835e-03,
         2.15096259e-03, -1.27328513e-02,  3.22906636e-02,
         2.98772007e-02, -1.85494740e-02, -6.53657946e-04,
         3.88556197e-02,  2.41267327e-02, -4.84270183e-03,
        -2.17875950e-02],
       [-5.25501464e-03,  8.72639474e-03, -2.05799732e-02,
         6.14312338e-03, -1.72808673e-02, -1.31866131e-02,
         5.83314849e-03,  6.02243049e-03, -3.84115218e-03,
         1.41651463e-02, -2.30942499e-02, -5.55078313e-03,
         1.04255294e-02,  2.30545755e-02, -2.18844530e-03,
        -3.62902768e-02, -1.92042850e-02,  1.22342422e-03,
         1.13154203e-02, -2.17898004e-03,  1.37291076e-02,
        -2.04059239e-02, -8.93826876e-03, -3.92630836e-03,
         8.99609085e-03, -1.99639872e-02,  9.40152071e-03,
        -1.68196503e-02, -1.08547490e-02, -4.02889075e-03,
         1.14458892e-02, -9.00383666e-03,  1.34997619e-02,
         3.90134053e-03, -6.80747256e-03,  9.08577070e-03,
        -1.72130466e-02, -1.32173812e-02, -1.73806830e-03,
        -4.78662038e-03,  1.92851052e-02,  9.79739800e-03,
        -4.36533149e-03, -8.25465377e-03, -3.25881690e-03,
        -1.42115178e-02, -3.79414624e-03,  3.12958076e-03,
         2.55106570e-04,  1.77397695e-03,  2.18093134e-02,
        -1.31484016e-03, -1.34956567e-02, -5.44582447e-03,
         2.86075217e-03, -2.14463435e-02,  3.33723240e-02,
         2.67816409e-02, -1.40035218e-02,  4.13230434e-03,
         3.57808806e-02,  2.61333492e-02, -8.73044250e-04,
        -2.66207643e-02],
       [-7.28952140e-03,  3.69984098e-03, -1.50594991e-02,
        -2.92313565e-03, -1.21962698e-02, -1.36810802e-02,
         6.82729529e-03,  6.73223054e-03, -5.93179138e-04,
         1.36305839e-02, -1.97144989e-02, -3.86482291e-03,
         1.80094223e-02,  1.42702283e-02, -1.32999849e-03,
        -3.40334214e-02, -1.76202524e-02,  2.30349993e-04,
         1.09610325e-02,  5.54277329e-03,  7.91644678e-03,
        -2.16093995e-02, -1.39994686e-02, -3.96339269e-03,
         8.55575502e-03, -9.74893570e-03,  1.05637731e-02,
        -4.55878396e-03, -1.34231234e-02, -5.41363843e-04,
         4.89153492e-04, -6.66437577e-03,  1.07367048e-02,
         3.43973073e-03, -4.18765657e-03,  6.89268531e-03,
        -1.19383521e-02, -8.89711361e-03, -4.44964785e-03,
         5.09598665e-03,  1.03713274e-02,  8.53536651e-03,
        -8.85974988e-03, -5.51856030e-03, -6.80169091e-04,
        -1.59089398e-02,  4.49734647e-03,  5.99729270e-03,
         2.46776640e-03, -1.70147279e-03,  5.19089587e-03,
         1.00052624e-03, -5.67172188e-03, -3.38913826e-03,
         2.77179107e-03, -6.60816021e-03,  2.12577283e-02,
         2.13659275e-02, -1.32223312e-02,  2.25704932e-03,
         2.57163309e-02,  2.04398781e-02, -1.69396808e-04,
        -1.73870064e-02],
       [-3.72720137e-03,  1.10780252e-02, -1.58859324e-02,
         6.66555809e-03, -1.44280717e-02, -8.00609123e-03,
         6.45924266e-03,  8.54926067e-04, -1.88849677e-04,
         8.98782630e-03, -1.54129220e-02, -1.08604110e-03,
         1.07514234e-02,  1.73932631e-02, -1.32540641e-02,
        -3.20286751e-02, -1.51650719e-02, -4.61112056e-03,
         1.35826627e-02,  2.19380498e-04,  1.40464576e-02,
        -1.90667752e-02, -4.66079684e-03, -3.83105595e-03,
         7.03095598e-03, -1.30360266e-02,  1.06415655e-02,
        -9.59097221e-03, -9.75623727e-03, -7.97409564e-03,
         5.58470841e-03, -9.38377157e-03,  7.96551071e-03,
        -4.83714510e-04, -2.05025147e-03,  8.32131505e-03,
        -1.56590100e-02, -1.04131605e-02, -3.17667797e-03,
        -5.72070433e-03,  1.51903769e-02,  1.19651994e-02,
        -4.08067275e-03, -1.01910857e-02,  3.11617507e-03,
        -1.46057652e-02, -4.41839360e-03,  5.06953197e-03,
        -4.96832095e-03,  7.81613402e-03,  1.70495212e-02,
         2.49758250e-05, -7.83495232e-03, -4.50461730e-03,
        -2.93536019e-03, -1.30293816e-02,  2.36500446e-02,
         1.58338025e-02, -1.86773948e-02,  7.61539035e-04,
         2.71951333e-02,  1.39996661e-02, -5.14069805e-03,
        -1.78940073e-02],
       [-6.54364191e-03,  5.09772869e-03, -1.04034524e-02,
        -3.73475719e-03, -1.19625181e-02, -1.40891932e-02,
         9.06457752e-03,  3.41503625e-03, -4.86045564e-03,
         1.25511140e-02, -2.17867326e-02, -7.47563411e-03,
         1.25360051e-02,  1.62971858e-02, -3.65820760e-03,
        -3.51158790e-02, -1.73445251e-02, -2.78897071e-03,
         6.79790135e-03,  3.37651756e-04,  1.03165153e-02,
        -2.39069872e-02, -7.28563499e-03, -7.42059347e-05,
         1.29556786e-02, -1.14473784e-02,  1.45603633e-02,
        -1.20491283e-02, -1.42659768e-02, -3.70843848e-03,
         4.51912638e-03, -7.15361023e-03,  1.15718255e-02,
         6.00760849e-03, -6.92916662e-03,  7.49018928e-03,
        -1.01785287e-02, -8.63973703e-03, -4.90589021e-03,
        -4.61561285e-04,  1.11296903e-02,  9.36738681e-03,
        -7.41511211e-03, -6.98906882e-03, -3.12256836e-03,
        -1.77129637e-02, -7.80898728e-04,  9.68260039e-03,
         3.10223433e-03,  4.86978563e-03,  1.11100767e-02,
        -9.04789940e-03, -1.10857347e-02, -6.04295917e-03,
         3.34106589e-04, -8.40245467e-03,  2.65593827e-02,
         1.84700415e-02, -1.15374802e-02, -1.12538924e-03,
         2.62181889e-02,  1.91304646e-02, -2.57981522e-03,
        -1.63868871e-02],
       [ 8.44749855e-04, -1.66689996e-02, -8.96400423e-04,
         6.72107562e-03, -2.39076628e-03,  2.60190992e-03,
         9.26916581e-03,  8.82431399e-03, -7.25202635e-03,
        -1.06594963e-02, -5.55762183e-03, -1.29527331e-03,
        -1.64227630e-03,  2.20268476e-03,  1.16184941e-02,
        -1.13204587e-02,  5.73474122e-03,  1.36803361e-02,
         5.84689900e-03,  9.08445287e-03, -1.40777905e-03,
        -1.06005892e-02, -3.84641928e-03,  2.70416541e-03,
         4.00838861e-03,  2.82439473e-03,  2.29762960e-03,
        -1.00108888e-03, -4.37536306e-04, -5.20851184e-03,
         4.79862979e-03, -4.48886072e-03,  8.50347336e-03,
        -1.42790927e-02, -1.26732215e-02,  4.63177776e-03,
        -1.20126840e-03,  5.99153014e-03,  1.22683365e-02,
         3.37655889e-04, -1.39692798e-02,  3.31070763e-03,
        -5.39805554e-03,  2.78319512e-02, -6.16331352e-03,
         5.77836717e-03, -1.19619851e-03, -1.58095714e-02,
         3.77729093e-03,  1.85370538e-03,  6.35961816e-03,
         3.49745946e-03, -1.27604958e-02, -9.54155251e-03,
         1.23564843e-02, -7.57912593e-03, -3.30073433e-03,
        -7.29874987e-03, -6.47724420e-03, -1.41964471e-02,
         1.36867436e-02,  5.99695602e-03,  9.53863724e-04,
         5.32696443e-03]], dtype=float32)

这个(多维)向量是否表示 LSTM 层的输出,是否可以选择每个内部向量作为相应句子的向量表示(即 vector1 -> sentence1 等等)?

根据我链接的 youtube 视频,这似乎是正确的,而且我还发现了其他似乎做同样的课程(例如在这个博客LSTM:了解输出类型)。但我并不相信。我的疑问是为什么该函数model.predict()会给出 LSTM 层的输出?Pythonhelp告诉这个函数:

为输入样本生成输出预测。

或者也在3.6 中。scikit-learn:Python 中的机器学习

在有监督的估计器中::model.predict()给定一个训练好的模型,预测一组新数据的标签。此方法接受一个参数,即新数据X_new(例如model.predict(X_new)),并返回数组中每个对象的学习标签。

那么,鉴于这些解释,为什么要model.predict()给出 LSTM 输出向量呢?究竟model.predict()是做什么的?

综上所述,我想知道我报告使用的程序model.predict()是否正确获取 Keras 中 LSTM 层的输出以及为什么。此外,如果不是,您是否可以建议正确的程序。

提前非常感谢。

----编辑----:另一种方法是使用https://keras.io/getting_started/faq/#how-can-i-obtain-the-output-of-an-intermediate-layer中的解释-特征提取

intermediate_layer_model = keras.Model(inputs=model.input,
                                       outputs=model.get_layer(layer_name).output)
intermediate_output = intermediate_layer_model(data)

使用layer_name = 我的 lstm 层的名称data = padded_docs我认为这个过程肯定是正确的,因为它在 Keras 文档中。而且我也认为它类似于第一种方法model.predict(padded_docs)y=model(x)在同一个 Keras 链接页面中,使用和之间存在区别y=model.predict(x),它表示它们都表示“在 x 上运行模型并检索输出 y”。

因此,总而言之,也许他们都通过模型或模型运行输入数据(在这种情况下是我们的句子)直到感兴趣的层,​​并通过具有参数(权重)的模型检索输出处理数据,这些数据在之后更新培训 ?这是否意味着获得一个层的输出(在这种情况下是一个 LSTM 层)?

1个回答

在这种情况下,首先将“docs”文件中的每个文本编码为词汇大小范围内的某些数字,然后将输出数组用零填充以使这些数字达到 max_length 大小。如果您检查单个文本的填充输出,它将如下所示

array([6, 2, 0, 0])

您已将输出数组的向量维度设置为 100。这意味着上述填充数组中的每个元素都将转换为 100 维度。现在您正在使用 keras 定义 LSTM 神经网络。如果您检查输出形状,它将给出一个大小为 (10, 4, 100) 的数组。这意味着 10 个长度为 4 的输入样本已转换为 100 个维度。

最后,在使用 padded_docs 作为输入和标签作为目标变量来拟合模型之后,您可以预测一些应该转换为 padded_docs 格式的新文档文件。只有这样 LSTM 层才能使用经过训练的模型进行预测。预测输出将为您提供 (0,1) 范围内的值,但不完全是 0 或 1。下面的输出显示了 0.9、0.8 等值。

array([[9.9962938e-01],
   [8.5913503e-01],
   [9.9966836e-01],
   [9.9902046e-01],
   [5.9763002e-01],
   [5.9763002e-01],
   [2.4047494e-04],
   [4.7051907e-04],
   [6.3633323e-03],
   [3.6294390e-05]], dtype=float32)

希望它能给问题一些启发