diff --git a/include/xsf/config.h b/include/xsf/config.h index ec7f0bbff8..1ad72238c7 100644 --- a/include/xsf/config.h +++ b/include/xsf/config.h @@ -60,6 +60,7 @@ #include #include #include +#include #include #include #include @@ -156,6 +157,9 @@ XSF_HOST_DEVICE constexpr T clamp(T &v, T &lo, T &hi) { template using numeric_limits = cuda::std::numeric_limits; +using cuda::std::gcd; +using cuda::std::lcm; + // Must use thrust for complex types in order to support CuPy template using complex = thrust::complex; @@ -244,6 +248,7 @@ using cuda::std::uint64_t; #include #include #include +#include #include #include #include diff --git a/include/xsf/stats.h b/include/xsf/stats.h index 2c78c9aedc..8e53b82c9f 100644 --- a/include/xsf/stats.h +++ b/include/xsf/stats.h @@ -1,6 +1,7 @@ #pragma once #include "xsf/bessel.h" +#include "xsf/binom.h" #include "xsf/cephes/bdtr.h" #include "xsf/cephes/chdtr.h" #include "xsf/cephes/fdtr.h" @@ -284,6 +285,126 @@ inline double pdtrc(double k, double m) { return cephes::pdtrc(k, m); } inline double pdtri(int k, double y) { return cephes::pdtri(k, y); } +namespace detail { + + template + XSF_HOST_DEVICE inline void + cvm_freq_table_all(int64_t m, int64_t n, int64_t a, int64_t b, FreqTable2D gs, FreqTable2D next_gs) { + using T = typename FreqTable2D::value_type; + int64_t K = static_cast(gs.extent(1)); + + // initialize gs to 0 + for (int64_t v = 0; v < m + 1; ++v) + for (int64_t k = 0; k < K; ++k) + gs(v, k) = T(0); + // base case: gs(0, 0) = 1 + gs(0, 0) = T(1); + + for (int64_t u = 0; u < n + 1; ++u) { + // v = 0: no next_gs(v-1, ...) term + { + int64_t d = -b * u; + int64_t d2 = d * d; + int64_t kstart = (d2 < K) ? d2 : K; + // next_gs(0, k) = gs(0, k - d2) for k >= d2, else 0 + for (int64_t k = 0; k < kstart; ++k) { + next_gs(0, k) = T(0); + } + for (int64_t k = kstart; k < K; ++k) { + next_gs(0, k) = gs(0, k - d2); + } + } + // v > 0: both terms contribute + for (int64_t v = 1; v < m + 1; ++v) { + int64_t d = a * v - b * u; + int64_t d2 = d * d; // d^2 = (a*v - b*u)^2 + int64_t kstart = (d2 < K) ? d2 : K; + for (int64_t k = 0; k < kstart; ++k) { + next_gs(v, k) = T(0); + } + for (int64_t k = kstart; k < K; ++k) { + next_gs(v, k) = next_gs(v - 1, k - d2) + gs(v, k - d2); + } + } + FreqTable2D tmp = gs; + gs = next_gs; + next_gs = tmp; + } + // We swap `gs` and `next_gs` at each u-step, so buffer parity depends on n. + // If n is even, the final table ends up in the original `next_gs` buffer; + // copy it back so the caller can always read results from the original `gs`. + if (n % 2 == 0) { + for (int64_t v = 0; v < m + 1; ++v) { + for (int64_t k = 0; k < K; ++k) { + next_gs(v, k) = gs(v, k); + } + } + } + } + +} // namespace detail + +template +XSF_HOST_DEVICE inline void cvm_2samp_freq_table(int64_t m, int64_t n, FreqTable2D freq_table, FreqTable2D workspace) { + /* + * Generate the exact Cramér-von Mises two-sample frequency table for + * sample sizes m and n. The table is independent of the scalar statistic. + */ + if (m <= 0 || n <= 0) { + set_error("cvm_2samp_freq_table", SF_ERROR_DOMAIN, "m and n must be positive"); + return; + } + // [1, p. 3] + int64_t lcm = std::lcm(m, n); + // [1, p. 4], below eq. 3 + int64_t a = lcm / m; + int64_t b = lcm / n; + + detail::cvm_freq_table_all(m, n, a, b, freq_table, workspace); +} + +template +XSF_HOST_DEVICE inline double pval_cvm_2samp_exact(double s, int64_t m, int64_t n, FreqTable2D freq_table) { + /* + * Compute the exact p-value of the Cramér-von Mises two-sample test + * for a given value s of the test statistic and where m and n are the sizes + * of the samples. + * + * [1] Y. Xiao, A. Gordon, and A. Yakovlev, "A C++ Program for + * the Cramér-Von Mises Two-Sample Test", J. Stat. Soft., + * vol. 17, no. 8, pp. 1-15, Dec. 2006. + * [2] T. W. Anderson "On the Distribution of the Two-Sample Cramér-von Mises + * Criterion," The Annals of Mathematical Statistics, Ann. Math. Statist. + * 33(3), 1148-1159, (September, 1962) + */ + if (m <= 0 || n <= 0) { + set_error("pval_cvm_2samp_exact", SF_ERROR_DOMAIN, "m and n must be positive"); + return std::numeric_limits::quiet_NaN(); + } + // [1, p. 3] + int64_t lcm = std::lcm(m, n); + // Combine Eq. 9 in [2] with Eq. 2 in [1] and solve for $\zeta$ + // Hint: `s` is $U$ in [2], and $T_2$ in [1] is $T$ in [2] + int64_t mn = m * n; + + // Uses double floor division since s is double + int64_t zeta = + static_cast(std::floor((lcm * lcm * (m + n) * (6.0 * s - mn * (4.0 * mn - 1))) / (6.0 * mn * mn))); + + int64_t K = static_cast(freq_table.extent(1)); + + // Clamp to prevent negative indexing when zeta < 0. + int64_t k0 = std::max(0, zeta); + + int64_t sum_freq = 0; + for (int64_t k = k0; k < K; ++k) { + sum_freq += freq_table(m, k); + } + + double combinations = xsf::binom(static_cast(m + n), static_cast(m)); + return sum_freq / combinations; +} + inline double smirnov(int n, double x) { return cephes::smirnov(n, x); } inline double smirnovc(int n, double x) { return cephes::smirnovc(n, x); } diff --git a/tests/xsf_tests/test_pval_cvm_2samp_exact.cpp b/tests/xsf_tests/test_pval_cvm_2samp_exact.cpp new file mode 100644 index 0000000000..84f8840ed1 --- /dev/null +++ b/tests/xsf_tests/test_pval_cvm_2samp_exact.cpp @@ -0,0 +1,53 @@ +#include "../testing_utils.h" +#define MDSPAN_USE_PAREN_OPERATOR 1 +#include +#include + +/* +// Reference values computed with scipy.stats._hypotests._pval_cvm_2samp_exact + +import numpy as np +from scipy import stats + +rng = np.random.default_rng(seed=42) + +list_m = rng.integers(3, 30, size=5) +list_n = rng.integers(3, 30, size=5) +rtol = 1e-10 + +for m, n in zip(list_m, list_n): + x = rng.standard_normal(m) + y = rng.standard_normal(n) + res = stats.cramervonmises_2samp(x, y, method="exact") + T = res.statistic + # Convert normalized statistic T to the unnormalized U + U = m * n * (m + n) * T + m * n * (4 * m * n - 1) / 6 + p_value = stats._hypotests._pval_cvm_2samp_exact(U, m, n) + assert np.isclose(res.pvalue, p_value, rtol=rtol), "The p-values do not match!" + print(f"U={U}, m={m}, n={n}, p-value={p_value}") +*/ +TEST_CASE("pval_cvm_2samp_exact test", "[pval_cvm_2samp_exact][xsf_tests]") { + using test_case = std::tuple; + auto [s, m, n, pval_expected, rtol] = GENERATE( + test_case{8862.0, 14, 8, 0.2679738562091503, 1e-10}, + test_case{3491.0000000000005, 14, 5, 0.34657722738218094, 1e-10}, + test_case{12559.0, 5, 26, 0.11812654860485784, 1e-10}, + test_case{8901.0, 23, 5, 0.9907610907610908, 1e-10} //, test_case{119376.0, 20, 21, 0.5716351061359124, 1e-10} + ); + + const int64_t lcm = std::lcm(m, n); + const int64_t K = (m + n) * lcm * lcm + 1; + + std::vector buf1((m + 1) * K, 0); + std::vector buf2((m + 1) * K, 0); + + using mdspan_2d = std::mdspan>; + mdspan_2d gs(buf1.data(), static_cast(m + 1), static_cast(K)); + mdspan_2d next_gs(buf2.data(), static_cast(m + 1), static_cast(K)); + + xsf::cvm_2samp_freq_table(m, n, gs, next_gs); + auto pval = xsf::pval_cvm_2samp_exact(s, m, n, gs); + const double rel_error = xsf::extended_relative_error(pval, pval_expected); + CAPTURE(s, m, n, K, pval, pval_expected, rel_error); + REQUIRE(rel_error <= rtol); +}