Skip to content
5 changes: 5 additions & 0 deletions include/xsf/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include <cuda/std/cstddef>
#include <cuda/std/cstdint>
#include <cuda/std/limits>
#include <cuda/std/numeric>
#include <cuda/std/tuple>
#include <cuda/std/type_traits>
#include <cuda/std/utility>
Expand Down Expand Up @@ -156,6 +157,9 @@ XSF_HOST_DEVICE constexpr T clamp(T &v, T &lo, T &hi) {
template <typename T>
using numeric_limits = cuda::std::numeric_limits<T>;

using cuda::std::gcd;
using cuda::std::lcm;

Comment on lines +160 to +162
Copy link
Copy Markdown
Member

@steppi steppi Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry about the CUDA side. I still have to figure out cupy/cupy#9839 before we can even try these in CuPy. I'm pretty sure just adding using cuda::std::gcd won't work in all cases, and we actually need wrappers for stdlib functions like the other ones in this file. I recall I had suggested using using like this when Irwin first set this up, but there was a reason he had done things the way he did.

// Must use thrust for complex types in order to support CuPy
template <typename T>
using complex = thrust::complex<T>;
Expand Down Expand Up @@ -244,6 +248,7 @@ using cuda::std::uint64_t;
#include <iterator>
#include <limits>
#include <math.h>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <utility>
Expand Down
121 changes: 121 additions & 0 deletions include/xsf/stats.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "xsf/bessel.h"
#include "xsf/binom.h"
#include "xsf/cephes/bdtr.h"
#include "xsf/cephes/chdtr.h"
#include "xsf/cephes/fdtr.h"
Expand Down Expand Up @@ -284,6 +285,126 @@ inline double pdtrc(double k, double m) { return cephes::pdtrc(k, m); }

inline double pdtri(int k, double y) { return cephes::pdtri(k, y); }

namespace detail {

template <typename FreqTable2D>
XSF_HOST_DEVICE inline void
cvm_freq_table_all(int64_t m, int64_t n, int64_t a, int64_t b, FreqTable2D gs, FreqTable2D next_gs) {
using T = typename FreqTable2D::value_type;
int64_t K = static_cast<int64_t>(gs.extent(1));

// initialize gs to 0
for (int64_t v = 0; v < m + 1; ++v)
for (int64_t k = 0; k < K; ++k)
gs(v, k) = T(0);
// base case: gs(0, 0) = 1
gs(0, 0) = T(1);

for (int64_t u = 0; u < n + 1; ++u) {
// v = 0: no next_gs(v-1, ...) term
{
int64_t d = -b * u;
int64_t d2 = d * d;
int64_t kstart = (d2 < K) ? d2 : K;
// next_gs(0, k) = gs(0, k - d2) for k >= d2, else 0
for (int64_t k = 0; k < kstart; ++k) {
next_gs(0, k) = T(0);
}
for (int64_t k = kstart; k < K; ++k) {
next_gs(0, k) = gs(0, k - d2);
}
}
// v > 0: both terms contribute
for (int64_t v = 1; v < m + 1; ++v) {
int64_t d = a * v - b * u;
int64_t d2 = d * d; // d^2 = (a*v - b*u)^2
int64_t kstart = (d2 < K) ? d2 : K;
for (int64_t k = 0; k < kstart; ++k) {
next_gs(v, k) = T(0);
}
for (int64_t k = kstart; k < K; ++k) {
next_gs(v, k) = next_gs(v - 1, k - d2) + gs(v, k - d2);
}
}
FreqTable2D tmp = gs;
gs = next_gs;
next_gs = tmp;
}
// We swap `gs` and `next_gs` at each u-step, so buffer parity depends on n.
// If n is even, the final table ends up in the original `next_gs` buffer;
// copy it back so the caller can always read results from the original `gs`.
if (n % 2 == 0) {
for (int64_t v = 0; v < m + 1; ++v) {
for (int64_t k = 0; k < K; ++k) {
next_gs(v, k) = gs(v, k);
}
}
}
}

} // namespace detail

template <typename FreqTable2D>
XSF_HOST_DEVICE inline void cvm_2samp_freq_table(int64_t m, int64_t n, FreqTable2D freq_table, FreqTable2D workspace) {
/*
* Generate the exact Cramér-von Mises two-sample frequency table for
* sample sizes m and n. The table is independent of the scalar statistic.
*/
if (m <= 0 || n <= 0) {
set_error("cvm_2samp_freq_table", SF_ERROR_DOMAIN, "m and n must be positive");
return;
}
// [1, p. 3]
int64_t lcm = std::lcm(m, n);
// [1, p. 4], below eq. 3
int64_t a = lcm / m;
int64_t b = lcm / n;

detail::cvm_freq_table_all(m, n, a, b, freq_table, workspace);
}

template <typename FreqTable2D>
XSF_HOST_DEVICE inline double pval_cvm_2samp_exact(double s, int64_t m, int64_t n, FreqTable2D freq_table) {
/*
* Compute the exact p-value of the Cramér-von Mises two-sample test
* for a given value s of the test statistic and where m and n are the sizes
* of the samples.
*
* [1] Y. Xiao, A. Gordon, and A. Yakovlev, "A C++ Program for
* the Cramér-Von Mises Two-Sample Test", J. Stat. Soft.,
* vol. 17, no. 8, pp. 1-15, Dec. 2006.
* [2] T. W. Anderson "On the Distribution of the Two-Sample Cramér-von Mises
* Criterion," The Annals of Mathematical Statistics, Ann. Math. Statist.
* 33(3), 1148-1159, (September, 1962)
*/
if (m <= 0 || n <= 0) {
set_error("pval_cvm_2samp_exact", SF_ERROR_DOMAIN, "m and n must be positive");
return std::numeric_limits<double>::quiet_NaN();
}
// [1, p. 3]
int64_t lcm = std::lcm(m, n);
// Combine Eq. 9 in [2] with Eq. 2 in [1] and solve for $\zeta$
// Hint: `s` is $U$ in [2], and $T_2$ in [1] is $T$ in [2]
int64_t mn = m * n;

// Uses double floor division since s is double
int64_t zeta =
static_cast<int64_t>(std::floor((lcm * lcm * (m + n) * (6.0 * s - mn * (4.0 * mn - 1))) / (6.0 * mn * mn)));

int64_t K = static_cast<int64_t>(freq_table.extent(1));

// Clamp to prevent negative indexing when zeta < 0.
int64_t k0 = std::max<int64_t>(0, zeta);

int64_t sum_freq = 0;
for (int64_t k = k0; k < K; ++k) {
sum_freq += freq_table(m, k);
}

double combinations = xsf::binom(static_cast<double>(m + n), static_cast<double>(m));
return sum_freq / combinations;
}

inline double smirnov(int n, double x) { return cephes::smirnov(n, x); }

inline double smirnovc(int n, double x) { return cephes::smirnovc(n, x); }
Expand Down
53 changes: 53 additions & 0 deletions tests/xsf_tests/test_pval_cvm_2samp_exact.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "../testing_utils.h"
#define MDSPAN_USE_PAREN_OPERATOR 1
#include <xsf/stats.h>
#include <xsf/third_party/kokkos/mdspan.hpp>

/*
// Reference values computed with scipy.stats._hypotests._pval_cvm_2samp_exact

import numpy as np
from scipy import stats

rng = np.random.default_rng(seed=42)

list_m = rng.integers(3, 30, size=5)
list_n = rng.integers(3, 30, size=5)
rtol = 1e-10

for m, n in zip(list_m, list_n):
x = rng.standard_normal(m)
y = rng.standard_normal(n)
res = stats.cramervonmises_2samp(x, y, method="exact")
T = res.statistic
# Convert normalized statistic T to the unnormalized U
U = m * n * (m + n) * T + m * n * (4 * m * n - 1) / 6
p_value = stats._hypotests._pval_cvm_2samp_exact(U, m, n)
assert np.isclose(res.pvalue, p_value, rtol=rtol), "The p-values do not match!"
print(f"U={U}, m={m}, n={n}, p-value={p_value}")
*/
TEST_CASE("pval_cvm_2samp_exact test", "[pval_cvm_2samp_exact][xsf_tests]") {
using test_case = std::tuple<double, int, int, double, double>;
auto [s, m, n, pval_expected, rtol] = GENERATE(
test_case{8862.0, 14, 8, 0.2679738562091503, 1e-10},
test_case{3491.0000000000005, 14, 5, 0.34657722738218094, 1e-10},
test_case{12559.0, 5, 26, 0.11812654860485784, 1e-10},
test_case{8901.0, 23, 5, 0.9907610907610908, 1e-10} //, test_case{119376.0, 20, 21, 0.5716351061359124, 1e-10}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When m=20 and n=21, the test was extremely slow, likely due to it taking too much memory.

);

const int64_t lcm = std::lcm(m, n);
const int64_t K = (m + n) * lcm * lcm + 1;

std::vector<int64_t> buf1((m + 1) * K, 0);
std::vector<int64_t> buf2((m + 1) * K, 0);

using mdspan_2d = std::mdspan<int64_t, std::dextents<size_t, 2>>;
mdspan_2d gs(buf1.data(), static_cast<size_t>(m + 1), static_cast<size_t>(K));
mdspan_2d next_gs(buf2.data(), static_cast<size_t>(m + 1), static_cast<size_t>(K));

xsf::cvm_2samp_freq_table(m, n, gs, next_gs);
auto pval = xsf::pval_cvm_2samp_exact(s, m, n, gs);
const double rel_error = xsf::extended_relative_error(pval, pval_expected);
CAPTURE(s, m, n, K, pval, pval_expected, rel_error);
REQUIRE(rel_error <= rtol);
}
Loading