我想获得使用 Keras 在 Python 中构建的网络的 LSTM 层的输出(即向量),并且经过训练可以对句子(即序列)进行分类。我该怎么做 ?
我的尝试如下:
使用该功能是否正确model.predict()?我在 Keras 中找到了这个视频LSTM | 了解 LSTM 输入和输出形状,它解释了 LSTM 层的输入(即嵌入层之后)是一个大小向量,(number of sequences, number of inputs, embedding dimension)并且相应的 LSTM 输出具有维度(number of sequences, number of LSTM units)。model.predict(encodedsequences_data)在链接的视频中,它通过在 LSTM 层之后使用来获取后一个向量(即 lstm 层的输出向量) 。例如,如果我训练一个神经网络来对正面和负面评论进行分类,如下所示:
import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.layers import Dropout
from keras.layers import LSTM
from keras.layers.embeddings import Embedding
# define documents
docs = np.array(['Well done!',
'Good work',
'Great effort',
'nice work',
'Excellent!',
'Weak',
'Poor effort!',
'not good',
'poor work',
'Could have done better.'])
# define class labels
labels = np.array([1,1,1,1,1,0,0,0,0,0])
# train the tokenizer
tokenizer = Tokenizer()
# fit the tokenizer
tokenizer.fit_on_texts(docs)
# encode the sentences
encoded_docs = tokenizer.texts_to_sequences(docs)
vocab_size=len(tokenizer.word_index)+1
# pad documents to a max length of 4 words
max_length = 4
padded_docs = pad_sequences(encoded_docs, maxlen=max_length, padding='post')
embedding_dim=100
# define the model
model = Sequential()
model.add(Embedding(vocab_size, embedding_dim, input_length=max_length, name='embeddings'))
model.add(LSTM(64))
output=model.predict(padded_docs)
model.add(Dropout(0.25))
model.add(Dense(64))
model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
model.summary()
# compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# fit the model
model.fit(padded_docs, labels, epochs=50, verbose=2)
我应该将 LSTM 层的输入视为 size 的向量(10, 4, 100)。相反,输出是一个向量,其形状由下式给出:
In[10]: output.shape
Out[10]: (10, 64)
这在实践中是这样的numpy.ndarray:
In[11]: output
Out[11]:
array([[-7.35682389e-03, -6.29833259e-04, -2.14141682e-02,
5.49282366e-03, -1.68905873e-02, -7.86065124e-03,
1.14580495e-02, 1.25549696e-02, -9.16293636e-03,
6.39621960e-03, -1.34323994e-02, -2.12187809e-03,
8.44217744e-03, 1.45898620e-02, -1.40892563e-03,
-3.41916122e-02, -1.31929619e-02, 9.33299214e-03,
1.19191762e-02, -1.00926403e-03, 1.26794688e-02,
-2.21233014e-02, -1.12434607e-02, -4.41948650e-03,
8.63359030e-03, -1.87689364e-02, 7.90755264e-03,
-9.07730684e-03, -7.35392375e-03, -8.41679424e-03,
6.20685238e-03, -3.13799526e-03, 1.42355347e-02,
3.77556833e-04, -7.31376186e-03, 4.97561414e-03,
-1.09350188e-02, -7.71270739e-03, 1.80931657e-03,
-4.15941747e-03, 1.03279343e-02, 1.07305320e-02,
-5.30181499e-03, 6.01283915e-04, -3.80512699e-03,
-1.37944380e-02, -5.06241946e-03, 3.49981769e-04,
6.85051316e-03, -3.79729504e-03, 1.81085169e-02,
-2.42224871e-03, -1.16188182e-02, -9.10035986e-03,
7.19825737e-03, -8.12581368e-03, 2.17414256e-02,
1.95931066e-02, -1.33228181e-02, 5.87116787e-03,
3.40227783e-02, 2.07215771e-02, -4.87452417e-05,
-2.10380089e-02],
[-5.61810751e-03, 1.05204089e-02, -1.52371302e-02,
-2.87347310e-03, -1.25838947e-02, -1.18132159e-02,
6.62136683e-03, 1.97880785e-03, -6.42609270e-03,
1.00167897e-02, -1.96973495e-02, -8.16532318e-03,
8.49343836e-03, 1.55395102e-02, -5.61557105e-03,
-3.35718431e-02, -1.70064531e-02, -3.50781716e-03,
1.04237692e-02, 3.42868188e-05, 1.28740557e-02,
-2.28579864e-02, -3.93496035e-03, -1.70960138e-03,
1.12372153e-02, -1.32894730e-02, 1.38472551e-02,
-1.48426415e-02, -1.18360845e-02, -8.26376118e-03,
2.58601783e-03, -1.19493445e-02, 9.43981484e-03,
6.56166160e-03, -4.16467572e-03, 1.13104628e-02,
-1.37887159e-02, -7.19879614e-03, -2.71008583e-04,
-7.01515237e-03, 1.21428408e-02, 1.30166933e-02,
-6.22257078e-03, -7.79333059e-03, 1.75176319e-04,
-1.44108124e-02, -7.61069916e-03, 7.94319343e-03,
-3.57268099e-03, 5.53885289e-03, 1.59333441e-02,
-6.24595489e-03, -1.24001894e-02, -6.18426641e-03,
-2.75318534e-03, -9.32588615e-03, 2.57411599e-02,
1.81322880e-02, -1.50167728e-02, 2.56175001e-04,
2.89641470e-02, 1.89595874e-02, -5.74650709e-03,
-1.62386764e-02],
[-1.57758372e-03, 6.51306845e-03, -1.83870271e-02,
-2.17917864e-03, -1.48072215e-02, -6.46742154e-03,
2.84540933e-03, 1.15868803e-02, -8.95035919e-03,
1.18787242e-02, -1.53372595e-02, -1.28116598e-03,
1.38210272e-02, 1.25349807e-02, 3.48864007e-03,
-2.82158218e-02, -1.32004097e-02, 3.89576564e-03,
1.16172284e-02, 6.00448996e-03, 1.25120645e-02,
-2.08042618e-02, -9.06666648e-03, -9.38181765e-03,
5.93809411e-03, -1.40472483e-02, 1.23529136e-02,
-6.72566192e-03, -1.32737402e-02, -7.24557228e-03,
-1.34050148e-03, -7.91837182e-03, 8.88528675e-03,
4.43336181e-03, -3.90838226e-03, 7.95213319e-03,
-1.97365284e-02, -7.12051382e-03, -1.71131018e-04,
9.33982781e-04, 1.24091003e-02, 6.72526145e-03,
-8.91984720e-03, -6.33321749e-03, -3.09348427e-04,
-1.35311736e-02, -2.99455877e-03, 4.07836633e-03,
-1.31862273e-03, -2.41302908e-03, 8.97983275e-03,
3.61930369e-03, -5.18017821e-03, -6.46935450e-03,
-8.58186861e-04, -2.87145306e-03, 1.99557561e-02,
2.37323977e-02, -1.33994315e-02, 8.58740974e-03,
2.87993196e-02, 2.31237561e-02, -9.56221425e-04,
-1.58641450e-02],
[-7.23353820e-03, 1.03754746e-02, -1.44966058e-02,
-5.24685532e-03, -7.83862360e-03, -1.28473695e-02,
7.85209332e-03, 4.59963316e-03, -6.33948809e-03,
1.19380560e-02, -2.20133327e-02, -1.17637683e-02,
4.68229316e-03, 1.56141464e-02, -2.19842512e-03,
-3.82021852e-02, -1.75803266e-02, -1.37230149e-03,
9.44487844e-03, -3.54365772e-03, 1.21249473e-02,
-2.55316924e-02, -9.54593590e-04, -4.05333284e-03,
1.15399351e-02, -7.02163018e-03, 1.45128630e-02,
-1.76015086e-02, -1.56135382e-02, -6.89028949e-03,
2.78898864e-03, -6.99541951e-03, 1.22199748e-02,
7.81757478e-03, -3.65910749e-03, 9.27608926e-03,
-1.18314605e-02, -7.95996375e-03, -5.45767276e-03,
2.41609639e-03, 1.30290538e-02, 1.31718945e-02,
-9.51540284e-03, -9.02444124e-03, -4.85872338e-03,
-1.42892599e-02, -5.70658036e-03, 5.16991317e-03,
1.60913187e-04, 5.72681241e-03, 1.93462688e-02,
-7.89363962e-03, -1.39182042e-02, -1.19014597e-02,
-3.61843006e-04, -9.27691348e-03, 2.38749031e-02,
2.55510695e-02, -1.21669313e-02, 2.98210932e-03,
3.11383940e-02, 2.01738160e-02, -1.07578654e-03,
-1.72522105e-02],
[-1.20945638e-02, 8.68875161e-03, -2.62735188e-02,
1.12946620e-02, -1.22309905e-02, -5.58001362e-03,
1.08549614e-02, 8.95095989e-03, -7.41497055e-03,
1.16792703e-02, -2.28939448e-02, -3.50253959e-03,
9.97900032e-03, 1.66020077e-02, -7.89092761e-03,
-4.07132506e-02, -1.78568307e-02, 4.34662355e-03,
1.21228509e-02, 3.13125411e-03, 1.45842955e-02,
-2.05406677e-02, -8.31084419e-03, -4.01034905e-03,
7.12802447e-03, -1.84220411e-02, 1.13592790e-02,
-1.26814209e-02, -1.18051590e-02, -8.80334340e-03,
3.80245410e-03, -8.09166487e-03, 1.76429395e-02,
-1.92356238e-04, -5.44061745e-03, 1.36529943e-02,
-1.77526288e-02, -7.82714784e-03, 1.20329263e-03,
-4.59810719e-03, 1.16233192e-02, 1.27270408e-02,
-7.62180379e-03, -9.63985734e-03, -7.18248449e-03,
-2.18094457e-02, -1.00004785e-02, -1.08398555e-03,
3.66409164e-04, -2.32298975e-03, 1.85762774e-02,
-3.05683468e-03, -1.08189192e-02, -8.16119835e-03,
2.15096259e-03, -1.27328513e-02, 3.22906636e-02,
2.98772007e-02, -1.85494740e-02, -6.53657946e-04,
3.88556197e-02, 2.41267327e-02, -4.84270183e-03,
-2.17875950e-02],
[-5.25501464e-03, 8.72639474e-03, -2.05799732e-02,
6.14312338e-03, -1.72808673e-02, -1.31866131e-02,
5.83314849e-03, 6.02243049e-03, -3.84115218e-03,
1.41651463e-02, -2.30942499e-02, -5.55078313e-03,
1.04255294e-02, 2.30545755e-02, -2.18844530e-03,
-3.62902768e-02, -1.92042850e-02, 1.22342422e-03,
1.13154203e-02, -2.17898004e-03, 1.37291076e-02,
-2.04059239e-02, -8.93826876e-03, -3.92630836e-03,
8.99609085e-03, -1.99639872e-02, 9.40152071e-03,
-1.68196503e-02, -1.08547490e-02, -4.02889075e-03,
1.14458892e-02, -9.00383666e-03, 1.34997619e-02,
3.90134053e-03, -6.80747256e-03, 9.08577070e-03,
-1.72130466e-02, -1.32173812e-02, -1.73806830e-03,
-4.78662038e-03, 1.92851052e-02, 9.79739800e-03,
-4.36533149e-03, -8.25465377e-03, -3.25881690e-03,
-1.42115178e-02, -3.79414624e-03, 3.12958076e-03,
2.55106570e-04, 1.77397695e-03, 2.18093134e-02,
-1.31484016e-03, -1.34956567e-02, -5.44582447e-03,
2.86075217e-03, -2.14463435e-02, 3.33723240e-02,
2.67816409e-02, -1.40035218e-02, 4.13230434e-03,
3.57808806e-02, 2.61333492e-02, -8.73044250e-04,
-2.66207643e-02],
[-7.28952140e-03, 3.69984098e-03, -1.50594991e-02,
-2.92313565e-03, -1.21962698e-02, -1.36810802e-02,
6.82729529e-03, 6.73223054e-03, -5.93179138e-04,
1.36305839e-02, -1.97144989e-02, -3.86482291e-03,
1.80094223e-02, 1.42702283e-02, -1.32999849e-03,
-3.40334214e-02, -1.76202524e-02, 2.30349993e-04,
1.09610325e-02, 5.54277329e-03, 7.91644678e-03,
-2.16093995e-02, -1.39994686e-02, -3.96339269e-03,
8.55575502e-03, -9.74893570e-03, 1.05637731e-02,
-4.55878396e-03, -1.34231234e-02, -5.41363843e-04,
4.89153492e-04, -6.66437577e-03, 1.07367048e-02,
3.43973073e-03, -4.18765657e-03, 6.89268531e-03,
-1.19383521e-02, -8.89711361e-03, -4.44964785e-03,
5.09598665e-03, 1.03713274e-02, 8.53536651e-03,
-8.85974988e-03, -5.51856030e-03, -6.80169091e-04,
-1.59089398e-02, 4.49734647e-03, 5.99729270e-03,
2.46776640e-03, -1.70147279e-03, 5.19089587e-03,
1.00052624e-03, -5.67172188e-03, -3.38913826e-03,
2.77179107e-03, -6.60816021e-03, 2.12577283e-02,
2.13659275e-02, -1.32223312e-02, 2.25704932e-03,
2.57163309e-02, 2.04398781e-02, -1.69396808e-04,
-1.73870064e-02],
[-3.72720137e-03, 1.10780252e-02, -1.58859324e-02,
6.66555809e-03, -1.44280717e-02, -8.00609123e-03,
6.45924266e-03, 8.54926067e-04, -1.88849677e-04,
8.98782630e-03, -1.54129220e-02, -1.08604110e-03,
1.07514234e-02, 1.73932631e-02, -1.32540641e-02,
-3.20286751e-02, -1.51650719e-02, -4.61112056e-03,
1.35826627e-02, 2.19380498e-04, 1.40464576e-02,
-1.90667752e-02, -4.66079684e-03, -3.83105595e-03,
7.03095598e-03, -1.30360266e-02, 1.06415655e-02,
-9.59097221e-03, -9.75623727e-03, -7.97409564e-03,
5.58470841e-03, -9.38377157e-03, 7.96551071e-03,
-4.83714510e-04, -2.05025147e-03, 8.32131505e-03,
-1.56590100e-02, -1.04131605e-02, -3.17667797e-03,
-5.72070433e-03, 1.51903769e-02, 1.19651994e-02,
-4.08067275e-03, -1.01910857e-02, 3.11617507e-03,
-1.46057652e-02, -4.41839360e-03, 5.06953197e-03,
-4.96832095e-03, 7.81613402e-03, 1.70495212e-02,
2.49758250e-05, -7.83495232e-03, -4.50461730e-03,
-2.93536019e-03, -1.30293816e-02, 2.36500446e-02,
1.58338025e-02, -1.86773948e-02, 7.61539035e-04,
2.71951333e-02, 1.39996661e-02, -5.14069805e-03,
-1.78940073e-02],
[-6.54364191e-03, 5.09772869e-03, -1.04034524e-02,
-3.73475719e-03, -1.19625181e-02, -1.40891932e-02,
9.06457752e-03, 3.41503625e-03, -4.86045564e-03,
1.25511140e-02, -2.17867326e-02, -7.47563411e-03,
1.25360051e-02, 1.62971858e-02, -3.65820760e-03,
-3.51158790e-02, -1.73445251e-02, -2.78897071e-03,
6.79790135e-03, 3.37651756e-04, 1.03165153e-02,
-2.39069872e-02, -7.28563499e-03, -7.42059347e-05,
1.29556786e-02, -1.14473784e-02, 1.45603633e-02,
-1.20491283e-02, -1.42659768e-02, -3.70843848e-03,
4.51912638e-03, -7.15361023e-03, 1.15718255e-02,
6.00760849e-03, -6.92916662e-03, 7.49018928e-03,
-1.01785287e-02, -8.63973703e-03, -4.90589021e-03,
-4.61561285e-04, 1.11296903e-02, 9.36738681e-03,
-7.41511211e-03, -6.98906882e-03, -3.12256836e-03,
-1.77129637e-02, -7.80898728e-04, 9.68260039e-03,
3.10223433e-03, 4.86978563e-03, 1.11100767e-02,
-9.04789940e-03, -1.10857347e-02, -6.04295917e-03,
3.34106589e-04, -8.40245467e-03, 2.65593827e-02,
1.84700415e-02, -1.15374802e-02, -1.12538924e-03,
2.62181889e-02, 1.91304646e-02, -2.57981522e-03,
-1.63868871e-02],
[ 8.44749855e-04, -1.66689996e-02, -8.96400423e-04,
6.72107562e-03, -2.39076628e-03, 2.60190992e-03,
9.26916581e-03, 8.82431399e-03, -7.25202635e-03,
-1.06594963e-02, -5.55762183e-03, -1.29527331e-03,
-1.64227630e-03, 2.20268476e-03, 1.16184941e-02,
-1.13204587e-02, 5.73474122e-03, 1.36803361e-02,
5.84689900e-03, 9.08445287e-03, -1.40777905e-03,
-1.06005892e-02, -3.84641928e-03, 2.70416541e-03,
4.00838861e-03, 2.82439473e-03, 2.29762960e-03,
-1.00108888e-03, -4.37536306e-04, -5.20851184e-03,
4.79862979e-03, -4.48886072e-03, 8.50347336e-03,
-1.42790927e-02, -1.26732215e-02, 4.63177776e-03,
-1.20126840e-03, 5.99153014e-03, 1.22683365e-02,
3.37655889e-04, -1.39692798e-02, 3.31070763e-03,
-5.39805554e-03, 2.78319512e-02, -6.16331352e-03,
5.77836717e-03, -1.19619851e-03, -1.58095714e-02,
3.77729093e-03, 1.85370538e-03, 6.35961816e-03,
3.49745946e-03, -1.27604958e-02, -9.54155251e-03,
1.23564843e-02, -7.57912593e-03, -3.30073433e-03,
-7.29874987e-03, -6.47724420e-03, -1.41964471e-02,
1.36867436e-02, 5.99695602e-03, 9.53863724e-04,
5.32696443e-03]], dtype=float32)
这个(多维)向量是否表示 LSTM 层的输出,是否可以选择每个内部向量作为相应句子的向量表示(即 vector1 -> sentence1 等等)?
根据我链接的 youtube 视频,这似乎是正确的,而且我还发现了其他似乎做同样的课程(例如在这个博客LSTM:了解输出类型)。但我并不相信。我的疑问是为什么该函数model.predict()会给出 LSTM 层的输出?Pythonhelp告诉这个函数:
为输入样本生成输出预测。
或者也在3.6 中。scikit-learn:Python 中的机器学习:
在有监督的估计器中::
model.predict()给定一个训练好的模型,预测一组新数据的标签。此方法接受一个参数,即新数据X_new(例如model.predict(X_new)),并返回数组中每个对象的学习标签。
那么,鉴于这些解释,为什么要model.predict()给出 LSTM 输出向量呢?究竟model.predict()是做什么的?
综上所述,我想知道我报告使用的程序model.predict()是否正确获取 Keras 中 LSTM 层的输出以及为什么。此外,如果不是,您是否可以建议正确的程序。
提前非常感谢。
----编辑----:另一种方法是使用https://keras.io/getting_started/faq/#how-can-i-obtain-the-output-of-an-intermediate-layer中的解释-特征提取
intermediate_layer_model = keras.Model(inputs=model.input,
outputs=model.get_layer(layer_name).output)
intermediate_output = intermediate_layer_model(data)
使用layer_name = 我的 lstm 层的名称和data = padded_docs。我认为这个过程肯定是正确的,因为它在 Keras 文档中。而且我也认为它类似于第一种方法model.predict(padded_docs)。y=model(x)在同一个 Keras 链接页面中,使用和之间存在区别y=model.predict(x),它表示它们都表示“在 x 上运行模型并检索输出 y”。
因此,总而言之,也许他们都通过模型或模型运行输入数据(在这种情况下是我们的句子)直到感兴趣的层,并通过具有参数(权重)的模型检索输出处理数据,这些数据在之后更新培训 ?这是否意味着获得一个层的输出(在这种情况下是一个 LSTM 层)?