diff --git a/tests/gcm/test_auto.py b/tests/gcm/test_auto.py index 1e28231087..26fe2df9a8 100644 --- a/tests/gcm/test_auto.py +++ b/tests/gcm/test_auto.py @@ -41,7 +41,9 @@ def _generate_non_linear_regression_data(): def _generate_linear_classification_data(): - X = np.random.normal(0, 1, (100, 5)) + # 500 samples instead of 100 to give cross-validation a stable enough signal + # so that LogisticRegression reliably wins over SVC for linear data. + X = np.random.normal(0, 1, (500, 5)) Y = (np.sum(X * np.random.uniform(-5, 5, X.shape[1]), axis=1) > 0).astype(str) return X, Y