scikit-learn for kmeans 中的惯性 cython 实现如何工作?

数据挖掘 scikit-学习 k-均值
2022-02-27 13:02:43

具体来说,这个&符号代表什么?为什么列索引总是0?

    cpdef floating _inertia_dense(
        np.ndarray[floating, ndim=2, mode='c'] X,  # IN
        floating[::1] sample_weight,               # IN
        floating[:, ::1] centers,                  # IN
        int[::1] labels):                          # IN
    """Compute inertia for dense input data
    Sum of squared distance between each sample and its assigned center.
    """
    cdef:
        int n_samples = X.shape[0]
        int n_features = X.shape[1]
        int i, j

        floating sq_dist = 0.0
        floating inertia = 0.0

    for i in range(n_samples):
        j = labels[i]
        sq_dist = _euclidean_dense_dense(&X[i, 0], &centers[j, 0],
                                         n_features, True)
        inertia += sq_dist * sample_weight[i]

    return inertia

1个回答

&是 中的“地址”运算符c,这似乎是它在这里的使用方式。请参阅 两个SO 帖子。

注意签名_euclidean_dense_dense

cdef floating _euclidean_dense_dense(
        floating* a, # IN
        floating* b, # IN
        int n_features,
        bint squared) nogil:

前两个输入是指针。所以你需要传递地址,而不是数据的副本。

还要注意,只有第一列中元素的地址被传递。如果你看一下它的定义_euclidean_dense_dense就会变得更清楚:该函数实际上循环遍历其计算中其余列的地址。