我是机器学习的新手,并试图将本教程应用于我的表格数据。我的输入是一个包含特征的熊猫数据框(我将分类列编码为浮点数)和一个包含标签的熊猫系列,这些标签是整数。这是我的代码:
def df_to_dataset(dataframe, labels, shuffle=True, batch_size=32):
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
if shuffle:
ds = ds.shuffle(buffer_size=len(dataframe))
ds = ds.batch(batch_size)
return ds
feature_columns = []
# numeric cols
for header in list(X_train):
feature_columns.append(feature_column.numeric_column(header))
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
batch_size = 32
train_ds = df_to_dataset(X_train, y_train, batch_size=batch_size)
val_ds = df_to_dataset(X_val, y_val, shuffle=False, batch_size=batch_size)
test_ds = df_to_dataset(X_test,y_test, shuffle=False,
batch_size=batch_size)
model = tf.keras.Sequential([
feature_layer,
layers.Dense(128, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'],)
#run_eagerly=True)
model.fit(train_ds,
validation_data=val_ds,
epochs=5)
loss, accuracy = model.evaluate(test_ds)
print("Accuracy", accuracy)
首先,当我打印数据集的形状时,所有特征形状都显示为 (?,)。我查了一下,但并没有真正明白这意味着什么,所以我继续前进。执行模型时,我得到以下输出:
W0625 16:28:50.013361 140172694484864 deprecation.py:323] From
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py:180:
add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops)
is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 1/5
2437/2437 [==============================] - 12s 5ms/step - loss:
-368933040.3744 - acc: 0.0000e+00 - val_loss: -1374389959.0878 - val_acc:
0.0000e+00
Epoch 2/5
2437/2437 [==============================] - 11s 4ms/step - loss:
-4239125012.7993 - acc: 0.0000e+00 - val_loss: -8055676778.8942 - val_acc:
0.0000e+00
Epoch 3/5
2437/2437 [==============================] - 11s 4ms/step - loss:
-14449097654.0468 - acc: 0.0000e+00 - val_loss: -21844830544.8532 - val_acc:
0.0000e+00
Epoch 4/5
2437/2437 [==============================] - 11s 4ms/step - loss:
-32560744568.1740 - acc: 0.0000e+00 - val_loss: -44181551604.6596 - val_acc:
0.0000e+00
Epoch 5/5
2437/2437 [==============================] - 11s 4ms/step - loss:
-60235093753.8022 - acc: 0.0000e+00 - val_loss: -76823729189.8015 - val_acc:
0.0000e+00
1219/1219 [==============================] - 3s 2ms/step - loss:
-78553874677.2896 - acc: 0.0000e+00
Accuracy 0.0
显然出了点问题,但我什至不知道从哪里开始调试它,所以欢迎任何提示!编辑:我试图预测的标签是从 80 到 100 的分数。因此,我从 binary_crossentropy 切换到 mean_squared 错误。现在loss还是很大,但是正数,准确率还是0