diff --git a/src/number/arithmetic.js b/src/number/arithmetic.js index dd03106..9ad8ac0 100644 --- a/src/number/arithmetic.js +++ b/src/number/arithmetic.js @@ -27,7 +27,11 @@ export const cbrt = plain(a => { return negate ? -result : result }) export const invert = plain(a => 1/a) -export const multiply = plain((a, b) => a * b) +export const multiply = [ + plain((a, b) => a * b), + match([Undefined, NumberT], Returns(NumberT, () => NaN)), + match([NumberT, Undefined], Returns(NumberT, () => NaN)) +] export const negate = plain(a => -a) export const sqrt = match(NumberT, (math, _N, strategy) => { diff --git a/src/vector/__test__/arithmetic.spec.js b/src/vector/__test__/arithmetic.spec.js index 77337b3..5dcc3ca 100644 --- a/src/vector/__test__/arithmetic.spec.js +++ b/src/vector/__test__/arithmetic.spec.js @@ -24,6 +24,21 @@ describe('Vector arithmetic functions', () => { assert.deepStrictEqual( add([[1, 2], [4, 2]], [0, -1]), [[1, 1], [4, 1]]) }) + it('multiplies vectors and matrices', () => { + const mult = math.multiply + const pyth = [3, 4, 5] + assert.deepStrictEqual(mult(pyth, 2), [6, 8, 10]) + assert.deepStrictEqual(mult(-3, pyth), [-9, -12, -15]) + assert.strictEqual(mult(pyth, pyth), 50) + const mat23 = [[1, 2, 3], [-3, -2, -1]] + assert.deepStrictEqual(mult(mat23, pyth), [26, -22]) + const mat32 = math.transpose(mat23) + assert.deepStrictEqual(mult(pyth, mat32), [26, -22]) + assert.deepStrictEqual(mult(mat23, mat32), [[14, -10], [-10, 14]]) + assert.deepStrictEqual( + mult(mat32, [[1, 2], [3, 4]]), + [[-8, -10], [-4, -4], [0, 2]]) + }) it('negates a vector', () => { assert.deepStrictEqual(math.negate([-3, 4, -5]), [3, -4, 5]) }) diff --git a/src/vector/__test__/type.spec.js b/src/vector/__test__/type.spec.js index e50228c..369033a 100644 --- a/src/vector/__test__/type.spec.js +++ b/src/vector/__test__/type.spec.js @@ -15,4 +15,11 @@ describe('Vector type functions', () => { ReturnType(vec.resolve([NumberT, BooleanT])), Vector(Unknown)) }) + it('can transpose vectors and matrices', () => { + const tsp = math.transpose + assert.deepStrictEqual(tsp([3, 4, 5]), [[3], [4], [5]]) + assert.deepStrictEqual(tsp([[1, 2], [3, 4]]), [[1, 3], [2, 4]]) + assert.deepStrictEqual( + tsp([[1, 2, 3], [4, 5, 6]]), [[1, 4], [2, 5], [3, 6]]) + }) }) diff --git a/src/vector/arithmetic.js b/src/vector/arithmetic.js index 4803148..1687138 100644 --- a/src/vector/arithmetic.js +++ b/src/vector/arithmetic.js @@ -1,4 +1,6 @@ -import {promoteBinary, promoteUnary} from './helpers.js' +import { + distributeFirst, distributeSecond, promoteBinary, promoteUnary +} from './helpers.js' import {Vector} from './Vector.js' import {ReturnType} from '#core/Type.js' @@ -16,6 +18,39 @@ export const normsq = match(Vector, (math, V) => { export const abs = promoteUnary('abs') export const add = promoteBinary('add') +export const dotMultiply = promoteBinary('multiply') +export const multiply = [ + distributeFirst('multiply'), + distributeSecond('multiply'), + match([Vector, Vector], (math, [V, W], strategy) => { + const VComp = V.Component + if (W.vectorDepth === 1) { + if (V.vectorDepth === 1) { + const eltWise = math.dotMultiply.resolve([V, W], strategy) + const sum = math.sum.resolve(ReturnType(eltWise)) + return ReturnsAs(sum, (v, w) => sum(eltWise(v, w))) + } + const compMult = math.multiply.resolve([VComp, W], strategy) + return ReturnsAs( + Vector(ReturnType(compMult)), + (v, w) => v.map(f => compMult(f, w))) + } + const transpose = math.transpose.resolve(W, strategy) + const wrapV = V.vectorDepth === 1 + const RowV = wrapV ? V : VComp + const rowMult = math.multiply.resolve([RowV, W.Component], strategy) + let RetType = Vector(ReturnType(rowMult)) + if (!wrapV) RetType = Vector(RetType) + return ReturnsAs(RetType, (v, w) => { + if (wrapV) v = [v] + w = transpose(w) + let retval = v.map(vrow => w.map(wcol => rowMult(vrow, wcol))) + if (wrapV) retval = retval[0] + return retval + }) + }) +] + export const negate = promoteUnary('negate') export const subtract = promoteBinary('subtract') diff --git a/src/vector/helpers.js b/src/vector/helpers.js index e0ee78f..b70bc30 100644 --- a/src/vector/helpers.js +++ b/src/vector/helpers.js @@ -7,17 +7,25 @@ export const promoteUnary = name => match(Vector, (math, V, strategy) => { return Returns(Vector(ReturnType(compOp)), v => v.map(elt => compOp(elt))) }) -export const promoteBinary = name => [ - match([Vector, Any], (math, [V, E], strategy) => { +export const distributeFirst = name => match( + [Vector, Any], + (math, [V, E], strategy) => { const compOp = math.resolve(name, [V.Component, E], strategy) return Returns( Vector(ReturnType(compOp)), (v, e) => v.map(f => compOp(f, e))) - }), - match([Any, Vector], (math, [E, V], strategy) => { + }) + +export const distributeSecond = name => match( + [Any, Vector], + (math, [E, V], strategy) => { const compOp = math.resolve(name, [E, V.Component], strategy) return Returns( Vector(ReturnType(compOp)), (e, v) => v.map(f => compOp(e, f))) - }), + }) + +export const promoteBinary = name => [ + distributeFirst(name), + distributeSecond(name), match([Vector, Vector], (math, [V, W], strategy) => { const VComp = V.Component const WComp = W.Component diff --git a/src/vector/type.js b/src/vector/type.js index 912c5a7..88cbe3f 100644 --- a/src/vector/type.js +++ b/src/vector/type.js @@ -9,4 +9,16 @@ export const vector = match(Multiple(Any), (math, [TV]) => { return Returns(Vector(CompType), v => v) }) - +export const transpose = match(Vector, (_math, V) => { + const wrapV = V.vectorDepth === 1 + const Mat = wrapV ? Vector(V) : V + return Returns(Mat, v => { + if (wrapV) v = [v] + const cols = v.length ? v[0].length : 0 + const retval = [] + for (let ix = 0; ix < cols; ++ix) { + retval.push(v.map(row => row[ix])) + } + return retval + }) +})