diff --git a/src/FSharp.Stats/Correlation.fs b/src/FSharp.Stats/Correlation.fs index 5182b213..d930ad24 100644 --- a/src/FSharp.Stats/Correlation.fs +++ b/src/FSharp.Stats/Correlation.fs @@ -140,33 +140,41 @@ module Correlation = /// /// let inline pearsonWeighted (seq1:seq<'T>) (seq2:seq<'T>) (weights:seq<'T>) : float = - // TODO: solve in a prettier coding fashion - if Seq.length seq1 <> Seq.length seq2 || Seq.length seq2 <> Seq.length weights then failwithf "input arguments are not the same length" - let zero = LanguagePrimitives.GenericZero< 'T > - let one = LanguagePrimitives.GenericOne<'T> - let weightedMean xVal wVal = - let a = Seq.fold2 (fun acc xi wi -> acc + (xi * wi)) zero xVal wVal |> float - let b = Seq.sum wVal|> float - a / b - let weightedCoVariance xVal yVal wVal = - let weightedMeanXW = weightedMean xVal wVal - let weightedMeanYW = weightedMean yVal wVal - let a = - Seq.map3 (fun xi yi wi -> - (float wi) * ((float xi) - weightedMeanXW) * ((float yi) - weightedMeanYW) - ) xVal yVal wVal - |> Seq.sum - let b = - Seq.sum wVal - |> float - a / b - let weightedCorrelation xVal yVal wVal = - let a = weightedCoVariance xVal yVal wVal - let b = - (weightedCoVariance xVal xVal wVal) * (weightedCoVariance yVal yVal wVal) - |> sqrt - a / b - weightedCorrelation seq1 seq2 weights + // Convert to arrays once (3 passes), then compute in 2 passes instead of the + // previous ~12 passes (3 Seq.length + 3x weightedCoVariance x 3 sub-passes each). + let xs = Array.ofSeq seq1 + let ys = Array.ofSeq seq2 + let ws = Array.ofSeq weights + let n = xs.Length + if n <> ys.Length || n <> ws.Length then + failwithf "input arguments are not the same length" + if n = 0 then nan + else + // Pass 1: compute weighted means + let mutable wSum = 0.0 + let mutable wxSum = 0.0 + let mutable wySum = 0.0 + for i = 0 to n - 1 do + let w = float ws.[i] + wSum <- wSum + w + wxSum <- wxSum + w * float xs.[i] + wySum <- wySum + w * float ys.[i] + let xMean = wxSum / wSum + let yMean = wySum / wSum + // Pass 2: compute weighted covariance and variances from the means + let mutable covSum = 0.0 + let mutable varXSum = 0.0 + let mutable varYSum = 0.0 + for i = 0 to n - 1 do + let w = float ws.[i] + let dx = float xs.[i] - xMean + let dy = float ys.[i] - yMean + covSum <- covSum + w * dx * dy + varXSum <- varXSum + w * dx * dx + varYSum <- varYSum + w * dy * dy + // The wSum denominator cancels in the ratio, so the result is: + // cov(x,y) / sqrt(var(x) * var(y)) + covSum / sqrt (varXSum * varYSum) /// /// Calculates the weighted pearson correlation of two samples. diff --git a/tests/FSharp.Stats.Tests/Correlation.fs b/tests/FSharp.Stats.Tests/Correlation.fs index 926952f8..1252d02d 100644 --- a/tests/FSharp.Stats.Tests/Correlation.fs +++ b/tests/FSharp.Stats.Tests/Correlation.fs @@ -362,3 +362,50 @@ let spearmanCorrelationTests = Expect.floatClose Accuracy.high testCase5 0.6887298748 "Should be equal (double precision)" Expect.floatClose Accuracy.high testCase6 -0.632455532 "Should be equal (double precision)" ] + +[] +let pearsonWeightedTests = + // Reference values verified with R: + // x <- c(1.1, 1.1, 1.2); y <- c(1.2, 0.9, 0.08); w <- c(0.2, 0.3, 0.5) + // library(weights); wtd.cor(x, y, weight = w)[1] # -0.9764159 + // + // x2 <- c(1.0, 2.0, 3.0, 4.0); y2 <- c(2.0, 4.0, 6.0, 8.0); w2 <- c(1.0, 1.0, 1.0, 1.0) + // wtd.cor(x2, y2, weight = w2) # 1.0 (perfectly correlated, uniform weights) + // + // x3 <- c(1.0, 2.0, 3.0, 4.0); y3 <- c(8.0, 6.0, 4.0, 2.0); w3 <- c(1.0, 1.0, 1.0, 1.0) + // wtd.cor(x3, y3, weight = w3) # -1.0 (perfectly anti-correlated) + + let x1 = [1.1; 1.1; 1.2] + let y1 = [1.2; 0.9; 0.08] + let w1 = [0.2; 0.3; 0.5] + + let x2 = [1.0; 2.0; 3.0; 4.0] + let y2 = [2.0; 4.0; 6.0; 8.0] + let wUniform = [1.0; 1.0; 1.0; 1.0] + + let x3 = [1.0; 2.0; 3.0; 4.0] + let y3 = [8.0; 6.0; 4.0; 2.0] + + testList "Correlation.Seq.pearsonWeighted" [ + testCase "docstring example" <| fun () -> + // matches the example in pearsonWeightedOfTriples docstring + let r = Seq.pearsonWeighted x1 y1 w1 + Expect.floatClose Accuracy.high r -0.9764158959 "weighted Pearson should match reference" + + testCase "ofTriples matches pearsonWeighted" <| fun () -> + let rDirect = Seq.pearsonWeighted x1 y1 w1 + let rTriples = Seq.zip3 x1 y1 w1 |> Seq.pearsonWeightedOfTriples + Expect.floatClose Accuracy.veryHigh rDirect rTriples "pearsonWeighted and pearsonWeightedOfTriples should agree" + + testCase "uniform weights: perfect positive correlation" <| fun () -> + let r = Seq.pearsonWeighted x2 y2 wUniform + Expect.floatClose Accuracy.veryHigh r 1.0 "uniform-weighted Pearson should be 1.0 for perfectly correlated data" + + testCase "uniform weights: perfect negative correlation" <| fun () -> + let r = Seq.pearsonWeighted x3 y3 wUniform + Expect.floatClose Accuracy.veryHigh r -1.0 "uniform-weighted Pearson should be -1.0 for perfectly anti-correlated data" + + testCase "mismatched length throws" <| fun () -> + Expect.throws (fun () -> Seq.pearsonWeighted [1.0; 2.0] [1.0] [1.0; 1.0] |> ignore) + "should throw for mismatched sequence lengths" + ]