idlize
606 строк · 22.0 Кб
1/*
2* Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3* Licensed under the Apache License, Version 2.0 (the "License");
4* you may not use this file except in compliance with the License.
5* You may obtain a copy of the License at
6*
7* http://www.apache.org/licenses/LICENSE-2.0
8*
9* Unless required by applicable law or agreed to in writing, software
10* distributed under the License is distributed on an "AS IS" BASIS,
11* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12* See the License for the specific language governing permissions and
13* limitations under the License.
14*/
15
16import * as ts from 'typescript';17import { AbstractVisitor } from "./AbstractVisitor"18import { Rewrite } from './transformation-context';19import {20FunctionKind,21Tracer,22idPlusKey,23PositionalIdTracker,24wrapInCompute,25RuntimeNames,26runtimeIdentifier,27isMemoKind,28hiddenParameters,29createComputeScope,30isVoidOrNotSpecified,31localStateStatement,32isTrackableParameter,33hasStaticModifier,34isMethodOfStableClass,35} from "./util"36
37
38function callIsIntrinsicContext(node: ts.Identifier): boolean {39return ts.idText(node) == RuntimeNames.__CONTEXT40}
41
42function callIsIntrinsicId(node: ts.Identifier): boolean {43return ts.idText(node) == RuntimeNames.__ID44}
45
46function callIsIntrinsicKey(node: ts.Identifier): boolean {47return ts.idText(node) == RuntimeNames.__KEY48}
49
50function transformIntrinsicContextCall(node: ts.CallExpression): ts.Identifier {51return runtimeIdentifier(RuntimeNames.CONTEXT)52}
53
54function transformIntrinsicIdCall(node: ts.CallExpression): ts.Identifier {55return runtimeIdentifier(RuntimeNames.ID)56}
57
58function transformIntrinsicKeyCall(positionalIdTracker: PositionalIdTracker, node: ts.CallExpression): ts.Expression {59return positionalIdTracker.id(RuntimeNames.__KEY)60}
61
62// Given a lambda type (foo:XXX, bar:YYY) => ZZZ
63// returns the type name XXX of the first argument.
64// Or undefined if such name can't be found.
65function firstArgTypeName(lambdaType: ts.TypeNode|undefined): string|undefined {66if (lambdaType === undefined) return undefined67if (!ts.isFunctionTypeNode(lambdaType)) return undefined68const firstArg = lambdaType?.parameters[0]69if (!firstArg) return undefined70if (!firstArg.type) return undefined71if (!ts.isTypeReferenceNode(firstArg.type)) return undefined72if (!ts.isIdentifier(firstArg.type.typeName)) return undefined73return ts.idText(firstArg.type.typeName)74}
75
76export class FunctionTransformer extends AbstractVisitor {77constructor(78public tracer: Tracer,79public typechecker: ts.TypeChecker,80public sourceFile: ts.SourceFile,81public rewrite: Rewrite,82ctx: ts.TransformationContext83) {84super(ctx)85}86
87trace(msg: any) {88this.tracer.trace(msg)89}90
91
92/*
93// The context.compute rewrite with lambda
94
95function foo0(p1, p2) {
96const ps1 = context.parState()
97const ps2 = context.parState()
98return context.compute(__contex, __id + "id", () => {
99if (a) return "1"
100return "2"
101// body
102})
103}
104
105// The lambda-less rewrite
106
107function foo1(p1, p2): R {
108const scope = __context().scope<R>(__id() + "id", 2)
109const ps1 = scope.param(0, p1, "name p1")
110const ps2 = scope.param(1, p2, "name p2")
111if (scope.unchanged) return scope.cached
112
113// body
114if (a) return scope.recache("1")
115return scope.recache("2")
116}
117*/
118
119// Given a function parameter120//121// /** @memo */122// function foo(width: number)123//124// produces a tracking state variable for it125//126// const __memo_state_width = context.param(width)127//128parameterStateStatement(parameter: ts.ParameterDeclaration, parameterIndex: number): ts.Statement|undefined {129if (!ts.isIdentifier(parameter.name)) return undefined130const parameterNameString = ts.idText(parameter.name)131const parameterName = ts.factory.createIdentifier(parameterNameString)132return localStateStatement(133parameterNameString,134parameterName,135parameterIndex
136)137}138
139parameterStateStatements(140originalParameters: ReadonlyArray<ts.ParameterDeclaration>,141hasThis: boolean142): ts.Statement[] {143const trackable = originalParameters.filter(parameter => isTrackableParameter(this.sourceFile, parameter))144
145const statements = trackable.map(146(parameter, index) => this.parameterStateStatement(parameter, hasThis ? index + 1 : index )147).filter(it => it !== undefined) as ts.Statement[]148
149if (hasThis) {150statements.push(151localStateStatement(152"this", ts.factory.createThis(), 0)153)154}155return statements156}157
158createEarlyReturn(originalType: ts.TypeNode | undefined): ts.Statement {159return ts.factory.createIfStatement(160ts.factory.createPropertyAccessExpression(161runtimeIdentifier(RuntimeNames.SCOPE),162runtimeIdentifier(RuntimeNames.INTERNAL_VALUE_OK),163),164isVoidOrNotSpecified(originalType)165? ts.factory.createBlock([166ts.factory.createExpressionStatement(167ts.factory.createPropertyAccessExpression(168runtimeIdentifier(RuntimeNames.SCOPE),169runtimeIdentifier(RuntimeNames.INTERNAL_VALUE),170)),171createSyntheticReturn(undefined)172])173: createSyntheticReturn(174ts.factory.createPropertyAccessExpression(175runtimeIdentifier(RuntimeNames.SCOPE),176runtimeIdentifier(RuntimeNames.INTERNAL_VALUE),177)178)179)180}181
182transformComputeCallable(183originalParameters: ts.NodeArray<ts.ParameterDeclaration>,184originalBody: ts.ConciseBody | undefined,185originalType: ts.TypeNode | undefined,186hasThis: boolean = false187): {188newParameters: ts.ParameterDeclaration[],189newBody: ts.FunctionBody|undefined190} {191const additionalParameters = hiddenParameters(this.rewrite)192const newParameters = additionalParameters.concat(originalParameters)193
194if (!originalBody) return {newParameters, newBody: undefined}195
196const parameterStates = this.parameterStateStatements(originalParameters, hasThis)197const scope = createComputeScope(parameterStates.length, idPlusKey(this.rewrite.positionalIdTracker), originalType)198const earlyReturn = this.createEarlyReturn(originalType)199
200let newStatements: ts.Statement[]201if (ts.isBlock(originalBody)) {202const lastStatement = originalBody.statements[originalBody.statements.length - 1]203if (!lastStatement || !ts.isReturnStatement(lastStatement)) {204newStatements = [205scope,206...parameterStates,207earlyReturn,208...originalBody.statements,209ts.factory.createReturnStatement()210]211} else {212newStatements = [213scope,214...parameterStates,215earlyReturn,216...originalBody.statements217]218}219} else { // It is an expression or undefined220newStatements = [221scope,222...parameterStates,223earlyReturn,224ts.factory.createReturnStatement(originalBody)225]226}227
228const newBody = ts.factory.createBlock(229newStatements,230true231)232
233return {newParameters, newBody}234}235
236getReturnType(node: ts.FunctionLikeDeclaration): ts.TypeNode | undefined {237// We mandate explicit return type for all memo functions238// And only allow it for void case.239return node.type ?? ts.factory.createKeywordTypeNode(ts.SyntaxKind.VoidKeyword)240}241
242transformComputeMethod(originalMethod: ts.MethodDeclaration): ts.MethodDeclaration {243const { newParameters, newBody } = this.transformComputeCallable(244originalMethod.parameters,245originalMethod.body,246this.getReturnType(originalMethod),247(!hasStaticModifier(originalMethod) && !isMethodOfStableClass(this.sourceFile, originalMethod))248)249const updatedMethod = ts.factory.updateMethodDeclaration(250originalMethod,251originalMethod.modifiers,252originalMethod.asteriskToken,253originalMethod.name,254originalMethod.questionToken,255originalMethod.typeParameters,256newParameters,257originalMethod.type,258newBody
259)260
261return updatedMethod262}263
264transformComputeArrow(originalArrow: ts.ArrowFunction): ts.ArrowFunction {265const { newParameters, newBody } = this.transformComputeCallable(266originalArrow.parameters,267originalArrow.body,268this.getReturnType(originalArrow)269)270const updatedArrow = ts.factory.updateArrowFunction(271originalArrow,272originalArrow.modifiers,273originalArrow.typeParameters,274newParameters,275originalArrow.type,276ts.factory.createToken(ts.SyntaxKind.EqualsGreaterThanToken),277newBody!278)279
280return updatedArrow281}282
283transformComputeFunctionExpression(original: ts.FunctionExpression): ts.FunctionExpression {284const { newParameters, newBody } = this.transformComputeCallable(285original.parameters,286original.body,287this.getReturnType(original)288)289const updated = ts.factory.updateFunctionExpression(290original,291original.modifiers,292original.asteriskToken,293original.name,294original.typeParameters,295newParameters,296original.type,297newBody!298)299
300return updated301}302
303transformComputeFunctionType(original: ts.FunctionTypeNode): ts.FunctionTypeNode {304const { newParameters, newBody } = this.transformComputeCallable(305original.parameters,306undefined,307undefined308)309const updated = ts.factory.updateFunctionTypeNode(310original,311original.typeParameters,312ts.factory.createNodeArray(newParameters),313original.type314)315
316return updated317}318
319maybeTransformFunctionType(original: ts.TypeNode): ts.TypeNode {320if (!ts.isFunctionTypeNode(original)) return original321return this.transformComputeFunctionType(original)322}323
324transformGetter(original: ts.GetAccessorDeclaration): ts.GetAccessorDeclaration {325return ts.factory.updateGetAccessorDeclaration(326original,327original.modifiers,328original.name,329original.parameters,330original.type ? this.maybeTransformFunctionType(original.type) : undefined,331original.body332)333}334
335transformSetter(original: ts.SetAccessorDeclaration): ts.SetAccessorDeclaration {336return ts.factory.updateSetAccessorDeclaration(337original,338original.modifiers,339original.name,340original.parameters,341original.body342)343}344
345transformComputeMethodSignature(original: ts.MethodSignature): ts.MethodSignature {346const { newParameters, newBody } = this.transformComputeCallable(347original.parameters,348undefined,349undefined350)351const updated = ts.factory.updateMethodSignature(352original,353original.modifiers,354original.name,355original.questionToken,356original.typeParameters,357ts.factory.createNodeArray(newParameters),358original.type359)360
361return updated362}363
364
365transformComputeFunction(originalFunction: ts.FunctionDeclaration): ts.FunctionDeclaration {366const { newParameters, newBody } = this.transformComputeCallable(367originalFunction.parameters,368originalFunction.body,369this.getReturnType(originalFunction)370)371const updatedFunction = ts.factory.updateFunctionDeclaration(372originalFunction,373originalFunction.modifiers,374originalFunction.asteriskToken,375originalFunction.name,376originalFunction.typeParameters,377newParameters,378originalFunction.type,379newBody
380)381
382return updatedFunction383}384
385transformIntrinsicMethod(originalMethod: ts.MethodDeclaration): ts.MethodDeclaration {386const updatedMethod = ts.factory.updateMethodDeclaration(387originalMethod,388originalMethod.modifiers,389originalMethod.asteriskToken,390originalMethod.name,391originalMethod.questionToken,392originalMethod.typeParameters,393hiddenParameters(this.rewrite).concat(originalMethod.parameters),394originalMethod.type,395originalMethod.body396)397
398return updatedMethod399}400
401transformIntrinsicFunction(originalFunction: ts.FunctionDeclaration): ts.FunctionDeclaration {402const updatedFunction = ts.factory.updateFunctionDeclaration(403originalFunction,404originalFunction.modifiers,405originalFunction.asteriskToken,406originalFunction.name,407originalFunction.typeParameters,408hiddenParameters(this.rewrite).concat(originalFunction.parameters),409originalFunction.type,410originalFunction.body411)412
413return updatedFunction414}415
416transformCallWithName(node: ts.CallExpression, sourceFile: ts.SourceFile, name: string): ts.CallExpression {417const args = node.arguments.slice()418const idExpression = idPlusKey(this.rewrite.positionalIdTracker)419
420args.unshift(421idExpression
422)423args.unshift(424runtimeIdentifier(RuntimeNames.CONTEXT),425)426
427const update = ts.factory.updateCallExpression(428node,429node.expression,430node.typeArguments,431args
432)433return update434}435
436transformCall(node: ts.CallExpression, sourceFile: ts.SourceFile): ts.CallExpression {437const identifier = node.expression438if (!identifier || !ts.isIdentifier(identifier)) {439throw `Could not transform @memo call to ${identifier}`440}441const name = ts.idText(identifier)442
443return this.transformCallWithName(node, sourceFile, name)444}445
446transformMethodCall(node: ts.CallExpression, sourceFile: ts.SourceFile): ts.CallExpression {447if (!ts.isPropertyAccessExpression(node.expression)) {448throw `Unexpected expression kind ${ts.SyntaxKind[node.expression.kind]}`449}450const member = node.expression.name451if (!member || !ts.isIdentifier(member)) {452throw `Could not transform @memo call to ${member}`453}454const name = ts.idText(member)455return this.transformCallWithName(node, sourceFile, name)456}457
458wrapObjectLiteralInCompute(node: ts.ObjectLiteralExpression|ts.ArrowFunction): ts.Expression {459return wrapInCompute(node, idPlusKey(this.rewrite.positionalIdTracker))460}461
462transformTransformedType(node: ts.TypeReferenceNode): ts.TypeReferenceNode {463return ts.factory.updateTypeReferenceNode(464node,465ts.factory.createIdentifier("number"), // Yes, just number.466undefined467)468}469
470visitCall(node: ts.CallExpression): ts.Expression {471let transformed: ts.Expression | undefined = undefined472if (ts.isIdentifier(node.expression)) {473const identifier = node.expression474if (callIsIntrinsicContext(identifier)) {475transformed = transformIntrinsicContextCall(node)476} else if (callIsIntrinsicId(identifier)) {477transformed = transformIntrinsicIdCall(node)478} else if (callIsIntrinsicKey(identifier)) {479transformed = transformIntrinsicKeyCall(this.rewrite.positionalIdTracker, node)480} else {481const kind = this.rewrite.callTable.get(ts.getOriginalNode(node, ts.isCallExpression))482if (kind) {483transformed = this.transformCall(node, this.sourceFile)484}485}486} else if (ts.isPropertyAccessExpression(node.expression)) {487const kind = this.rewrite.callTable.get(ts.getOriginalNode(node, ts.isCallExpression))488if (kind) {489transformed = this.transformMethodCall(node, this.sourceFile)490}491}492return transformed ?? node493}494
495visitor(beforeChildren: ts.Node): ts.Node {496// Note that visitEachChild can create a new node if its children have changed.497const node = this.visitEachChild(beforeChildren)498let transformed: ts.Node|undefined = undefined499
500if (ts.isCallExpression(node)) {501transformed = this.visitCall(node)502} else if (ts.isFunctionDeclaration(node)) {503const originalNode = ts.getOriginalNode(node, ts.isFunctionDeclaration)504switch (this.rewrite.functionTable.get(originalNode)) {505case FunctionKind.MEMO:506transformed = this.transformComputeFunction(node)507break;508case FunctionKind.MEMO_INTRINSIC:509transformed = this.transformIntrinsicFunction(node)510break;511}512} else if (ts.isMethodDeclaration(node)) {513const originalNode = ts.getOriginalNode(node, ts.isMethodDeclaration)514switch (this.rewrite.functionTable.get(originalNode)) {515case FunctionKind.MEMO:516transformed = this.transformComputeMethod(node)517break;518case FunctionKind.MEMO_INTRINSIC:519transformed = this.transformIntrinsicMethod(node)520break;521}522} else if (ts.isArrowFunction(node)) {523const originalNode = ts.getOriginalNode(node, ts.isArrowFunction)524switch (this.rewrite.functionTable.get(originalNode)) {525case FunctionKind.MEMO:526transformed = this.transformComputeArrow(node)527break528case FunctionKind.MEMO_INTRINSIC:529break530default: {531const originalNode = ts.getOriginalNode(node, ts.isArrowFunction)532const originalParent = originalNode.parent533if (originalParent && ts.isCallExpression(originalParent)) {534if (isMemoKind(this.rewrite.callTable.get(originalParent)))535transformed = this.wrapObjectLiteralInCompute(node)536}537}538/*539TODO: no memo:intrinsic lambdas for now.
540*/
541}542} else if (ts.isFunctionExpression(node)) {543const originalNode = ts.getOriginalNode(node, ts.isFunctionExpression)544switch (this.rewrite.functionTable.get(originalNode)) {545case FunctionKind.MEMO:546transformed = this.transformComputeFunctionExpression(node)547break;548/*549TODO: no memo:intrinsic lambdas for now.
550*/
551}552} else if (ts.isFunctionTypeNode(node)) {553const originalNode = ts.getOriginalNode(node, ts.isFunctionTypeNode)554switch (this.rewrite.functionTable.get(originalNode)) {555case FunctionKind.MEMO:556transformed = this.transformComputeFunctionType(node)557break;558/*559TODO: no memo:intrinsic lambdas for now.
560*/
561}562} else if (ts.isMethodSignature(node)) {563const originalNode = ts.getOriginalNode(node, ts.isMethodSignature)564switch (this.rewrite.functionTable.get(originalNode)) {565case FunctionKind.MEMO:566transformed = this.transformComputeMethodSignature(node)567break;568/*569TODO: no memo:intrinsic signatures for now.
570*/
571}572} else if (ts.isGetAccessor(node)) {573const originalNode = ts.getOriginalNode(node, ts.isGetAccessorDeclaration)574if (this.rewrite.functionTable.get(originalNode) == FunctionKind.MEMO) {575transformed = this.transformGetter(node)576}577} else if (ts.isSetAccessor(node)) {578const originalNode = ts.getOriginalNode(node, ts.isSetAccessorDeclaration)579if (this.rewrite.functionTable.get(originalNode) == FunctionKind.MEMO) {580transformed = this.transformSetter(node)581}582} else if (ts.isObjectLiteralExpression(node)) {583const originalNode = ts.getOriginalNode(node, ts.isObjectLiteralExpression)584const originalParent = originalNode.parent585if (originalParent && ts.isCallExpression(originalParent)) {586if (isMemoKind(this.rewrite.callTable.get(originalParent)))587transformed = this.wrapObjectLiteralInCompute(node)588}589} else if (ts.isTypeReferenceNode(node)) {590if (ts.isIdentifier(node.typeName) && ts.idText(node.typeName) == RuntimeNames.TRANSFORMED_TYPE) {591transformed = this.transformTransformedType(node)592}593}594return transformed ?? node595}596}
597
598function createSyntheticReturn(node: ts.Expression | undefined): ts.ReturnStatement {599return ts.factory.createReturnStatement(600ts.factory.createCallExpression(601runtimeIdentifier(RuntimeNames.SYNTHETIC_RETURN_MARK),602undefined,603node ? [node] : undefined604)605)606}
607