diff --git a/src/__test__/numbers.spec.js b/src/__test__/numbers.spec.js index 95f3658..f1a8602 100644 --- a/src/__test__/numbers.spec.js +++ b/src/__test__/numbers.spec.js @@ -10,4 +10,9 @@ describe('the numbers-only bundle', () => { assert.strictEqual(math.isnan(-16.5), 0) assert.strictEqual(math.isnan(NaN), 1) }) + it('takes sqrt with NaN for negative', () => { + assert.strictEqual(math.sqrt(25), 5) + assert(math.isnan(math.sqrt(-25))) + assert(math.isnan(math.sqrt(NaN))) + }) }) diff --git a/src/number/__test__/arithmetic.spec.js b/src/number/__test__/arithmetic.spec.js index 7d368cb..6d718ef 100644 --- a/src/number/__test__/arithmetic.spec.js +++ b/src/number/__test__/arithmetic.spec.js @@ -1,5 +1,6 @@ import assert from 'assert' import math from '#nanomath' +import {ReturnTyping} from '#core/Type.js' describe('number arithmetic', () => { it('supports basic operations', () => { @@ -14,4 +15,18 @@ describe('number arithmetic', () => { assert.strictEqual(math.subtract(4, 2), 2) assert.strictEqual(math.quotient(7, 3), 2) }) + it('takes square root of numbers appropriately', () => { + assert(isNaN(math.sqrt(NaN))) + assert.strictEqual(math.sqrt(4), 2) + assert.deepStrictEqual(math.sqrt(-4), math.complex(0, 2)) + math.config.returnTyping = ReturnTyping.conservative + assert(isNaN(math.sqrt(NaN))) + assert.strictEqual(math.sqrt(4), 2) + assert(isNaN(math.sqrt(-4))) + math.config.returnTyping = ReturnTyping.full + assert(isNaN(math.sqrt(NaN))) + assert.deepStrictEqual(math.sqrt(4), math.complex(2, 0)) + assert.deepStrictEqual(math.sqrt(-4), math.complex(0, 2)) + math.config.returnTyping = ReturnTyping.free + }) }) diff --git a/src/number/arithmetic.js b/src/number/arithmetic.js index 662a157..0b5c66f 100644 --- a/src/number/arithmetic.js +++ b/src/number/arithmetic.js @@ -1,4 +1,10 @@ import {plain} from './helpers.js' +import {NumberT} from './NumberT.js' +import {OneOf, Returns, ReturnTyping} from '#core/Type.js' +import {match} from '#core/TypePatterns.js' +import {Complex} from '#complex/Complex.js' + +const {conservative, full} = ReturnTyping export const abs = plain(Math.abs) export const absquare = plain(a => a*a) @@ -18,5 +24,25 @@ export const cbrt = plain(a => { export const invert = plain(a => 1/a) export const multiply = plain((a, b) => a * b) export const negate = plain(a => -a) + +export const sqrt = match(NumberT, (math, _N, strategy) => { + if (!math.types.Complex || strategy === conservative) { + return Returns(NumberT, Math.sqrt) + } + const cplx = math.complex.resolve([NumberT, NumberT], full) + if (strategy === full) { + const cnan = math.nan(Complex(NumberT)) + return Returns(Complex(NumberT), a => { + if (isNaN(a)) return cnan + return a >= 0 ? cplx(Math.sqrt(a), 0) : cplx(0, Math.sqrt(-a)) + }) + } + // strategy === free, return "best" type + return Returns(OneOf(NumberT, Complex(NumberT)), a => { + if (isNaN(a)) return NaN + return a >= 0 ? Math.sqrt(a) : cplx(0, Math.sqrt(-a)) + }) +}) + export const subtract = plain((a, b) => a - b) export const quotient = plain((a,b) => Math.floor(a/b)) diff --git a/src/package.json b/src/package.json index 290fb63..2e7d812 100644 --- a/src/package.json +++ b/src/package.json @@ -1,8 +1,9 @@ { "imports" : { "#nanomath": "./nanomath.js", - "#boolean/*.js": "./boolean/*.js", "#core/*.js": "./core/*.js", + "#boolean/*.js": "./boolean/*.js", + "#complex/*.js": "./complex/*.js", "#generic/*.js": "./generic/*.js", "#number/*.js": "./number/*.js" },