fix(templates): Enhance the catchall to call the correct specialization

Also enhances type inference in preparation for template types.
This commit is contained in:
Glen Whitney 2022-08-03 09:35:11 -07:00
parent 880efac15b
commit 5dab7d64e7
3 changed files with 174 additions and 43 deletions

View File

@ -4,8 +4,8 @@ export {abs} from './abs.mjs'
export {absquare} from './absquare.mjs' export {absquare} from './absquare.mjs'
export {add} from './add.mjs' export {add} from './add.mjs'
export {associate} from './associate.mjs' export {associate} from './associate.mjs'
export {conjugate} from './conjugate.mjs'
export {complex} from './complex.mjs' export {complex} from './complex.mjs'
export {conjugate} from './conjugate.mjs'
export {equal} from './equal.mjs' export {equal} from './equal.mjs'
export {gcd} from './gcd.mjs' export {gcd} from './gcd.mjs'
export {invert} from './invert.mjs' export {invert} from './invert.mjs'

View File

@ -27,6 +27,7 @@ export default class PocomathInstance {
'importDependencies', 'importDependencies',
'install', 'install',
'installType', 'installType',
'joinTypes',
'name', 'name',
'typeOf', 'typeOf',
'Types', 'Types',
@ -293,7 +294,20 @@ export default class PocomathInstance {
} }
this._typed.addTypes([{name: type, test: testFn}], beforeType) this._typed.addTypes([{name: type, test: testFn}], beforeType)
this.Types[type] = spec this.Types[type] = spec
this._subtypes[type] = new Set()
this._priorTypes[type] = new Set() this._priorTypes[type] = new Set()
// Update all the subtype sets of supertypes up the chain, and
// while we are at it add trivial conversions from subtypes to supertypes
// to help typed-function match signatures properly:
let nextSuper = spec.refines
while (nextSuper) {
this._typed.addConversion(
{from: type, to: nextSuper, convert: x => x})
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(type)
this._subtypes[nextSuper].add(type)
nextSuper = this.Types[nextSuper].refines
}
/* Now add conversions to this type */ /* Now add conversions to this type */
for (const from in (spec.from || {})) { for (const from in (spec.from || {})) {
if (from in this.Types) { if (from in this.Types) {
@ -304,50 +318,81 @@ export default class PocomathInstance {
{from, to: nextSuper, convert: spec.from[from]}) {from, to: nextSuper, convert: spec.from[from]})
this._invalidateDependents(':' + nextSuper) this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(from) this._priorTypes[nextSuper].add(from)
/* And all of the subtypes of from are now prior as well: */
for (const subtype of this._subtypes[from]) {
this._priorTypes[nextSuper].add(subtype)
}
nextSuper = this.Types[nextSuper].refines nextSuper = this.Types[nextSuper].refines
} }
} }
} }
/* And add conversions from this type */ /* And add conversions from this type */
for (const to in this.Types) { for (const to in this.Types) {
if (type in (this.Types[to].from || {})) { for (const fromtype in this.Types[to].from) {
if (spec.refines == to || spec.refines in this._subtypes[to]) { if (type == fromtype
throw new SyntaxError( || (fromtype in this._subtypes
`Conversion of ${type} to its supertype ${to} disallowed.`) && this._subtypes[fromtype].has(type))) {
} if (spec.refines == to || spec.refines in this._subtypes[to]) {
let nextSuper = to throw new SyntaxError(
while (nextSuper) { `Conversion of ${type} to its supertype ${to} disallowed.`)
this._typed.addConversion({ }
from: type, let nextSuper = to
to: nextSuper, while (nextSuper) {
convert: this.Types[to].from[type] this._typed.addConversion({
}) from: type,
this._invalidateDependents(':' + nextSuper) to: nextSuper,
this._priorTypes[nextSuper].add(type) convert: this.Types[to].from[fromtype]
nextSuper = this.Types[nextSuper].refines })
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(type)
nextSuper = this.Types[nextSuper].refines
}
} }
} }
} }
// Update all the subtype sets of supertypes up the chain, and
// while we are at it add trivial conversions from subtypes to supertypes
// to help typed-function match signatures properly:
this._subtypes[type] = new Set()
let nextSuper = spec.refines
while (nextSuper) {
this._typed.addConversion(
{from: type, to: nextSuper, convert: x => x})
this._invalidateDependents(':' + nextSuper)
this._priorTypes[nextSuper].add(type)
this._subtypes[nextSuper].add(type)
nextSuper = this.Types[nextSuper].refines
}
// update the typeOf function // update the typeOf function
const imp = {} const imp = {}
imp[type] = {uses: new Set(), does: () => () => type} imp[type] = {uses: new Set(), does: () => () => type}
console.log('Adding', type, 'to typeOf')
this._installFunctions({typeOf: imp}) this._installFunctions({typeOf: imp})
} }
/* Returns the most refined type of all the types in the array, with
* '' standing for the empty type for convenience. If the second
* argument `convert` is true, a convertible type is considered a
* a subtype (defaults to false).
*/
joinTypes(types, convert) {
let join = ''
for (const type of types) {
join = this._joinTypes(join, type, convert)
}
return join
}
/* helper for above */
_joinTypes(typeA, typeB, convert) {
if (!typeA) return typeB
if (!typeB) return typeA
if (typeA === typeB) return typeA
const subber = convert ? this._priorTypes : this._subtypes
if (subber[typeB].has(typeA)) return typeB
/* OK, so we need the most refined supertype of A that contains B:
*/
let nextSuper = typeA
while (nextSuper) {
if (subber[nextSuper].has(typeB)) return nextSuper
nextSuper = this.Types[nextSuper].refines
}
/* And if conversions are allowed, we have to search the other way too */
if (convert) {
nextSuper = typeB
while (nextSuper) {
if (subber[nextSuper].has(typeA)) return nextSuper
nextSuper = this.Types[nextSuper].refines
}
}
return 'any'
}
/* Returns a list of all types that have been mentioned in the /* Returns a list of all types that have been mentioned in the
* signatures of operations, but which have not actually been installed: * signatures of operations, but which have not actually been installed:
*/ */
@ -530,30 +575,94 @@ export default class PocomathInstance {
const signature = substituteInSig( const signature = substituteInSig(
trimSignature, theTemplateParam, 'any') trimSignature, theTemplateParam, 'any')
/* The catchall signature has to detect the actual type of the call /* The catchall signature has to detect the actual type of the call
* and add the new instantiations * and add the new instantiations. We should really be using the
* typed-function parser to do the manipulations below, but we don't
* have access. The firs section prepares the type inference data:
*/ */
const argTypes = trimSignature.split(',') const parTypes = trimSignature.split(',')
let exemplar = -1 const inferences = []
for (let i = 0; i < argTypes.length; ++i) { const typer = entity => this.typeOf(entity)
const argType = argTypes[i].trim() let ambiguous = true
if (argType === theTemplateParam) { for (let parType of parTypes) {
exemplar = i parType = parType.trim()
break if (parType.slice(0,3) === '...') {
parType = parType.slice(3).trim()
}
if (parType === theTemplateParam) {
inferences.push(typer)
ambiguous = false
} else {
inferences.push(false)
} }
} }
if (exemplar < 0) { if (ambiguous) {
throw new SyntaxError( throw new SyntaxError(
`Cannot find template parameter in ${rawSignature}`) `Cannot find template parameter in ${rawSignature}`)
} }
/* Now build the catchall implementation */
const self = this const self = this
const patch = (refs) => (...args) => { const patch = (refs) => (...args) => {
const example = args[exemplar] /* First infer the type we actually should have been called for */
const instantiateFor = self.typeOf(example) let i = -1
let j = -1
/* collect the arg types */
const argTypes = []
for (const arg of args) {
++j
// in case of rest parameter, reuse last parameter type:
if (i < inferences.length - 1) ++i
if (inferences[i]) {
const argType = inferences[i](arg)
if (!argType || argType === 'any') {
throw TypeError(
`Type inference failed for argument ${j} of ${name}`)
}
argTypes.push(argType)
}
}
if (argTypes.length === 0) {
throw TypeError('Type inference failed for' + name)
}
let usedConversions = false
let instantiateFor = self.joinTypes(argTypes)
if (instantiateFor === 'any') {
usedConversions = true
instantiateFor = self.joinTypes(argTypes, usedConversions)
if (instantiateFor === 'any') {
throw TypeError('No common type for arguments to ' + name)
}
}
/* Transform the arguments if we used any conversions: */
if (usedConversions) {
i = - 1
for (j = 0; j < args.length; ++j) {
if (i < parTypes.length - 1) ++i
const wantType = substituteInSig(
parTypes[i], theTemplateParam, instantiateFor)
if (wantType !== parTypes[i]) {
args[j] = self._typed.convert(args[j], wantType)
}
}
}
/* Arrange that the desired instantiation will be there next
* time so we don't have to go through that again for this type
*/
refs[theTemplateParam] = instantiateFor refs[theTemplateParam] = instantiateFor
behavior.instantiations.add(instantiateFor) behavior.instantiations.add(instantiateFor)
self._invalidate(name) self._invalidate(name)
// And for now, we have to rely on the "any" implementation. Hope // And update refs because we now know the type we're instantiating
// it matches the instantiated one! // for:
for (const dep of behavior.uses) {
let [func, needsig] = dep.split(/[()]/)
if (needsig && self._typed.isTypedFunction(refs[dep])) {
const subsig = substituteInSig(
needsig, theTemplateParam, instantiateFor)
if (subsig !== needsig) {
refs[dep] = self._typed.find(refs[dep], subsig)
}
}
}
// Finally ready to make the call.
return behavior.does(refs)(...args) return behavior.does(refs)(...args)
} }
this._addTFimplementation( this._addTFimplementation(

View File

@ -8,6 +8,7 @@ import * as complex from '../src/complex/all.mjs'
import * as complexAdd from '../src/complex/add.mjs' import * as complexAdd from '../src/complex/add.mjs'
import * as complexNegate from '../src/complex/negate.mjs' import * as complexNegate from '../src/complex/negate.mjs'
import * as complexComplex from '../src/complex/complex.mjs' import * as complexComplex from '../src/complex/complex.mjs'
import * as bigintAdd from '../src/bigint/add.mjs'
import * as concreteSubtract from '../src/generic/subtract.concrete.mjs' import * as concreteSubtract from '../src/generic/subtract.concrete.mjs'
import * as genericSubtract from '../src/generic/subtract.mjs' import * as genericSubtract from '../src/generic/subtract.mjs'
import extendToComplex from '../src/complex/extendToComplex.mjs' import extendToComplex from '../src/complex/extendToComplex.mjs'
@ -112,4 +113,25 @@ describe('A custom instance', () => {
math.complex(1n, -3n)) math.complex(1n, -3n))
}) })
it("instantiates templates correctly", () => {
const inst = new PocomathInstance('InstantiateTemplates')
inst.install(numberAdd)
inst.install({typeMerge: {'T,T': ({T}) => (t,u) => 'Merge to ' + T }})
assert.strictEqual(inst.typeMerge(7,6.28), 'Merge to number')
assert.strictEqual(inst.typeMerge(7,6), 'Merge to NumInt')
assert.strictEqual(inst.typeMerge(7.35,6), 'Merge to number')
inst.install(complexAdd)
inst.install(complexComplex)
inst.install(bigintAdd)
assert.strictEqual(
inst.typeMerge(6n, inst.complex(3n, 2n)),
'Merge to GaussianInteger')
assert.strictEqual(
inst.typeMerge(3, inst.complex(4.5,2.1)),
'Merge to Complex')
// The following is the current behavior, since both `3+0i` and `3n + 0ni`
// are Complex, but it is unfortunate and hopefully it will be fixed
// with templates:
assert.strictEqual(inst.typeMerge(3, 3n), 'Merge to Complex')
})
}) })