-
Notifications
You must be signed in to change notification settings - Fork 297
Add ComplexDotProduct to SVE microbenchmark #5045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9c562f7
Add ComplexDotProduct to SVE microbenchmark
ylpoonlg 2a4f6f3
Merge branch 'main' into github-complexdot
ylpoonlg ecd6f2c
Add SVE Category
ylpoonlg 67e0d51
Disable when version < net10.0
ylpoonlg 67e09fd
Use local function and fix verify size
ylpoonlg 29e2cad
Merge branch 'main' into github-complexdot
LoopedBard3 69345c2
Merge branch 'main' into github-complexdot
DrewScoggins File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,285 @@ | ||
| using System; | ||
| using System.Diagnostics; | ||
| using System.Numerics; | ||
| using System.Runtime.Intrinsics; | ||
| using System.Runtime.Intrinsics.Arm; | ||
| using BenchmarkDotNet.Attributes; | ||
| using BenchmarkDotNet.Extensions; | ||
| using BenchmarkDotNet.Configs; | ||
| using BenchmarkDotNet.Filters; | ||
| using MicroBenchmarks; | ||
|
|
||
| namespace SveBenchmarks | ||
| { | ||
| [BenchmarkCategory(Categories.Sve)] | ||
| [OperatingSystemsArchitectureFilter(allowed: true, System.Runtime.InteropServices.Architecture.Arm64)] | ||
| [Config(typeof(Config))] | ||
| public class ComplexDotProduct | ||
| { | ||
| private class Config : ManualConfig | ||
| { | ||
| public Config() | ||
| { | ||
| AddFilter(new SimpleFilter(_ => Sve2.IsSupported)); | ||
ylpoonlg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
|
|
||
| [Params(15, 127, 527, 10015)] | ||
| public int Size; | ||
|
|
||
| private sbyte[] _input1; | ||
| private sbyte[] _input2; | ||
| private int[] _output; | ||
|
|
||
| [GlobalSetup] | ||
| public virtual void Setup() | ||
| { | ||
| _input1 = new sbyte[Size * 4]; | ||
| _input2 = new sbyte[Size * 4]; | ||
|
|
||
| sbyte[] vals = ValuesGenerator.Array<sbyte>(Size * 8); | ||
| for (int i = 0; i < Size * 4; i++) | ||
| { | ||
| _input1[i] = vals[i]; | ||
| _input2[i] = vals[Size * 4 + i]; | ||
| } | ||
|
|
||
| _output = new int[Size * 2]; | ||
| } | ||
|
|
||
| [GlobalCleanup] | ||
| public virtual void Verify() | ||
| { | ||
| int[] current = (int[])_output.Clone(); | ||
| Setup(); | ||
| Scalar(); | ||
| int[] scalar = (int[])_output.Clone(); | ||
|
|
||
| // Check that the result is the same as scalar. | ||
| for (int i = 0; i < Size; i++) | ||
ylpoonlg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
| Debug.Assert(current[i] == scalar[i]); | ||
| } | ||
| } | ||
|
|
||
| // The following algorithms are adapted from the Arm simd-loops repository: | ||
| // https://gitlab.arm.com/architecture/simd-loops/-/blob/main/loops/loop_110.c | ||
|
|
||
| [Benchmark] | ||
| public unsafe void Scalar() | ||
| { | ||
| fixed (sbyte* a = _input1, b = _input2) | ||
| fixed (int* c = _output) | ||
| { | ||
| for (int i = 0; i < Size; i++) | ||
| { | ||
| // Real part. | ||
| c[i * 2] = (int)(((a[4 * i] * b[4 * i]) - (a[4 * i + 1] * b[4 * i + 1])) + | ||
| ((a[4 * i + 2] * b[4 * i + 2]) - (a[4 * i + 3] * b[4 * i + 3]))); | ||
| // Imaginary part. | ||
| c[i * 2 + 1] = (int)(((a[4 * i + 1] * b[4 * i]) + (a[4 * i] * b[4 * i + 1])) + | ||
| ((a[4 * i + 3] * b[4 * i + 2]) + (a[4 * i + 2] * b[4 * i + 3]))); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| [Benchmark] | ||
| public unsafe void Vector128ComplexDotProduct() | ||
| { | ||
| fixed (sbyte* a = _input1, b = _input2) | ||
| fixed (int* c = _output) | ||
| { | ||
|
|
||
| // Helper function for calculating dot product. | ||
| Func<Vector128<int>, Vector128<int>, Vector128<int>, Vector128<int>, | ||
| Vector128<int>, Vector128<int>, Vector128<int>, Vector128<int>, | ||
| (Vector128<int>, Vector128<int>)> dotProduct = (a0, a1, a2, a3, b0, b1, b2, b3) => { | ||
|
|
||
| Vector128<int> cr = Vector128<int>.Zero; | ||
| Vector128<int> ci = Vector128<int>.Zero; | ||
|
|
||
| // Compute dot product of the real part. | ||
| cr = AdvSimd.MultiplyAdd(cr, a0, b0); | ||
| cr = AdvSimd.MultiplySubtract(cr, a1, b1); | ||
| cr = AdvSimd.MultiplyAdd(cr, a2, b2); | ||
| cr = AdvSimd.MultiplySubtract(cr, a3, b3); | ||
| // Compute dot product of the imaginary part. | ||
| ci = AdvSimd.MultiplyAdd(ci, a0, b1); | ||
| ci = AdvSimd.MultiplyAdd(ci, a1, b0); | ||
| ci = AdvSimd.MultiplyAdd(ci, a2, b3); | ||
| ci = AdvSimd.MultiplyAdd(ci, a3, b2); | ||
|
|
||
| return (cr, ci); | ||
| }; | ||
ylpoonlg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| int i = 0; | ||
| int lmt = Size - 16; | ||
| for (; i <= lmt; i += 16) | ||
| { | ||
| // Load inputs as sbytes. | ||
| (Vector128<sbyte> a0, Vector128<sbyte> a1, Vector128<sbyte> a2, Vector128<sbyte> a3) = AdvSimd.Arm64.Load4xVector128AndUnzip(a + 4 * i); | ||
| (Vector128<sbyte> b0, Vector128<sbyte> b1, Vector128<sbyte> b2, Vector128<sbyte> b3) = AdvSimd.Arm64.Load4xVector128AndUnzip(b + 4 * i); | ||
| Vector128<int> cr, ci; | ||
|
|
||
| // Extend sbyte to short for each input component. | ||
| Vector128<short> a0l = AdvSimd.SignExtendWideningLower(a0.GetLower()); | ||
| Vector128<short> a0h = AdvSimd.SignExtendWideningUpper(a0); | ||
| Vector128<short> a1l = AdvSimd.SignExtendWideningLower(a1.GetLower()); | ||
| Vector128<short> a1h = AdvSimd.SignExtendWideningUpper(a1); | ||
| Vector128<short> a2l = AdvSimd.SignExtendWideningLower(a2.GetLower()); | ||
| Vector128<short> a2h = AdvSimd.SignExtendWideningUpper(a2); | ||
| Vector128<short> a3l = AdvSimd.SignExtendWideningLower(a3.GetLower()); | ||
| Vector128<short> a3h = AdvSimd.SignExtendWideningUpper(a3); | ||
| Vector128<short> b0l = AdvSimd.SignExtendWideningLower(b0.GetLower()); | ||
| Vector128<short> b0h = AdvSimd.SignExtendWideningUpper(b0); | ||
| Vector128<short> b1l = AdvSimd.SignExtendWideningLower(b1.GetLower()); | ||
| Vector128<short> b1h = AdvSimd.SignExtendWideningUpper(b1); | ||
| Vector128<short> b2l = AdvSimd.SignExtendWideningLower(b2.GetLower()); | ||
| Vector128<short> b2h = AdvSimd.SignExtendWideningUpper(b2); | ||
| Vector128<short> b3l = AdvSimd.SignExtendWideningLower(b3.GetLower()); | ||
| Vector128<short> b3h = AdvSimd.SignExtendWideningUpper(b3); | ||
|
|
||
| // Compute the lower half of the lower half. | ||
| Vector128<int> a0ll = AdvSimd.SignExtendWideningLower(a0l.GetLower()); | ||
| Vector128<int> a1ll = AdvSimd.SignExtendWideningLower(a1l.GetLower()); | ||
| Vector128<int> a2ll = AdvSimd.SignExtendWideningLower(a2l.GetLower()); | ||
| Vector128<int> a3ll = AdvSimd.SignExtendWideningLower(a3l.GetLower()); | ||
| Vector128<int> b0ll = AdvSimd.SignExtendWideningLower(b0l.GetLower()); | ||
| Vector128<int> b1ll = AdvSimd.SignExtendWideningLower(b1l.GetLower()); | ||
| Vector128<int> b2ll = AdvSimd.SignExtendWideningLower(b2l.GetLower()); | ||
| Vector128<int> b3ll = AdvSimd.SignExtendWideningLower(b3l.GetLower()); | ||
| (cr, ci) = dotProduct(a0ll, a1ll, a2ll, a3ll, b0ll, b1ll, b2ll, b3ll); | ||
| AdvSimd.Arm64.StoreVectorAndZip(c + 2 * i, (cr, ci)); | ||
|
|
||
| // Compute the upper half of the lower half. | ||
| Vector128<int> a0lh = AdvSimd.SignExtendWideningUpper(a0l); | ||
| Vector128<int> a1lh = AdvSimd.SignExtendWideningUpper(a1l); | ||
| Vector128<int> a2lh = AdvSimd.SignExtendWideningUpper(a2l); | ||
| Vector128<int> a3lh = AdvSimd.SignExtendWideningUpper(a3l); | ||
| Vector128<int> b0lh = AdvSimd.SignExtendWideningUpper(b0l); | ||
| Vector128<int> b1lh = AdvSimd.SignExtendWideningUpper(b1l); | ||
| Vector128<int> b2lh = AdvSimd.SignExtendWideningUpper(b2l); | ||
| Vector128<int> b3lh = AdvSimd.SignExtendWideningUpper(b3l); | ||
| (cr, ci) = dotProduct(a0lh, a1lh, a2lh, a3lh, b0lh, b1lh, b2lh, b3lh); | ||
| AdvSimd.Arm64.StoreVectorAndZip(c + 2 * i + 8, (cr, ci)); | ||
|
|
||
| // Compute the lower half of the upper half. | ||
| Vector128<int> a0hl = AdvSimd.SignExtendWideningLower(a0h.GetLower()); | ||
| Vector128<int> a1hl = AdvSimd.SignExtendWideningLower(a1h.GetLower()); | ||
| Vector128<int> a2hl = AdvSimd.SignExtendWideningLower(a2h.GetLower()); | ||
| Vector128<int> a3hl = AdvSimd.SignExtendWideningLower(a3h.GetLower()); | ||
| Vector128<int> b0hl = AdvSimd.SignExtendWideningLower(b0h.GetLower()); | ||
| Vector128<int> b1hl = AdvSimd.SignExtendWideningLower(b1h.GetLower()); | ||
| Vector128<int> b2hl = AdvSimd.SignExtendWideningLower(b2h.GetLower()); | ||
| Vector128<int> b3hl = AdvSimd.SignExtendWideningLower(b3h.GetLower()); | ||
| (cr, ci) = dotProduct(a0hl, a1hl, a2hl, a3hl, b0hl, b1hl, b2hl, b3hl); | ||
| AdvSimd.Arm64.StoreVectorAndZip(c + 2 * i + 16, (cr, ci)); | ||
|
|
||
| // Compute the upper half of the upper half. | ||
| Vector128<int> a0hh = AdvSimd.SignExtendWideningUpper(a0h); | ||
| Vector128<int> a1hh = AdvSimd.SignExtendWideningUpper(a1h); | ||
| Vector128<int> a2hh = AdvSimd.SignExtendWideningUpper(a2h); | ||
| Vector128<int> a3hh = AdvSimd.SignExtendWideningUpper(a3h); | ||
| Vector128<int> b0hh = AdvSimd.SignExtendWideningUpper(b0h); | ||
| Vector128<int> b1hh = AdvSimd.SignExtendWideningUpper(b1h); | ||
| Vector128<int> b2hh = AdvSimd.SignExtendWideningUpper(b2h); | ||
| Vector128<int> b3hh = AdvSimd.SignExtendWideningUpper(b3h); | ||
| (cr, ci) = dotProduct(a0hh, a1hh, a2hh, a3hh, b0hh, b1hh, b2hh, b3hh); | ||
| AdvSimd.Arm64.StoreVectorAndZip(c + 2 * i + 24, (cr, ci)); | ||
|
|
||
| } | ||
| // Handle remaining elements. | ||
| for (; i < Size; i++) | ||
| { | ||
| // Real part. | ||
| c[i * 2] = (int)(((a[4 * i] * b[4 * i]) - (a[4 * i + 1] * b[4 * i + 1])) + | ||
| ((a[4 * i + 2] * b[4 * i + 2]) - (a[4 * i + 3] * b[4 * i + 3]))); | ||
| // Imaginary part. | ||
| c[i * 2 + 1] = (int)(((a[4 * i + 1] * b[4 * i]) + (a[4 * i] * b[4 * i + 1])) + | ||
| ((a[4 * i + 3] * b[4 * i + 2]) + (a[4 * i + 2] * b[4 * i + 3]))); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| [Benchmark] | ||
| public unsafe void SveComplexDotProduct() | ||
| { | ||
| fixed (sbyte* a = _input1, b = _input2) | ||
| fixed (int* c = _output) | ||
| { | ||
| int i = 0; | ||
| int cntw = (int)Sve.Count32BitElements(); | ||
|
|
||
| // Create mask for the imaginary half of a word. | ||
| Vector<short> imMask = (Vector<short>)(new Vector<uint>(0xFFFF0000u)); | ||
| Vector<int> pTrue = Sve.CreateTrueMaskInt32(); | ||
| Vector<int> pLoop = (Vector<int>)Sve.CreateWhileLessThanMask32Bit(0, Size); | ||
| while (Sve.TestFirstTrue(pTrue, pLoop)) | ||
| { | ||
| // Load inputs. | ||
| Vector<sbyte> a1 = (Vector<sbyte>)Sve.LoadVector(pLoop, (int*)(a + 4 * i)); | ||
| Vector<sbyte> b1 = (Vector<sbyte>)Sve.LoadVector(pLoop, (int*)(b + 4 * i)); | ||
|
|
||
| // Unpack and extend the lower and upper halves to short. | ||
| Vector<short> a1l = Sve.SignExtendWideningLower(a1); | ||
| Vector<short> a1h = Sve.SignExtendWideningUpper(a1); | ||
| Vector<short> b1l = Sve.SignExtendWideningLower(b1); | ||
| Vector<short> b1h = Sve.SignExtendWideningUpper(b1); | ||
|
|
||
| // Negate the imaginary components. | ||
| Vector<short> b1l_re = Sve.ConditionalSelect(imMask, Sve.Negate(b1l), b1l); | ||
| Vector<short> b1h_re = Sve.ConditionalSelect(imMask, Sve.Negate(b1h), b1h); | ||
| // Swap the real and imaginary components from each word. | ||
| Vector<short> b1l_im = (Vector<short>)Sve.ReverseElement16((Vector<int>)b1l); | ||
| Vector<short> b1h_im = (Vector<short>)Sve.ReverseElement16((Vector<int>)b1h); | ||
|
|
||
| // Compute dot products (real and imaginary, low and high bits). | ||
| Vector<long> c1l_re = Sve.DotProduct(Vector<long>.Zero, a1l, b1l_re); | ||
| Vector<long> c1h_re = Sve.DotProduct(Vector<long>.Zero, a1h, b1h_re); | ||
| Vector<long> c1l_im = Sve.DotProduct(Vector<long>.Zero, a1l, b1l_im); | ||
| Vector<long> c1h_im = Sve.DotProduct(Vector<long>.Zero, a1h, b1h_im); | ||
|
|
||
| // Combine low and high parts back together. | ||
| Vector<int> cr = Sve.UnzipEven((Vector<int>)c1l_re, (Vector<int>)c1h_re); | ||
| Vector<int> ci = Sve.UnzipEven((Vector<int>)c1l_im, (Vector<int>)c1h_im); | ||
|
|
||
| // Interleaved store real and imaginary parts of the result. | ||
| Sve.StoreAndZip(pLoop, c + 2 * i, (cr, ci)); | ||
|
|
||
| // Handle loop. | ||
| i += cntw; | ||
| pLoop = (Vector<int>)Sve.CreateWhileLessThanMask32Bit(i, Size); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| [Benchmark] | ||
| public unsafe void Sve2ComplexDotProduct() | ||
| { | ||
| fixed (sbyte* a = _input1, b = _input2) | ||
| fixed (int* c = _output) | ||
| { | ||
| int i = 0; | ||
| int cntw = (int)Sve.Count32BitElements(); | ||
|
|
||
| Vector<int> pTrue = Sve.CreateTrueMaskInt32(); | ||
| Vector<int> pLoop = (Vector<int>)Sve.CreateWhileLessThanMask32Bit(0, Size); | ||
| while (Sve.TestFirstTrue(pTrue, pLoop)) | ||
| { | ||
| Vector<sbyte> a1 = (Vector<sbyte>)Sve.LoadVector(pLoop, (int*)(a + 4 * i)); | ||
| Vector<sbyte> b1 = (Vector<sbyte>)Sve.LoadVector(pLoop, (int*)(b + 4 * i)); | ||
|
|
||
| Vector<int> cr = Sve2.DotProductRotateComplex(Vector<int>.Zero, a1, b1, 0); | ||
| Vector<int> ci = Sve2.DotProductRotateComplex(Vector<int>.Zero, a1, b1, 1); | ||
|
|
||
| Sve.StoreAndZip(pLoop, c + 2 * i, (cr, ci)); | ||
|
|
||
| // Handle loop. | ||
| i += cntw; | ||
| pLoop = (Vector<int>)Sve.CreateWhileLessThanMask32Bit(i, Size); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.