Skip to content
38 changes: 30 additions & 8 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,11 @@ void kmeans_fit(
auto centroids_const = raft::make_device_matrix_view<const DataT, IndexT>(
cur_centroids_ptr, n_clusters, n_features);

iter_inertia = DataT{0};
auto d_iter_inertia = raft::make_device_scalar<DataT>(handle, DataT{0});
auto d_batch_cost = raft::make_device_scalar<DataT>(handle, DataT{0});
DataT* p_acc = d_iter_inertia.data_handle();
DataT* p_batch = d_batch_cost.data_handle();

data_batches.reset();
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>;
std::optional<wt_iter_t> wt_it;
Expand All @@ -956,15 +960,33 @@ void kmeans_fit(
cur_batch_weights(static_cast<IndexT>(data_batch.offset()), wt_data, cur_batch_size);
}

DataT batch_cost = DataT{0};
cuvs::cluster::kmeans::cluster_cost(handle,
batch_data_view,
centroids_const,
raft::make_host_scalar_view(&batch_cost),
batch_sw);
std::optional<raft::device_vector_view<const DataT, IndexT>> batch_xnorm = std::nullopt;
if (need_compute_norms) {
if constexpr (data_on_device) {
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size);
} else {
raft::copy(L2NormBatch.data_handle(),
h_norm_cache.data_handle() + data_batch.offset(),
cur_batch_size,
stream);
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle(), cur_batch_size);
}
Comment on lines +963 to +975

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

HIGH: Guard the uncached host-norm path for zero-iteration fits.

Issue: For host data, this final inertia path always copies from h_norm_cache; if max_iter == 0, the training loop never populated it, so inertia is computed from uninitialized norms.
Why: This returns incorrect final inertia for a valid-looking parameter combination because max_iter is not rejected earlier.

Suggested fix
         if (need_compute_norms) {
           if constexpr (data_on_device) {
             batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
               L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size);
           } else {
-            raft::copy(L2NormBatch.data_handle(),
-                       h_norm_cache.data_handle() + data_batch.offset(),
-                       cur_batch_size,
-                       stream);
+            if (norms_cached) {
+              raft::copy(L2NormBatch.data_handle(),
+                         h_norm_cache.data_handle() + data_batch.offset(),
+                         cur_batch_size,
+                         stream);
+            } else {
+              compute_batch_norms(data_batch.data(), cur_batch_size);
+            }
             batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
               L2NormBatch.data_handle(), cur_batch_size);
           }
         }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/cluster/detail/kmeans.cuh` around lines 963 - 975, The host data path
(in the else block after checking data_on_device) unconditionally copies from
h_norm_cache without verifying it was populated. When max_iter equals zero, the
training loop never initializes h_norm_cache, causing the raft::copy operation
to read uninitialized memory and compute incorrect final inertia. Add a guard
condition in the else block to check if max_iter is zero, and handle this case
separately (either by skipping the norm computation or computing norms
on-the-fly) before attempting to copy from the unpopulated h_norm_cache.

}

cuvs::cluster::kmeans::cluster_cost(
handle, batch_data_view, centroids_const, d_batch_cost.view(), batch_sw, batch_xnorm);

iter_inertia += batch_cost;
raft::linalg::map_offset(handle,
raft::make_device_vector_view<DataT, int>(p_acc, 1),
[p_acc, p_batch] __device__(int) { return *p_acc + *p_batch; });
}

raft::copy(handle,
raft::make_host_scalar_view<DataT>(&iter_inertia),
raft::make_const_mdspan(d_iter_inertia.view()));
raft::resource::sync_stream(handle);
}

if (iter_inertia < inertia[0]) {
Expand Down
28 changes: 22 additions & 6 deletions cpp/src/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,17 @@ void min_cluster_distance(raft::resources const& handle,
* @param[in] centroids Cluster centroids [n_clusters x n_features]
* @param[out] cost Sum of squared distances to nearest centroid (device)
* @param[in] sample_weight Optional per-sample weights [n_samples]
* @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples].
* When provided, the internal norm computation is skipped.
*/
template <typename DataT, typename IndexT>
void cluster_cost(
raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_scalar_view<DataT> cost,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight = std::nullopt)
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight = std::nullopt,
std::optional<raft::device_vector_view<const DataT, IndexT>> X_norm = std::nullopt)
{
auto stream = raft::resource::get_cuda_stream(handle);
auto n_clusters = centroids.extent(0);
Expand All @@ -398,8 +401,17 @@ void cluster_cost(

rmm::device_uvector<char> workspace(n_samples * sizeof(IndexT), stream);

auto x_norms = raft::make_device_vector<DataT>(handle, n_samples);
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, x_norms.view());
std::optional<raft::device_vector<DataT, IndexT>> x_norms_buf;
DataT* x_norms_ptr;
if (X_norm.has_value()) {
x_norms_ptr = const_cast<DataT*>(X_norm->data_handle());
} else {
x_norms_buf.emplace(raft::make_device_vector<DataT, IndexT>(handle, n_samples));
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, X, x_norms_buf->view());
x_norms_ptr = x_norms_buf->data_handle();
}
auto x_norms_view = raft::make_device_vector_view<DataT, IndexT>(x_norms_ptr, n_samples);
Comment thread
coderabbitai[bot] marked this conversation as resolved.

auto min_cluster_distance = raft::make_device_vector<DataT>(handle, n_samples);
rmm::device_uvector<DataT> l2_norm_or_distance_buffer(0, stream);
Expand All @@ -412,7 +424,7 @@ void cluster_cost(
raft::make_device_matrix_view<DataT, IndexT>(
const_cast<DataT*>(centroids.data_handle()), n_clusters, n_features),
min_cluster_distance.view(),
x_norms.view(),
x_norms_view,
l2_norm_or_distance_buffer,
metric,
n_samples,
Expand Down Expand Up @@ -444,17 +456,21 @@ void cluster_cost(
* @param[in] centroids Cluster centroids [n_clusters x n_features]
* @param[out] cost Sum of squared distances to nearest centroid (host)
* @param[in] sample_weight Optional per-sample weights [n_samples]
* @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples].
* When provided, the internal norm computation is skipped.
*/
template <typename DataT, typename IndexT>
void cluster_cost(
raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::host_scalar_view<DataT> cost,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight = std::nullopt)
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight = std::nullopt,
std::optional<raft::device_vector_view<const DataT, IndexT>> X_norm = std::nullopt)
{
auto device_cost = raft::make_device_scalar<DataT>(handle, DataT(0));
cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, device_cost.view(), sample_weight);
cuvs::cluster::kmeans::cluster_cost(
handle, X, centroids, device_cost.view(), sample_weight, X_norm);
raft::copy(handle, cost, raft::make_const_mdspan(device_cost.view()));
raft::resource::sync_stream(handle);
}
Expand Down
Loading