使用 numpy.vstack 展开以连接的 28x28 mnist 图像。为什么 numpy.vstack 这么慢?

数据挖掘 喀拉斯 张量流 麻木的
2022-03-07 06:11:41

我正在尝试将 60000,28,28 mnist 数字列表简单地重塑为 60000,784 numpy 数组,其中数字已展开。

为此,代码如下:

(xdata,xlabel),(ydata,ylabel)=tf.keras.datasets.mnist.load_data()
newxdata=np.array([])
cnt=0
for i in xdata:
 tmpx=i.ravel()
 if cnt == 0:
  newxdata=np.concatenate((newxdata,tmpx))
 else:
  newxdata=np.vstack((newxdata,tmpx))
 cnt=cnt+1

为什么这需要这么长时间才能运行?有没有办法加快速度?最终,数据将以较小的批量输入到 keras 模型中。当要求批量大小时,编写一个执行循环展开的生成器会提高性能还是不会产生影响?

2个回答

您可以简单地通过以下方式重塑 numpy 数组:

newxdata =  xdata.reshape((60000,28*28))

例如。或者简单地说:

newxdata =  xdata.reshape((len(xdata),-1))

请注意,reshape 是一个 numpy 函数,它也可以用作:

import numpy as np
newxdata =  np.reshape(xdata, (60000,-1))

为了加快您的循环,您也可以使用multiprocessing.PoolCuPyNumba等库。

我接受了这个建议,并对实施进行了一些补充。

最终,新转换的数据将被一个需要输入(batch_size,unrolled image)的keras函数摄取,所以我创建了一个生成器函数.returnbatch(batch_size),它返回一个(batchsize,unrolledimage)