Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 35 additions & 27 deletions src/FSharp.Stats/Correlation.fs
Original file line number Diff line number Diff line change
Expand Up @@ -140,33 +140,41 @@ module Correlation =
/// </code>
/// </example>
let inline pearsonWeighted (seq1:seq<'T>) (seq2:seq<'T>) (weights:seq<'T>) : float =
// TODO: solve in a prettier coding fashion
if Seq.length seq1 <> Seq.length seq2 || Seq.length seq2 <> Seq.length weights then failwithf "input arguments are not the same length"
let zero = LanguagePrimitives.GenericZero< 'T >
let one = LanguagePrimitives.GenericOne<'T>
let weightedMean xVal wVal =
let a = Seq.fold2 (fun acc xi wi -> acc + (xi * wi)) zero xVal wVal |> float
let b = Seq.sum wVal|> float
a / b
let weightedCoVariance xVal yVal wVal =
let weightedMeanXW = weightedMean xVal wVal
let weightedMeanYW = weightedMean yVal wVal
let a =
Seq.map3 (fun xi yi wi ->
(float wi) * ((float xi) - weightedMeanXW) * ((float yi) - weightedMeanYW)
) xVal yVal wVal
|> Seq.sum
let b =
Seq.sum wVal
|> float
a / b
let weightedCorrelation xVal yVal wVal =
let a = weightedCoVariance xVal yVal wVal
let b =
(weightedCoVariance xVal xVal wVal) * (weightedCoVariance yVal yVal wVal)
|> sqrt
a / b
weightedCorrelation seq1 seq2 weights
// Convert to arrays once (3 passes), then compute in 2 passes instead of the
// previous ~12 passes (3 Seq.length + 3x weightedCoVariance x 3 sub-passes each).
let xs = Array.ofSeq seq1
let ys = Array.ofSeq seq2
let ws = Array.ofSeq weights
let n = xs.Length
if n <> ys.Length || n <> ws.Length then
failwithf "input arguments are not the same length"
if n = 0 then nan
else
// Pass 1: compute weighted means
let mutable wSum = 0.0
let mutable wxSum = 0.0
let mutable wySum = 0.0
for i = 0 to n - 1 do
let w = float ws.[i]
wSum <- wSum + w
wxSum <- wxSum + w * float xs.[i]
wySum <- wySum + w * float ys.[i]
let xMean = wxSum / wSum
let yMean = wySum / wSum
// Pass 2: compute weighted covariance and variances from the means
let mutable covSum = 0.0
let mutable varXSum = 0.0
let mutable varYSum = 0.0
for i = 0 to n - 1 do
let w = float ws.[i]
let dx = float xs.[i] - xMean
let dy = float ys.[i] - yMean
covSum <- covSum + w * dx * dy
varXSum <- varXSum + w * dx * dx
varYSum <- varYSum + w * dy * dy
// The wSum denominator cancels in the ratio, so the result is:
// cov(x,y) / sqrt(var(x) * var(y))
covSum / sqrt (varXSum * varYSum)

/// <summary>
/// Calculates the weighted pearson correlation of two samples.
Expand Down
47 changes: 47 additions & 0 deletions tests/FSharp.Stats.Tests/Correlation.fs
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,50 @@ let spearmanCorrelationTests =
Expect.floatClose Accuracy.high testCase5 0.6887298748 "Should be equal (double precision)"
Expect.floatClose Accuracy.high testCase6 -0.632455532 "Should be equal (double precision)"
]

[<Tests>]
let pearsonWeightedTests =
// Reference values verified with R:
// x <- c(1.1, 1.1, 1.2); y <- c(1.2, 0.9, 0.08); w <- c(0.2, 0.3, 0.5)
// library(weights); wtd.cor(x, y, weight = w)[1] # -0.9764159
//
// x2 <- c(1.0, 2.0, 3.0, 4.0); y2 <- c(2.0, 4.0, 6.0, 8.0); w2 <- c(1.0, 1.0, 1.0, 1.0)
// wtd.cor(x2, y2, weight = w2) # 1.0 (perfectly correlated, uniform weights)
//
// x3 <- c(1.0, 2.0, 3.0, 4.0); y3 <- c(8.0, 6.0, 4.0, 2.0); w3 <- c(1.0, 1.0, 1.0, 1.0)
// wtd.cor(x3, y3, weight = w3) # -1.0 (perfectly anti-correlated)

let x1 = [1.1; 1.1; 1.2]
let y1 = [1.2; 0.9; 0.08]
let w1 = [0.2; 0.3; 0.5]

let x2 = [1.0; 2.0; 3.0; 4.0]
let y2 = [2.0; 4.0; 6.0; 8.0]
let wUniform = [1.0; 1.0; 1.0; 1.0]

let x3 = [1.0; 2.0; 3.0; 4.0]
let y3 = [8.0; 6.0; 4.0; 2.0]

testList "Correlation.Seq.pearsonWeighted" [
testCase "docstring example" <| fun () ->
// matches the example in pearsonWeightedOfTriples docstring
let r = Seq.pearsonWeighted x1 y1 w1
Expect.floatClose Accuracy.high r -0.9764158959 "weighted Pearson should match reference"

testCase "ofTriples matches pearsonWeighted" <| fun () ->
let rDirect = Seq.pearsonWeighted x1 y1 w1
let rTriples = Seq.zip3 x1 y1 w1 |> Seq.pearsonWeightedOfTriples
Expect.floatClose Accuracy.veryHigh rDirect rTriples "pearsonWeighted and pearsonWeightedOfTriples should agree"

testCase "uniform weights: perfect positive correlation" <| fun () ->
let r = Seq.pearsonWeighted x2 y2 wUniform
Expect.floatClose Accuracy.veryHigh r 1.0 "uniform-weighted Pearson should be 1.0 for perfectly correlated data"

testCase "uniform weights: perfect negative correlation" <| fun () ->
let r = Seq.pearsonWeighted x3 y3 wUniform
Expect.floatClose Accuracy.veryHigh r -1.0 "uniform-weighted Pearson should be -1.0 for perfectly anti-correlated data"

testCase "mismatched length throws" <| fun () ->
Expect.throws (fun () -> Seq.pearsonWeighted [1.0; 2.0] [1.0] [1.0; 1.0] |> ignore)
"should throw for mismatched sequence lengths"
]
Loading