-
-
Notifications
You must be signed in to change notification settings - Fork 22
ENH: Implement the exact p-value of the Cramér-von Mises two-sample test #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| #include "xsf/cephes/tukey.h" | ||
| #include "xsf/erf.h" | ||
| #include "xsf/gamma.h" | ||
| #include <numeric> | ||
|
|
||
| namespace xsf { | ||
|
|
||
|
|
@@ -284,6 +285,87 @@ inline double pdtrc(double k, double m) { return cephes::pdtrc(k, m); } | |
|
|
||
| inline double pdtri(int k, double y) { return cephes::pdtri(k, y); } | ||
|
|
||
| inline int64_t comb(int64_t n, int64_t k) { | ||
| // binomial coefficient n choose k, i.e. n! / (k! * (n - k)!) | ||
| if (k > n - k) | ||
| k = n - k; | ||
| int64_t result = 1; | ||
| for (int64_t i = 0; i < k; ++i) { | ||
| result = result * (n - i) / (i + 1); | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
| /* | ||
| * Compute the exact p-value of the Cramer-von Mises two-sample test | ||
| * for a given value s of the test statistic. | ||
| * m and n are the sizes of the samples. | ||
| * | ||
| * [1] Y. Xiao, A. Gordon, and A. Yakovlev, "A C++ Program for | ||
| * the Cramér-Von Mises Two-Sample Test", J. Stat. Soft., | ||
| * vol. 17, no. 8, pp. 1-15, Dec. 2006. | ||
| * [2] T. W. Anderson "On the Distribution of the Two-Sample Cramer-von Mises | ||
| * Criterion," The Annals of Mathematical Statistics, Ann. Math. Statist. | ||
| * 33(3), 1148-1159, (September, 1962) | ||
| */ | ||
| inline double pval_cvm_2samp_exact(double s, int64_t m, int64_t n) { | ||
| // [1, p. 3] | ||
| int64_t lcm = std::lcm(m, n); | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| // [1, p. 4], below eq. 3 | ||
| int64_t a = lcm / m; | ||
| int64_t b = lcm / n; | ||
|
Comment on lines
+305
to
+306
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| // Combine Eq. 9 in [2] with Eq. 2 in [1] and solve for $\zeta$ | ||
| // Hint: `s` is $U$ in [2], and $T_2$ in [1] is $T$ in [2] | ||
| int64_t mn = m * n; | ||
| // Uses double floor division since s is double | ||
| int64_t zeta = std::floor((lcm * lcm * (m + n) * (6.0 * s - mn * (4.0 * mn - 1))) / (6.0 * mn * mn)); | ||
|
Comment on lines
+307
to
+311
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using consistent notations and following the Python comment: From [1, page 3] Rewriting in terms of and from [2, eq. 9] (and using the same notations as in [1])
Simple algebra gives We then floor as |
||
| int64_t combinations = comb(m + n, m); | ||
| // Each frequency table maps value -> frequency, | ||
| // mirroring the 2-row numpy array where row 0 = values, row 1 = frequencies | ||
| using FreqTable = std::map<int64_t, int64_t>; | ||
| // the frequency table of g_{u, v}^+ defined in [1, p. 6] | ||
| // gs[0] = {0: 1}, gs[1..m] = empty | ||
| std::vector<FreqTable> gs(m + 1); | ||
|
Comment on lines
+315
to
+318
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mdhaber, you want this to be able to work in CuPy right? Also, rather than putting this whole thing into a scalar kernel, and making a ufunc out it, it seems like it may be better to take the idea of the original function and decompose it into the ufuncs that are needed to make it work. I think the key component here would be a gufunc to generate the frequency table.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes. OK, good to know. |
||
| gs[0][0] = 1; | ||
|
Comment on lines
+318
to
+319
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initialize the frequency tables of Each table is a g = [[zeta_1, zeta_2, ..., zeta_k],
[count_1, count_2, ..., count_k]]row 0 = distinct values of Initially: |
||
| for (int64_t u = 0; u < n + 1; ++u) { | ||
| std::vector<FreqTable> next_gs; | ||
| FreqTable tmp; | ||
| for (int64_t v = 0; v < m + 1; ++v) { | ||
| // Calculate g recursively with eq. 11 in [1]. Even though it | ||
| // doesn't look like it, this also does 12/13 (all of Algorithm 1). | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears that (12) is not needed, as this is produced automatically by how For Starting from Here, we use that |
||
| const FreqTable &g = gs[v]; | ||
| // Merge tmp and g: for common keys sum frequencies, | ||
| // keep unique keys from both sides. | ||
| // (equivalent to np.intersect1d + concatenate logic) | ||
| FreqTable merged; | ||
| for (const auto &[key, freq] : tmp) { | ||
| merged[key] += freq; | ||
| } | ||
| for (const auto &[key, freq] : g) { | ||
| merged[key] += freq; | ||
| } | ||
|
Comment on lines
+326
to
+336
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Translation of the original Python code. vi, i0, i1 = np.intersect1d(tmp[0], g[0], return_indices=True)
tmp = np.concatenate([
np.stack([vi, tmp[1, i0] + g[1, i1]]),
np.delete(tmp, i0, 1),
np.delete(g, i1, 1)
], 1)
|
||
| int64_t diff = a * v - b * u; | ||
| int64_t res = diff * diff; | ||
| // tmp[0] += res (shift all keys by res) | ||
| tmp.clear(); | ||
| for (const auto &[key, freq] : merged) { | ||
| tmp[key + res] += freq; | ||
| } | ||
|
Comment on lines
+337
to
+343
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| next_gs.push_back(tmp); | ||
| } | ||
| gs = std::move(next_gs); | ||
| } | ||
|
Comment on lines
+320
to
+347
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you provide additional information that helps the reader compare this to either the original code, which is pretty hard to follow, or the algorithm, which looks simple but has terms like
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree; I've found the Python code quite difficult to follow. I’ll provide more detail and include screenshots of the relevant sections of the paper in the coming days. |
||
| // (equivalent to return np.float64(np.sum(freq[value >= zeta]) / combinations)) | ||
| const FreqTable &final_table = gs[m]; | ||
| int64_t sum_freq = 0; | ||
| for (const auto &[value, freq] : final_table) { | ||
| if (value >= zeta) { | ||
| sum_freq += freq; | ||
| } | ||
| } | ||
| return static_cast<double>(sum_freq) / static_cast<double>(combinations); | ||
| } | ||
|
Comment on lines
+348
to
+357
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compute p-value |
||
|
|
||
| inline double smirnov(int n, double x) { return cephes::smirnov(n, x); } | ||
|
|
||
| inline double smirnovc(int n, double x) { return cephes::smirnovc(n, x); } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| #include "../testing_utils.h" | ||
| #include <xsf/stats.h> | ||
|
|
||
| /* | ||
| // Reference values computed with scipy.stats._hypotests._pval_cvm_2samp_exact | ||
|
|
||
| import numpy as np | ||
| from scipy import stats | ||
|
|
||
| rng = np.random.default_rng(seed=42) | ||
|
|
||
| list_m = rng.integers(3, 30, size=5) | ||
| list_n = rng.integers(3, 30, size=5) | ||
| rtol = 1e-10 | ||
|
|
||
| for m, n in zip(list_m, list_n): | ||
| x = rng.standard_normal(m) | ||
| y = rng.standard_normal(n) | ||
| res = stats.cramervonmises_2samp(x, y, method="exact") | ||
| T = res.statistic | ||
| # Convert normalized statistic T to the unnormalized U | ||
| U = m * n * (m + n) * T + m * n * (4 * m * n - 1) / 6 | ||
| p_value = stats._hypotests._pval_cvm_2samp_exact(U, m, n) | ||
| assert np.isclose(res.pvalue, p_value, rtol=rtol), "The p-values do not match!" | ||
| print(f"U={U}, m={m}, n={n}, p-value={p_value}") | ||
| */ | ||
| TEST_CASE("pval_cvm_2samp_exact test", "[pval_cvm_2samp_exact][xsf_tests]") { | ||
| using test_case = std::tuple<double, int, int, double, double>; | ||
| auto [s, m, n, pval_expected, rtol] = GENERATE( | ||
| test_case{12559.0, 5, 26, 0.11812654860485784, 1e-10}, test_case{8901.0, 23, 5, 0.9907610907610908, 1e-10}, | ||
| test_case{119376.0, 20, 21, 0.5716351061359124, 1e-10}, test_case{8862.0, 14, 8, 0.2679738562091503, 1e-10}, | ||
| test_case{3491.0000000000005, 14, 5, 0.34657722738218094, 1e-10} | ||
| ); | ||
| const auto pval = xsf::pval_cvm_2samp_exact(s, m, n); | ||
| const auto rel_error = xsf::extended_relative_error(pval, pval_expected); | ||
| CAPTURE(s, m, n, pval, pval_expected, rel_error); | ||
| REQUIRE(rel_error <= rtol); | ||
| } |



Uh oh!
There was an error while loading. Please reload this page.