diff --git a/src/function/logical/nullish.js b/src/function/logical/nullish.js index cd3a7af52d..bed2173665 100644 --- a/src/function/logical/nullish.js +++ b/src/function/logical/nullish.js @@ -5,12 +5,12 @@ import { createMatAlgo13xDD } from '../../type/matrix/utils/matAlgo13xDD.js' import { DimensionError } from '../../error/DimensionError.js' const name = 'nullish' -const dependencies = ['typed', 'matrix', 'size', 'flatten', 'deepEqual'] +const dependencies = ['typed', 'matrix', 'size', 'deepEqual'] export const createNullish = /* #__PURE__ */ factory( name, dependencies, - ({ typed, matrix, size, flatten, deepEqual }) => { + ({ typed, matrix, size, deepEqual }) => { const matAlgo03xDSf = createMatAlgo03xDSf({ typed }) const matAlgo14xDs = createMatAlgo14xDs({ typed }) const matAlgo13xDD = createMatAlgo13xDD({ typed }) @@ -57,24 +57,35 @@ export const createNullish = /* #__PURE__ */ factory( // SparseMatrix-first with collection RHS: enforce exact shape match 'SparseMatrix, Array | Matrix': (x, y) => { - const sx = size(x) - const sy = size(y) - if (deepEqual(sx, sy)) return x - throw new DimensionError(sx, sy) + _validateSize(x, y) + return x }, // DenseMatrix-first handlers (no broadcasting between collections) - 'DenseMatrix, DenseMatrix': typed.referToSelf(self => (x, y) => matAlgo13xDD(x, y, self)), + 'DenseMatrix, DenseMatrix': typed.referToSelf(self => (x, y) => _matAlgo13xDDnoBroadcast(x, y, self) + ), 'DenseMatrix, SparseMatrix': typed.referToSelf(self => (x, y) => matAlgo03xDSf(x, y, self, false)), - 'DenseMatrix, Array': typed.referToSelf(self => (x, y) => matAlgo13xDD(x, matrix(y), self)), + 'DenseMatrix, Array': typed.referToSelf(self => (x, y) => _matAlgo13xDDnoBroadcast(x, y, self)), 'DenseMatrix, any': typed.referToSelf(self => (x, y) => matAlgo14xDs(x, y, self, false)), // Array-first handlers (bridge via matrix() where needed) - 'Array, Array': typed.referToSelf(self => (x, y) => matAlgo13xDD(matrix(x), matrix(y), self).valueOf()), - 'Array, DenseMatrix': typed.referToSelf(self => (x, y) => matAlgo13xDD(matrix(x), y, self)), + 'Array, Array': typed.referToSelf(self => (x, y) => _matAlgo13xDDnoBroadcast(x, y, self)), + 'Array, DenseMatrix': typed.referToSelf(self => (x, y) => matrix(_matAlgo13xDDnoBroadcast(x, y, self))), 'Array, SparseMatrix': typed.referToSelf(self => (x, y) => matAlgo03xDSf(matrix(x), y, self, false)), 'Array, any': typed.referToSelf(self => (x, y) => matAlgo14xDs(matrix(x), y, self, false).valueOf()) } ) + function _validateSize (x, y) { + const sx = size(x) + const sy = size(y) + if (!deepEqual(sx, sy)) { + throw new DimensionError(sx, sy) + } + } + // Remve broadcasting from matAlgo13xDD + function _matAlgo13xDDnoBroadcast (x, y, callback) { + _validateSize(x, y) + return matAlgo13xDD(x, y, callback) + } } ) diff --git a/src/type/matrix/utils/broadcast.js b/src/type/matrix/utils/broadcast.js index 6be714741b..85ad856bbe 100644 --- a/src/type/matrix/utils/broadcast.js +++ b/src/type/matrix/utils/broadcast.js @@ -5,15 +5,15 @@ import { deepStrictEqual } from '../../../utils/object.js' * Broadcasts two matrices, and return both in an array * It checks if it's possible with broadcasting rules * -* @param {Matrix} A First Matrix -* @param {Matrix} B Second Matrix +* @param {Matrix|Array} A First Matrix +* @param {Matrix|Array} B Second Matrix * -* @return {Matrix[]} [ broadcastedA, broadcastedB ] +* @return {Matrix|Array} [ broadcastedA, broadcastedB ] */ -export function broadcast (A, B) { +export function broadcastMatrices (A, B) { if (deepStrictEqual(A.size(), B.size())) { - // If matrices have the same size return them + // If matrices have the same size return them as such return [A, B] } @@ -38,3 +38,62 @@ function _broadcastTo (M, size) { } return M.create(broadcastTo(M.valueOf(), size), M.datatype()) } + +/** + * Recursively maps two arrays assuming they are rectangular, if a size is provided it's assumed + * to be validated already. No index is provided to the callback. + * + * @param {Array} array1 First array to broadcast + * @param {Array} array2 Second array to broadcast + * @param {number[]} [size1] Size of the first array + * @param {number[]} [size2] Size of the second array + * @param {function} callback The callback function to apply to each pair of elements + * @returns {Array} The resulting array after applying the callback to each pair of elements + */ +export function broadcast (array1, array2, size1, size2, callback) { + if (![array1, array2, size1, size2].every(Array.isArray)) { + throw new Error('Arrays and their sizes must be provided') + } + if (typeof callback !== 'function') { + throw new Error('Callback must be a function') + } + + if (size1.length <= 0 || size2.length <= 0) { + return { data: [], size: [0] } + } + + const finalSize = broadcastSizes(size1, size2) + const offset1 = finalSize.length - size1.length + const offset2 = finalSize.length - size2.length + const maxDepth = finalSize.length - 1 + return { data: iterate(array1, array2), size: finalSize } + + function iterate (array1, array2, depth = 0) { + const currentDimensionSize = finalSize[depth] + const result = Array(currentDimensionSize) + if (depth < maxDepth) { + for (let i = 0; i < currentDimensionSize; i++) { + const nextArray1 = offset1 > depth + ? array1 + : (array1.length === 1 ? array1[0] : array1[i]) + const nextArray2 = offset2 > depth + ? array2 + : (array2.length === 1 ? array2[0] : array2[i]) + result[i] = iterate( + nextArray1, + nextArray2, + depth + 1 + ) + } + } else { + for (let i = 0; i < currentDimensionSize; i++) { + result[i] = callback( + array1.length === 1 ? array1[0] : array1[i], + array2.length === 1 ? array2[0] : array2[i] + ) + } + } + + return result + } +} diff --git a/src/type/matrix/utils/matAlgo13xDD.js b/src/type/matrix/utils/matAlgo13xDD.js index 3ea14fd1c4..4e0d86f58e 100644 --- a/src/type/matrix/utils/matAlgo13xDD.js +++ b/src/type/matrix/utils/matAlgo13xDD.js @@ -1,5 +1,7 @@ import { factory } from '../../../utils/factory.js' -import { DimensionError } from '../../../error/DimensionError.js' +import { broadcast } from './broadcast.js' +import { isMatrix } from '../../../utils/is.js' +import { arraySize, validate } from '../../../utils/array.js' const name = 'matAlgo13xDD' const dependencies = ['typed'] @@ -11,36 +13,28 @@ export const createMatAlgo13xDD = /* #__PURE__ */ factory(name, dependencies, ({ * * C(i,j,...z) = f(Aij..z, Bij..z) * - * @param {Matrix} a The DenseMatrix instance (A) - * @param {Matrix} b The DenseMatrix instance (B) + * @param {Matrix|Array} a The DenseMatrix instance (A) + * @param {Matrix|Array} b The DenseMatrix instance (B) * @param {Function} callback The f(Aij..z,Bij..z) operation to invoke * - * @return {Matrix} DenseMatrix (C) + * @return {Matrix|Array} DenseMatrix (C) * * https://github.com/josdejong/mathjs/pull/346#issuecomment-97658658 */ return function matAlgo13xDD (a, b, callback) { // a arrays - const adata = a._data - const asize = a._size - const adt = a._datatype - // b arrays - const bdata = b._data - const bsize = b._size - const bdt = b._datatype - // c arrays - const csize = [] - - // validate dimensions - if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) } + const aIsMatrix = isMatrix(a) + const adata = aIsMatrix ? a._data : a + const asize = aIsMatrix ? a._size : arraySize(a) + if (!aIsMatrix) validate(adata, asize) + const adt = aIsMatrix ? a._datatype : undefined - // validate each one of the dimension sizes - for (let s = 0; s < asize.length; s++) { - // must match - if (asize[s] !== bsize[s]) { throw new RangeError('Dimension mismatch. Matrix A (' + asize + ') must match Matrix B (' + bsize + ')') } - // update dimension in c - csize[s] = asize[s] - } + // b arrays + const bIsMatrix = isMatrix(b) + const bdata = bIsMatrix ? b._data : b + const bsize = bIsMatrix ? b._size : arraySize(b) + if (!bIsMatrix) validate(bdata, bsize) + const bdt = bIsMatrix ? b._datatype : undefined // datatype let dt @@ -56,34 +50,16 @@ export const createMatAlgo13xDD = /* #__PURE__ */ factory(name, dependencies, ({ } // populate cdata, iterate through dimensions - const cdata = csize.length > 0 ? _iterate(cf, 0, csize, csize[0], adata, bdata) : [] - - // c matrix - return a.createDenseMatrix({ - data: cdata, - size: csize, - datatype: dt - }) - } + const cdata = broadcast(adata, bdata, asize, bsize, cf) - // recursive function - function _iterate (f, level, s, n, av, bv) { - // initialize array for this level - const cv = [] - // check we reach the last level - if (level === s.length - 1) { - // loop arrays in last level - for (let i = 0; i < n; i++) { - // invoke callback and store value - cv[i] = f(av[i], bv[i]) - } + if (aIsMatrix || bIsMatrix) { + const cMatrix = aIsMatrix ? a.createDenseMatrix() : b.createDenseMatrix() + cMatrix._data = cdata.data + cMatrix._size = cdata.size + cMatrix._datatype = dt + return cMatrix } else { - // iterate current level - for (let j = 0; j < n; j++) { - // iterate next level - cv[j] = _iterate(f, level + 1, s, s[level + 1], av[j], bv[j]) - } + return cdata.data } - return cv } }) diff --git a/src/type/matrix/utils/matrixAlgorithmSuite.js b/src/type/matrix/utils/matrixAlgorithmSuite.js index 216935b20e..ed627c4ec8 100644 --- a/src/type/matrix/utils/matrixAlgorithmSuite.js +++ b/src/type/matrix/utils/matrixAlgorithmSuite.js @@ -2,7 +2,7 @@ import { factory } from '../../../utils/factory.js' import { extend } from '../../../utils/object.js' import { createMatAlgo13xDD } from './matAlgo13xDD.js' import { createMatAlgo14xDs } from './matAlgo14xDs.js' -import { broadcast } from './broadcast.js' +import { broadcastMatrices } from './broadcast.js' const name = 'matrixAlgorithmSuite' const dependencies = ['typed', 'matrix'] @@ -36,71 +36,71 @@ export const createMatrixAlgorithmSuite = /* #__PURE__ */ factory( if (elop) { // First the dense ones matrixSignatures = { - 'DenseMatrix, DenseMatrix': (x, y) => matAlgo13xDD(...broadcast(x, y), elop), + 'DenseMatrix, DenseMatrix': (x, y) => matAlgo13xDD(x, y, elop), 'Array, Array': (x, y) => - matAlgo13xDD(...broadcast(matrix(x), matrix(y)), elop).valueOf(), - 'Array, DenseMatrix': (x, y) => matAlgo13xDD(...broadcast(matrix(x), y), elop), - 'DenseMatrix, Array': (x, y) => matAlgo13xDD(...broadcast(x, matrix(y)), elop) + matAlgo13xDD(matrix(x), matrix(y), elop).valueOf(), + 'Array, DenseMatrix': (x, y) => matAlgo13xDD(matrix(x), y, elop), + 'DenseMatrix, Array': (x, y) => matAlgo13xDD(x, matrix(y), elop) } // Now incorporate sparse matrices if (options.SS) { matrixSignatures['SparseMatrix, SparseMatrix'] = - (x, y) => options.SS(...broadcast(x, y), elop, false) + (x, y) => options.SS(...broadcastMatrices(x, y), elop, false) } if (options.DS) { matrixSignatures['DenseMatrix, SparseMatrix'] = - (x, y) => options.DS(...broadcast(x, y), elop, false) + (x, y) => options.DS(...broadcastMatrices(x, y), elop, false) matrixSignatures['Array, SparseMatrix'] = - (x, y) => options.DS(...broadcast(matrix(x), y), elop, false) + (x, y) => options.DS(...broadcastMatrices(matrix(x), y), elop, false) } if (SD) { matrixSignatures['SparseMatrix, DenseMatrix'] = - (x, y) => SD(...broadcast(y, x), elop, true) + (x, y) => SD(...broadcastMatrices(y, x), elop, true) matrixSignatures['SparseMatrix, Array'] = - (x, y) => SD(...broadcast(matrix(y), x), elop, true) + (x, y) => SD(...broadcastMatrices(matrix(y), x), elop, true) } } else { // No elop, use this // First the dense ones matrixSignatures = { 'DenseMatrix, DenseMatrix': typed.referToSelf(self => (x, y) => { - return matAlgo13xDD(...broadcast(x, y), self) + return matAlgo13xDD(x, y, self) }), 'Array, Array': typed.referToSelf(self => (x, y) => { - return matAlgo13xDD(...broadcast(matrix(x), matrix(y)), self).valueOf() + return matAlgo13xDD(matrix(x), matrix(y), self).valueOf() }), 'Array, DenseMatrix': typed.referToSelf(self => (x, y) => { - return matAlgo13xDD(...broadcast(matrix(x), y), self) + return matAlgo13xDD(matrix(x), y, self) }), 'DenseMatrix, Array': typed.referToSelf(self => (x, y) => { - return matAlgo13xDD(...broadcast(x, matrix(y)), self) + return matAlgo13xDD(x, matrix(y), self) }) } // Now incorporate sparse matrices if (options.SS) { matrixSignatures['SparseMatrix, SparseMatrix'] = typed.referToSelf(self => (x, y) => { - return options.SS(...broadcast(x, y), self, false) + return options.SS(...broadcastMatrices(x, y), self, false) }) } if (options.DS) { matrixSignatures['DenseMatrix, SparseMatrix'] = typed.referToSelf(self => (x, y) => { - return options.DS(...broadcast(x, y), self, false) + return options.DS(...broadcastMatrices(x, y), self, false) }) matrixSignatures['Array, SparseMatrix'] = typed.referToSelf(self => (x, y) => { - return options.DS(...broadcast(matrix(x), y), self, false) + return options.DS(...broadcastMatrices(matrix(x), y), self, false) }) } if (SD) { matrixSignatures['SparseMatrix, DenseMatrix'] = typed.referToSelf(self => (x, y) => { - return SD(...broadcast(y, x), self, true) + return SD(...broadcastMatrices(y, x), self, true) }) matrixSignatures['SparseMatrix, Array'] = typed.referToSelf(self => (x, y) => { - return SD(...broadcast(matrix(y), x), self, true) + return SD(...broadcastMatrices(matrix(y), x), self, true) }) } } diff --git a/src/utils/array.js b/src/utils/array.js index 3c8b8813fa..293885cb6e 100644 --- a/src/utils/array.js +++ b/src/utils/array.js @@ -738,7 +738,7 @@ export function broadcastSizes (...sizes) { const dim = dimensions[i] for (let j = 0; j < dim; j++) { const n = N - dim + j - if (size[j] > sizeMax[n]) { + if (sizeMax[n] === null || size[j] > sizeMax[n]) { sizeMax[n] = size[j] } } @@ -900,6 +900,57 @@ export function deepMap (array, callback, skipIndex = false) { } } +/** + * Recursively maps multiple arrays assuming they are rectangular, if a size is provided it's assumed + * to be validated already. + */ +export function deepMapMultiple (arrays, sizes = [], callback, skipIndex = false) { + arrays.forEach((array, arrayIndex) => { + if (!sizes[arrayIndex]) { + const size = arraySize(array) + validate(array, size) + sizes[arrayIndex] = size + } + }) + + const finalSize = broadcastSizes(...sizes) + const offsets = sizes.map((size) => finalSize.length - size.length) + const maxDepth = finalSize.length - 1 + const callbackUsesIndex = skipIndex || callback.length > 1 + const index = callbackUsesIndex ? [] : null + const resultsArray = iterate(arrays, 0) + return resultsArray + + function iterate (arrays, depth = 0) { + const currentDimensionSize = finalSize[depth] + const result = Array(currentDimensionSize) + if (depth < maxDepth) { + for (let i = 0; i < currentDimensionSize; i++) { + if (index) index[depth] = i + result[i] = iterate( + arrays.map((array, arrayIndex) => + offsets[arrayIndex] > depth + ? array + : array.length === 1 + ? array[0] + : array[i] + ), + depth + 1 + ) + } + } else { + for (let i = 0; i < currentDimensionSize; i++) { + if (index) index[depth] = i + result[i] = callback( + arrays.map((a) => (a.length === 1 ? a[0] : a[i])), + index ? index.slice() : undefined + ) + } + } + return result + } +} + /** * Recursively iterates over each element in a multi-dimensional array and applies a callback function. * diff --git a/test/unit-tests/expression/function/evaluate.test.js b/test/unit-tests/expression/function/evaluate.test.js index 72e9af11f0..27d4ee06a8 100644 --- a/test/unit-tests/expression/function/evaluate.test.js +++ b/test/unit-tests/expression/function/evaluate.test.js @@ -206,8 +206,8 @@ describe('evaluate', function () { assert.deepStrictEqual(math.evaluate(math.matrix(['null ?? 1', '2 ?? null'])), math.matrix([1, 2])) // Test shape mismatch with empty array - assert.throws(() => math.evaluate('[] ?? [7, 8]'), /RangeError/) - assert.throws(() => math.evaluate('[1] ?? [7, 8]'), /RangeError/) + assert.throws(() => math.evaluate('[] ?? [7, 8]'), /Dimension mismatch/) + assert.throws(() => math.evaluate('[1] ?? [7, 8]'), /Dimension mismatch/) }) it('should handle nullish coalescing with function calls', function () { diff --git a/test/unit-tests/function/algebra/sylvester.test.js b/test/unit-tests/function/algebra/sylvester.test.js index 72bb42aa55..bf1fdd1122 100644 --- a/test/unit-tests/function/algebra/sylvester.test.js +++ b/test/unit-tests/function/algebra/sylvester.test.js @@ -62,7 +62,7 @@ describe('sylvester', function () { math2.config({ legacySubset: true }) // Test legacy syntax with sylvester - // This is not strictly necessary and shoudl be removed after the deprecation period + // This is not strictly necessary and should be removed after the deprecation period const sylvesterA = [[-5.3, -1.4, -0.2, 0.7], [-0.4, -1.0, -0.1, -1.2], [0.3, 0.7, -2.5, 0.7], diff --git a/test/unit-tests/function/logical/nullish.test.js b/test/unit-tests/function/logical/nullish.test.js index 407809e7dd..5fead7330d 100644 --- a/test/unit-tests/function/logical/nullish.test.js +++ b/test/unit-tests/function/logical/nullish.test.js @@ -123,7 +123,7 @@ describe('nullish', function () { describe('shape handling and sparse matrices', function () { it('should throw on mismatched shapes', function () { assert.throws(() => nullish([1], [7, 8]), /Dimension mismatch/) - assert.throws(() => nullish(matrix([1]), matrix([7, 8])), /RangeError/) + assert.throws(() => nullish(matrix([1]), matrix([7, 8])), /Dimension mismatch/) assert.throws(() => nullish(sparse([[1]]), matrix([7, 8])), /DimensionError/) }) diff --git a/test/unit-tests/function/matrix/column.test.js b/test/unit-tests/function/matrix/column.test.js index 6c6e3fed57..1af0c240d9 100644 --- a/test/unit-tests/function/matrix/column.test.js +++ b/test/unit-tests/function/matrix/column.test.js @@ -124,7 +124,7 @@ describe('column', function () { [0, 0, 0, 6, 0] ] // Test column with legacySubset syntax - // This is not strictly necessary and shoudl be removed after the deprecation period + // This is not strictly necessary and should be removed after the deprecation period assert.deepStrictEqual( math2.column(a, 4).valueOf(), [[0], [4], [0], [0], [0]] diff --git a/test/unit-tests/type/matrix/utils/broadcast.test.js b/test/unit-tests/type/matrix/utils/broadcast.test.js index 5e98290bc8..7813ce7009 100644 --- a/test/unit-tests/type/matrix/utils/broadcast.test.js +++ b/test/unit-tests/type/matrix/utils/broadcast.test.js @@ -1,76 +1,118 @@ import assert from 'assert' import math from '../../../../../src/defaultInstance.js' -import { broadcast } from '../../../../../src/type/matrix/utils/broadcast.js' +import { broadcastMatrices, broadcast } from '../../../../../src/type/matrix/utils/broadcast.js' const matrix = math.matrix -describe('broadcast', function () { - it('should return matrices as such if they are the same size', function () { - const A = matrix([1, 2]) - const B = matrix([3, 4]) - const r = broadcast(A, B) - assert.deepStrictEqual(r[0].valueOf(), A.valueOf()) - assert.deepStrictEqual(r[1].valueOf(), B.valueOf()) - }) +describe('broadcast utils', function () { + describe('broadcastMatrices', function () { + it('should return matrices as such if they are the same size', function () { + const A = matrix([1, 2]) + const B = matrix([3, 4]) + const r = broadcastMatrices(A, B) + assert.deepStrictEqual(r[0].valueOf(), A.valueOf()) + assert.deepStrictEqual(r[1].valueOf(), B.valueOf()) + }) - it('should throw an error if they are not broadcastable', function () { - const A = matrix([1, 2]) - const B = matrix([3, 4, 5]) - assert.throws(function () { broadcast(A, B) }) - }) + it('should throw an error if they are not broadcastable', function () { + const A = matrix([1, 2]) + const B = matrix([3, 4, 5]) + assert.throws(function () { broadcastMatrices(A, B) }) + }) - it('should not mutate the original matrices', function () { - const A = matrix([1, 2]) - const B = matrix([[3], [4]]) - broadcast(A, B) - assert.deepStrictEqual(A.valueOf(), [1, 2]) - assert.deepStrictEqual(B.valueOf(), [[3], [4]]) - }) + it('should not mutate the original matrices', function () { + const A = matrix([1, 2]) + const B = matrix([[3], [4]]) + broadcastMatrices(A, B) + assert.deepStrictEqual(A.valueOf(), [1, 2]) + assert.deepStrictEqual(B.valueOf(), [[3], [4]]) + }) - it('should broadcast the first matrix', function () { - const A = matrix([1, 2]) - const B = matrix([[3, 3], [4, 4]]) - const r = broadcast(A, B) - assert.deepStrictEqual(r[0].valueOf(), [[1, 2], [1, 2]]) - assert.deepStrictEqual(r[1].valueOf(), B.valueOf()) - }) + it('should broadcast the first matrix', function () { + const A = matrix([1, 2]) + const B = matrix([[3, 3], [4, 4]]) + const r = broadcastMatrices(A, B) + assert.deepStrictEqual(r[0].valueOf(), [[1, 2], [1, 2]]) + assert.deepStrictEqual(r[1].valueOf(), B.valueOf()) + }) - it('should broadcast the second matrix', function () { - const A = matrix([[1, 2], [1, 2]]) - const B = matrix([[3], [4]]) - const r = broadcast(A, B) - assert.deepStrictEqual(r[0].valueOf(), A.valueOf()) - assert.deepStrictEqual(r[1].valueOf(), [[3, 3], [4, 4]]) - }) + it('should broadcast the second matrix', function () { + const A = matrix([[1, 2], [1, 2]]) + const B = matrix([[3], [4]]) + const r = broadcastMatrices(A, B) + assert.deepStrictEqual(r[0].valueOf(), A.valueOf()) + assert.deepStrictEqual(r[1].valueOf(), [[3, 3], [4, 4]]) + }) - it('should broadcast both matrices', function () { - const A = matrix([1, 2]) - const B = matrix([[3], [4]]) - const r = broadcast(A, B) - assert.deepStrictEqual(r[0].valueOf(), [[1, 2], [1, 2]]) - assert.deepStrictEqual(r[1].valueOf(), [[3, 3], [4, 4]]) - }) + it('should broadcast both matrices', function () { + const A = matrix([1, 2]) + const B = matrix([[3], [4]]) + const r = broadcastMatrices(A, B) + assert.deepStrictEqual(r[0].valueOf(), [[1, 2], [1, 2]]) + assert.deepStrictEqual(r[1].valueOf(), [[3, 3], [4, 4]]) + }) - it('should broadcast a scalar and a column vector', function () { - const A = matrix([1]) - const B = matrix([[3], [4]]) - const r = broadcast(A, B) - assert.deepStrictEqual(r[0].valueOf(), [[1], [1]]) - assert.deepStrictEqual(r[1].valueOf(), B.valueOf()) - }) + it('should broadcast a scalar and a column vector', function () { + const A = matrix([1]) + const B = matrix([[3], [4]]) + const r = broadcastMatrices(A, B) + assert.deepStrictEqual(r[0].valueOf(), [[1], [1]]) + assert.deepStrictEqual(r[1].valueOf(), B.valueOf()) + }) - it('should broadcast a row vector and a scalar', function () { - const A = matrix([1, 2]) - const B = matrix([3]) - const r = broadcast(A, B) - assert.deepStrictEqual(r[0].valueOf(), A.valueOf()) - assert.deepStrictEqual(r[1].valueOf(), [3, 3]) - }) + it('should broadcast a row vector and a scalar', function () { + const A = matrix([1, 2]) + const B = matrix([3]) + const r = broadcastMatrices(A, B) + assert.deepStrictEqual(r[0].valueOf(), A.valueOf()) + assert.deepStrictEqual(r[1].valueOf(), [3, 3]) + }) + + it('should broadcast higher dimensions', function () { + const A = matrix([[[1, 2]]]) + const B = matrix([[[3]], [[4]]]) + const r = broadcastMatrices(A, B) + assert.deepStrictEqual(r[0].valueOf(), [[[1, 2]], [[1, 2]]]) + assert.deepStrictEqual(r[1].valueOf(), [[[3, 3]], [[4, 4]]]) + }) + + describe('broadcast', function () { + it('should apply a broadcasting function on two arrays of the same size', function () { + const a = [[1, 2], [3, 4]] + const b = [[5, 6], [7, 8]] + const result = broadcast(a, b, [2, 2], [2, 2], (x, y) => x + y) + assert.deepStrictEqual(result.data, [[6, 8], [10, 12]]) + assert.deepStrictEqual(result.size, [2, 2]) + }) + + it('should apply a broadcasting function on two arrays of different sizes', function () { + const a = [[1, 2], [3, 4]] + const b = [10, 20] + const result = broadcast(a, b, [2, 2], [2], (x, y) => x + y) + assert.deepStrictEqual(result.data, [[11, 22], [13, 24]]) + assert.deepStrictEqual(result.size, [2, 2]) + }) + + it('should throw an error if arrays and sizes are not provided', function () { + assert.throws(function () { broadcast([1, 2], [3, 4], [2], null, (x, y) => x + y) }) + }) + + it('should throw an error if callback is not a function', function () { + assert.throws(function () { broadcast([1, 2], [3, 4], [2], [2], null) }) + }) + + it('should handle broadcasting with empty arrays', function () { + const a = [] + const b = [] + const result = broadcast(a, b, [0], [0], (x, y) => x + y) + assert.deepStrictEqual(result.data, []) + assert.deepStrictEqual(result.size, [0]) + }) - it('should broadcast higher dimensions', function () { - const A = matrix([[[1, 2]]]) - const B = matrix([[[3]], [[4]]]) - const r = broadcast(A, B) - assert.deepStrictEqual(r[0].valueOf(), [[[1, 2]], [[1, 2]]]) - assert.deepStrictEqual(r[1].valueOf(), [[[3, 3]], [[4, 4]]]) + it('should throw an error if arrays are not broadcastable', function () { + const a = [[1, 2], [3, 4]] + const b = [10, 20, 30] + assert.throws(function () { broadcast(a, b, [2, 2], [3], (x, y) => x + y) }) + }) + }) }) })