fix(templates): Enhance the catchall to call the correct specialization

Also enhances type inference in preparation for template types.
This commit is contained in:
Glen Whitney 2022-08-03 09:35:11 -07:00
parent 880efac15b
commit 5dab7d64e7
3 changed files with 174 additions and 43 deletions

View File

@ -4,8 +4,8 @@ export {abs} from './abs.mjs'
export {absquare} from './absquare.mjs'
export {add} from './add.mjs'
export {associate} from './associate.mjs'
export {conjugate} from './conjugate.mjs'
export {complex} from './complex.mjs'
export {conjugate} from './conjugate.mjs'
export {equal} from './equal.mjs'
export {gcd} from './gcd.mjs'
export {invert} from './invert.mjs'

View File

@ -27,6 +27,7 @@ export default class PocomathInstance {
'importDependencies',
'install',
'installType',
'joinTypes',
'name',
'typeOf',
'Types',
@ -293,7 +294,20 @@ export default class PocomathInstance {
}
this._typed.addTypes([{name: type, test: testFn}], beforeType)
this.Types[type] = spec
this._subtypes[type] = new Set()
this._priorTypes[type] = new Set()
// Update all the subtype sets of supertypes up the chain, and
// while we are at it add trivial conversions from subtypes to supertypes
// to help typed-function match signatures properly:
let nextSuper = spec.refines
while (nextSuper) {
this._typed.addConversion(
{from: type, to: nextSuper, convert: x => x})
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(type)
this._subtypes[nextSuper].add(type)
nextSuper = this.Types[nextSuper].refines
}
/* Now add conversions to this type */
for (const from in (spec.from || {})) {
if (from in this.Types) {
@ -304,50 +318,81 @@ export default class PocomathInstance {
{from, to: nextSuper, convert: spec.from[from]})
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(from)
/* And all of the subtypes of from are now prior as well: */
for (const subtype of this._subtypes[from]) {
this._priorTypes[nextSuper].add(subtype)
}
nextSuper = this.Types[nextSuper].refines
}
}
}
/* And add conversions from this type */
for (const to in this.Types) {
if (type in (this.Types[to].from || {})) {
if (spec.refines == to || spec.refines in this._subtypes[to]) {
throw new SyntaxError(
`Conversion of ${type} to its supertype ${to} disallowed.`)
}
let nextSuper = to
while (nextSuper) {
this._typed.addConversion({
from: type,
to: nextSuper,
convert: this.Types[to].from[type]
})
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(type)
nextSuper = this.Types[nextSuper].refines
for (const fromtype in this.Types[to].from) {
if (type == fromtype
|| (fromtype in this._subtypes
&& this._subtypes[fromtype].has(type))) {
if (spec.refines == to || spec.refines in this._subtypes[to]) {
throw new SyntaxError(
`Conversion of ${type} to its supertype ${to} disallowed.`)
}
let nextSuper = to
while (nextSuper) {
this._typed.addConversion({
from: type,
to: nextSuper,
convert: this.Types[to].from[fromtype]
})
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(type)
nextSuper = this.Types[nextSuper].refines
}
}
}
}
// Update all the subtype sets of supertypes up the chain, and
// while we are at it add trivial conversions from subtypes to supertypes
// to help typed-function match signatures properly:
this._subtypes[type] = new Set()
let nextSuper = spec.refines
while (nextSuper) {
this._typed.addConversion(
{from: type, to: nextSuper, convert: x => x})
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(type)
this._subtypes[nextSuper].add(type)
nextSuper = this.Types[nextSuper].refines
}
// update the typeOf function
const imp = {}
imp[type] = {uses: new Set(), does: () => () => type}
console.log('Adding', type, 'to typeOf')
this._installFunctions({typeOf: imp})
}
/* Returns the most refined type of all the types in the array, with
* '' standing for the empty type for convenience. If the second
* argument `convert` is true, a convertible type is considered a
* a subtype (defaults to false).
*/
joinTypes(types, convert) {
let join = ''
for (const type of types) {
join = this._joinTypes(join, type, convert)
}
return join
}
/* helper for above */
_joinTypes(typeA, typeB, convert) {
if (!typeA) return typeB
if (!typeB) return typeA
if (typeA === typeB) return typeA
const subber = convert ? this._priorTypes : this._subtypes
if (subber[typeB].has(typeA)) return typeB
/* OK, so we need the most refined supertype of A that contains B:
*/
let nextSuper = typeA
while (nextSuper) {
if (subber[nextSuper].has(typeB)) return nextSuper
nextSuper = this.Types[nextSuper].refines
}
/* And if conversions are allowed, we have to search the other way too */
if (convert) {
nextSuper = typeB
while (nextSuper) {
if (subber[nextSuper].has(typeA)) return nextSuper
nextSuper = this.Types[nextSuper].refines
}
}
return 'any'
}
/* Returns a list of all types that have been mentioned in the
* signatures of operations, but which have not actually been installed:
*/
@ -530,30 +575,94 @@ export default class PocomathInstance {
const signature = substituteInSig(
trimSignature, theTemplateParam, 'any')
/* The catchall signature has to detect the actual type of the call
* and add the new instantiations
* and add the new instantiations. We should really be using the
* typed-function parser to do the manipulations below, but we don't
* have access. The firs section prepares the type inference data:
*/
const argTypes = trimSignature.split(',')
let exemplar = -1
for (let i = 0; i < argTypes.length; ++i) {
const argType = argTypes[i].trim()
if (argType === theTemplateParam) {
exemplar = i
break
const parTypes = trimSignature.split(',')
const inferences = []
const typer = entity => this.typeOf(entity)
let ambiguous = true
for (let parType of parTypes) {
parType = parType.trim()
if (parType.slice(0,3) === '...') {
parType = parType.slice(3).trim()
}
if (parType === theTemplateParam) {
inferences.push(typer)
ambiguous = false
} else {
inferences.push(false)
}
}
if (exemplar < 0) {
if (ambiguous) {
throw new SyntaxError(
`Cannot find template parameter in ${rawSignature}`)
}
/* Now build the catchall implementation */
const self = this
const patch = (refs) => (...args) => {
const example = args[exemplar]
const instantiateFor = self.typeOf(example)
/* First infer the type we actually should have been called for */
let i = -1
let j = -1
/* collect the arg types */
const argTypes = []
for (const arg of args) {
++j
// in case of rest parameter, reuse last parameter type:
if (i < inferences.length - 1) ++i
if (inferences[i]) {
const argType = inferences[i](arg)
if (!argType || argType === 'any') {
throw TypeError(
`Type inference failed for argument ${j} of ${name}`)
}
argTypes.push(argType)
}
}
if (argTypes.length === 0) {
throw TypeError('Type inference failed for' + name)
}
let usedConversions = false
let instantiateFor = self.joinTypes(argTypes)
if (instantiateFor === 'any') {
usedConversions = true
instantiateFor = self.joinTypes(argTypes, usedConversions)
if (instantiateFor === 'any') {
throw TypeError('No common type for arguments to ' + name)
}
}
/* Transform the arguments if we used any conversions: */
if (usedConversions) {
i = - 1
for (j = 0; j < args.length; ++j) {
if (i < parTypes.length - 1) ++i
const wantType = substituteInSig(
parTypes[i], theTemplateParam, instantiateFor)
if (wantType !== parTypes[i]) {
args[j] = self._typed.convert(args[j], wantType)
}
}
}
/* Arrange that the desired instantiation will be there next
* time so we don't have to go through that again for this type
*/
refs[theTemplateParam] = instantiateFor
behavior.instantiations.add(instantiateFor)
self._invalidate(name)
// And for now, we have to rely on the "any" implementation. Hope
// it matches the instantiated one!
// And update refs because we now know the type we're instantiating
// for:
for (const dep of behavior.uses) {
let [func, needsig] = dep.split(/[()]/)
if (needsig && self._typed.isTypedFunction(refs[dep])) {
const subsig = substituteInSig(
needsig, theTemplateParam, instantiateFor)
if (subsig !== needsig) {
refs[dep] = self._typed.find(refs[dep], subsig)
}
}
}
// Finally ready to make the call.
return behavior.does(refs)(...args)
}
this._addTFimplementation(

View File

@ -8,6 +8,7 @@ import * as complex from '../src/complex/all.mjs'
import * as complexAdd from '../src/complex/add.mjs'
import * as complexNegate from '../src/complex/negate.mjs'
import * as complexComplex from '../src/complex/complex.mjs'
import * as bigintAdd from '../src/bigint/add.mjs'
import * as concreteSubtract from '../src/generic/subtract.concrete.mjs'
import * as genericSubtract from '../src/generic/subtract.mjs'
import extendToComplex from '../src/complex/extendToComplex.mjs'
@ -112,4 +113,25 @@ describe('A custom instance', () => {
math.complex(1n, -3n))
})
it("instantiates templates correctly", () => {
const inst = new PocomathInstance('InstantiateTemplates')
inst.install(numberAdd)
inst.install({typeMerge: {'T,T': ({T}) => (t,u) => 'Merge to ' + T }})
assert.strictEqual(inst.typeMerge(7,6.28), 'Merge to number')
assert.strictEqual(inst.typeMerge(7,6), 'Merge to NumInt')
assert.strictEqual(inst.typeMerge(7.35,6), 'Merge to number')
inst.install(complexAdd)
inst.install(complexComplex)
inst.install(bigintAdd)
assert.strictEqual(
inst.typeMerge(6n, inst.complex(3n, 2n)),
'Merge to GaussianInteger')
assert.strictEqual(
inst.typeMerge(3, inst.complex(4.5,2.1)),
'Merge to Complex')
// The following is the current behavior, since both `3+0i` and `3n + 0ni`
// are Complex, but it is unfortunate and hopefully it will be fixed
// with templates:
assert.strictEqual(inst.typeMerge(3, 3n), 'Merge to Complex')
})
})