diff --git a/src/complex/native.mjs b/src/complex/native.mjs index 4f63c8f..1dfb45e 100644 --- a/src/complex/native.mjs +++ b/src/complex/native.mjs @@ -4,8 +4,8 @@ export {abs} from './abs.mjs' export {absquare} from './absquare.mjs' export {add} from './add.mjs' export {associate} from './associate.mjs' -export {conjugate} from './conjugate.mjs' export {complex} from './complex.mjs' +export {conjugate} from './conjugate.mjs' export {equal} from './equal.mjs' export {gcd} from './gcd.mjs' export {invert} from './invert.mjs' diff --git a/src/core/PocomathInstance.mjs b/src/core/PocomathInstance.mjs index 6fab6c9..a698488 100644 --- a/src/core/PocomathInstance.mjs +++ b/src/core/PocomathInstance.mjs @@ -27,6 +27,7 @@ export default class PocomathInstance { 'importDependencies', 'install', 'installType', + 'joinTypes', 'name', 'typeOf', 'Types', @@ -293,7 +294,20 @@ export default class PocomathInstance { } this._typed.addTypes([{name: type, test: testFn}], beforeType) this.Types[type] = spec + this._subtypes[type] = new Set() this._priorTypes[type] = new Set() + // Update all the subtype sets of supertypes up the chain, and + // while we are at it add trivial conversions from subtypes to supertypes + // to help typed-function match signatures properly: + let nextSuper = spec.refines + while (nextSuper) { + this._typed.addConversion( + {from: type, to: nextSuper, convert: x => x}) + this._invalidateDependents(':' + nextSuper) + this._priorTypes[nextSuper].add(type) + this._subtypes[nextSuper].add(type) + nextSuper = this.Types[nextSuper].refines + } /* Now add conversions to this type */ for (const from in (spec.from || {})) { if (from in this.Types) { @@ -304,50 +318,81 @@ export default class PocomathInstance { {from, to: nextSuper, convert: spec.from[from]}) this._invalidateDependents(':' + nextSuper) this._priorTypes[nextSuper].add(from) + /* And all of the subtypes of from are now prior as well: */ + for (const subtype of this._subtypes[from]) { + this._priorTypes[nextSuper].add(subtype) + } nextSuper = this.Types[nextSuper].refines } } } /* And add conversions from this type */ for (const to in this.Types) { - if (type in (this.Types[to].from || {})) { - if (spec.refines == to || spec.refines in this._subtypes[to]) { - throw new SyntaxError( - `Conversion of ${type} to its supertype ${to} disallowed.`) - } - let nextSuper = to - while (nextSuper) { - this._typed.addConversion({ - from: type, - to: nextSuper, - convert: this.Types[to].from[type] - }) - this._invalidateDependents(':' + nextSuper) - this._priorTypes[nextSuper].add(type) - nextSuper = this.Types[nextSuper].refines + for (const fromtype in this.Types[to].from) { + if (type == fromtype + || (fromtype in this._subtypes + && this._subtypes[fromtype].has(type))) { + if (spec.refines == to || spec.refines in this._subtypes[to]) { + throw new SyntaxError( + `Conversion of ${type} to its supertype ${to} disallowed.`) + } + let nextSuper = to + while (nextSuper) { + this._typed.addConversion({ + from: type, + to: nextSuper, + convert: this.Types[to].from[fromtype] + }) + this._invalidateDependents(':' + nextSuper) + this._priorTypes[nextSuper].add(type) + nextSuper = this.Types[nextSuper].refines + } } } } - // Update all the subtype sets of supertypes up the chain, and - // while we are at it add trivial conversions from subtypes to supertypes - // to help typed-function match signatures properly: - this._subtypes[type] = new Set() - let nextSuper = spec.refines - while (nextSuper) { - this._typed.addConversion( - {from: type, to: nextSuper, convert: x => x}) - this._invalidateDependents(':' + nextSuper) - this._priorTypes[nextSuper].add(type) - this._subtypes[nextSuper].add(type) - nextSuper = this.Types[nextSuper].refines - } - // update the typeOf function const imp = {} imp[type] = {uses: new Set(), does: () => () => type} + console.log('Adding', type, 'to typeOf') this._installFunctions({typeOf: imp}) } + /* Returns the most refined type of all the types in the array, with + * '' standing for the empty type for convenience. If the second + * argument `convert` is true, a convertible type is considered a + * a subtype (defaults to false). + */ + joinTypes(types, convert) { + let join = '' + for (const type of types) { + join = this._joinTypes(join, type, convert) + } + return join + } + /* helper for above */ + _joinTypes(typeA, typeB, convert) { + if (!typeA) return typeB + if (!typeB) return typeA + if (typeA === typeB) return typeA + const subber = convert ? this._priorTypes : this._subtypes + if (subber[typeB].has(typeA)) return typeB + /* OK, so we need the most refined supertype of A that contains B: + */ + let nextSuper = typeA + while (nextSuper) { + if (subber[nextSuper].has(typeB)) return nextSuper + nextSuper = this.Types[nextSuper].refines + } + /* And if conversions are allowed, we have to search the other way too */ + if (convert) { + nextSuper = typeB + while (nextSuper) { + if (subber[nextSuper].has(typeA)) return nextSuper + nextSuper = this.Types[nextSuper].refines + } + } + return 'any' + } /* Returns a list of all types that have been mentioned in the * signatures of operations, but which have not actually been installed: */ @@ -530,30 +575,94 @@ export default class PocomathInstance { const signature = substituteInSig( trimSignature, theTemplateParam, 'any') /* The catchall signature has to detect the actual type of the call - * and add the new instantiations + * and add the new instantiations. We should really be using the + * typed-function parser to do the manipulations below, but we don't + * have access. The firs section prepares the type inference data: */ - const argTypes = trimSignature.split(',') - let exemplar = -1 - for (let i = 0; i < argTypes.length; ++i) { - const argType = argTypes[i].trim() - if (argType === theTemplateParam) { - exemplar = i - break + const parTypes = trimSignature.split(',') + const inferences = [] + const typer = entity => this.typeOf(entity) + let ambiguous = true + for (let parType of parTypes) { + parType = parType.trim() + if (parType.slice(0,3) === '...') { + parType = parType.slice(3).trim() + } + if (parType === theTemplateParam) { + inferences.push(typer) + ambiguous = false + } else { + inferences.push(false) } } - if (exemplar < 0) { + if (ambiguous) { throw new SyntaxError( `Cannot find template parameter in ${rawSignature}`) } + /* Now build the catchall implementation */ const self = this const patch = (refs) => (...args) => { - const example = args[exemplar] - const instantiateFor = self.typeOf(example) + /* First infer the type we actually should have been called for */ + let i = -1 + let j = -1 + /* collect the arg types */ + const argTypes = [] + for (const arg of args) { + ++j + // in case of rest parameter, reuse last parameter type: + if (i < inferences.length - 1) ++i + if (inferences[i]) { + const argType = inferences[i](arg) + if (!argType || argType === 'any') { + throw TypeError( + `Type inference failed for argument ${j} of ${name}`) + } + argTypes.push(argType) + } + } + if (argTypes.length === 0) { + throw TypeError('Type inference failed for' + name) + } + let usedConversions = false + let instantiateFor = self.joinTypes(argTypes) + if (instantiateFor === 'any') { + usedConversions = true + instantiateFor = self.joinTypes(argTypes, usedConversions) + if (instantiateFor === 'any') { + throw TypeError('No common type for arguments to ' + name) + } + } + /* Transform the arguments if we used any conversions: */ + if (usedConversions) { + i = - 1 + for (j = 0; j < args.length; ++j) { + if (i < parTypes.length - 1) ++i + const wantType = substituteInSig( + parTypes[i], theTemplateParam, instantiateFor) + if (wantType !== parTypes[i]) { + args[j] = self._typed.convert(args[j], wantType) + } + } + } + /* Arrange that the desired instantiation will be there next + * time so we don't have to go through that again for this type + */ refs[theTemplateParam] = instantiateFor behavior.instantiations.add(instantiateFor) self._invalidate(name) - // And for now, we have to rely on the "any" implementation. Hope - // it matches the instantiated one! + // And update refs because we now know the type we're instantiating + // for: + for (const dep of behavior.uses) { + let [func, needsig] = dep.split(/[()]/) + if (needsig && self._typed.isTypedFunction(refs[dep])) { + const subsig = substituteInSig( + needsig, theTemplateParam, instantiateFor) + if (subsig !== needsig) { + refs[dep] = self._typed.find(refs[dep], subsig) + } + } + } + // Finally ready to make the call. return behavior.does(refs)(...args) } this._addTFimplementation( diff --git a/test/custom.mjs b/test/custom.mjs index 9fb66b1..76bb88a 100644 --- a/test/custom.mjs +++ b/test/custom.mjs @@ -8,6 +8,7 @@ import * as complex from '../src/complex/all.mjs' import * as complexAdd from '../src/complex/add.mjs' import * as complexNegate from '../src/complex/negate.mjs' import * as complexComplex from '../src/complex/complex.mjs' +import * as bigintAdd from '../src/bigint/add.mjs' import * as concreteSubtract from '../src/generic/subtract.concrete.mjs' import * as genericSubtract from '../src/generic/subtract.mjs' import extendToComplex from '../src/complex/extendToComplex.mjs' @@ -112,4 +113,25 @@ describe('A custom instance', () => { math.complex(1n, -3n)) }) + it("instantiates templates correctly", () => { + const inst = new PocomathInstance('InstantiateTemplates') + inst.install(numberAdd) + inst.install({typeMerge: {'T,T': ({T}) => (t,u) => 'Merge to ' + T }}) + assert.strictEqual(inst.typeMerge(7,6.28), 'Merge to number') + assert.strictEqual(inst.typeMerge(7,6), 'Merge to NumInt') + assert.strictEqual(inst.typeMerge(7.35,6), 'Merge to number') + inst.install(complexAdd) + inst.install(complexComplex) + inst.install(bigintAdd) + assert.strictEqual( + inst.typeMerge(6n, inst.complex(3n, 2n)), + 'Merge to GaussianInteger') + assert.strictEqual( + inst.typeMerge(3, inst.complex(4.5,2.1)), + 'Merge to Complex') + // The following is the current behavior, since both `3+0i` and `3n + 0ni` + // are Complex, but it is unfortunate and hopefully it will be fixed + // with templates: + assert.strictEqual(inst.typeMerge(3, 3n), 'Merge to Complex') + }) })