idlize

Форк
0
/
function-transformer.ts 
606 строк · 22.0 Кб
1
/*
2
 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3
 * Licensed under the Apache License, Version 2.0 (the "License");
4
 * you may not use this file except in compliance with the License.
5
 * You may obtain a copy of the License at
6
 *
7
 * http://www.apache.org/licenses/LICENSE-2.0
8
 *
9
 * Unless required by applicable law or agreed to in writing, software
10
 * distributed under the License is distributed on an "AS IS" BASIS,
11
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
 * See the License for the specific language governing permissions and
13
 * limitations under the License.
14
 */
15

16
import * as ts from 'typescript';
17
import { AbstractVisitor } from "./AbstractVisitor"
18
import { Rewrite } from './transformation-context';
19
import {
20
    FunctionKind,
21
    Tracer,
22
    idPlusKey,
23
    PositionalIdTracker,
24
    wrapInCompute,
25
    RuntimeNames,
26
    runtimeIdentifier,
27
    isMemoKind,
28
    hiddenParameters,
29
    createComputeScope,
30
    isVoidOrNotSpecified,
31
    localStateStatement,
32
    isTrackableParameter,
33
    hasStaticModifier,
34
    isMethodOfStableClass,
35
} from "./util"
36

37

38
function callIsIntrinsicContext(node: ts.Identifier): boolean {
39
    return ts.idText(node) == RuntimeNames.__CONTEXT
40
}
41

42
function callIsIntrinsicId(node: ts.Identifier): boolean {
43
    return ts.idText(node) == RuntimeNames.__ID
44
}
45

46
function callIsIntrinsicKey(node: ts.Identifier): boolean {
47
    return ts.idText(node) == RuntimeNames.__KEY
48
}
49

50
function transformIntrinsicContextCall(node: ts.CallExpression): ts.Identifier {
51
    return runtimeIdentifier(RuntimeNames.CONTEXT)
52
}
53

54
function transformIntrinsicIdCall(node: ts.CallExpression): ts.Identifier {
55
    return runtimeIdentifier(RuntimeNames.ID)
56
}
57

58
function transformIntrinsicKeyCall(positionalIdTracker: PositionalIdTracker, node: ts.CallExpression): ts.Expression {
59
    return positionalIdTracker.id(RuntimeNames.__KEY)
60
}
61

62
// Given a lambda type (foo:XXX, bar:YYY) => ZZZ
63
// returns the type name XXX of the first argument.
64
// Or undefined if such name can't be found.
65
function firstArgTypeName(lambdaType: ts.TypeNode|undefined): string|undefined {
66
    if (lambdaType === undefined) return undefined
67
    if (!ts.isFunctionTypeNode(lambdaType)) return undefined
68
    const firstArg = lambdaType?.parameters[0]
69
    if (!firstArg) return undefined
70
    if (!firstArg.type) return undefined
71
    if (!ts.isTypeReferenceNode(firstArg.type)) return undefined
72
    if (!ts.isIdentifier(firstArg.type.typeName)) return undefined
73
    return ts.idText(firstArg.type.typeName)
74
}
75

76
export class FunctionTransformer extends AbstractVisitor {
77
    constructor(
78
        public tracer: Tracer,
79
        public typechecker: ts.TypeChecker,
80
        public sourceFile: ts.SourceFile,
81
        public rewrite: Rewrite,
82
        ctx: ts.TransformationContext
83
    ) {
84
        super(ctx)
85
    }
86

87
    trace(msg: any) {
88
        this.tracer.trace(msg)
89
    }
90

91

92
/*
93
// The context.compute rewrite with lambda
94

95
function foo0(p1, p2) {
96
    const ps1 = context.parState()
97
    const ps2 = context.parState()
98
    return context.compute(__contex, __id + "id", () => {
99
        if (a) return "1"
100
        return "2"
101
        // body
102
    })
103
}
104

105
// The lambda-less rewrite
106

107
function foo1(p1, p2): R {
108
    const scope = __context().scope<R>(__id() + "id", 2)
109
    const ps1 = scope.param(0, p1, "name p1")
110
    const ps2 = scope.param(1, p2, "name p2")
111
    if (scope.unchanged) return scope.cached
112

113
    // body
114
    if (a) return scope.recache("1")
115
    return scope.recache("2")
116
}
117
*/
118

119
    // Given a function parameter
120
    //
121
    // /** @memo */
122
    // function foo(width: number)
123
    //
124
    // produces a tracking state variable for it
125
    //
126
    // const __memo_state_width = context.param(width)
127
    //
128
    parameterStateStatement(parameter: ts.ParameterDeclaration, parameterIndex: number): ts.Statement|undefined {
129
        if (!ts.isIdentifier(parameter.name)) return undefined
130
        const parameterNameString = ts.idText(parameter.name)
131
        const parameterName = ts.factory.createIdentifier(parameterNameString)
132
        return localStateStatement(
133
            parameterNameString,
134
            parameterName,
135
            parameterIndex
136
        )
137
    }
138

139
    parameterStateStatements(
140
        originalParameters: ReadonlyArray<ts.ParameterDeclaration>,
141
        hasThis: boolean
142
    ): ts.Statement[] {
143
        const trackable = originalParameters.filter(parameter => isTrackableParameter(this.sourceFile, parameter))
144

145
        const statements =  trackable.map(
146
            (parameter, index) => this.parameterStateStatement(parameter, hasThis ? index + 1 : index )
147
        ).filter(it => it !== undefined) as ts.Statement[]
148

149
        if (hasThis) {
150
            statements.push(
151
                localStateStatement(
152
                    "this", ts.factory.createThis(), 0)
153
            )
154
        }
155
        return statements
156
    }
157

158
    createEarlyReturn(originalType: ts.TypeNode | undefined): ts.Statement {
159
        return ts.factory.createIfStatement(
160
            ts.factory.createPropertyAccessExpression(
161
                runtimeIdentifier(RuntimeNames.SCOPE),
162
                runtimeIdentifier(RuntimeNames.INTERNAL_VALUE_OK),
163
            ),
164
            isVoidOrNotSpecified(originalType)
165
                ? ts.factory.createBlock([
166
                    ts.factory.createExpressionStatement(
167
                        ts.factory.createPropertyAccessExpression(
168
                            runtimeIdentifier(RuntimeNames.SCOPE),
169
                            runtimeIdentifier(RuntimeNames.INTERNAL_VALUE),
170
                        )),
171
                    createSyntheticReturn(undefined)
172
                ])
173
                : createSyntheticReturn(
174
                    ts.factory.createPropertyAccessExpression(
175
                        runtimeIdentifier(RuntimeNames.SCOPE),
176
                        runtimeIdentifier(RuntimeNames.INTERNAL_VALUE),
177
                    )
178
                )
179
        )
180
    }
181

182
    transformComputeCallable(
183
        originalParameters: ts.NodeArray<ts.ParameterDeclaration>,
184
        originalBody: ts.ConciseBody | undefined,
185
        originalType: ts.TypeNode | undefined,
186
        hasThis: boolean = false
187
    ): {
188
        newParameters: ts.ParameterDeclaration[],
189
            newBody: ts.FunctionBody|undefined
190
    } {
191
        const additionalParameters = hiddenParameters(this.rewrite)
192
        const newParameters = additionalParameters.concat(originalParameters)
193

194
        if (!originalBody) return {newParameters, newBody: undefined}
195

196
        const parameterStates = this.parameterStateStatements(originalParameters, hasThis)
197
        const scope = createComputeScope(parameterStates.length, idPlusKey(this.rewrite.positionalIdTracker), originalType)
198
        const earlyReturn = this.createEarlyReturn(originalType)
199

200
        let newStatements: ts.Statement[]
201
        if (ts.isBlock(originalBody)) {
202
            const lastStatement = originalBody.statements[originalBody.statements.length - 1]
203
            if (!lastStatement || !ts.isReturnStatement(lastStatement)) {
204
                newStatements = [
205
                    scope,
206
                    ...parameterStates,
207
                    earlyReturn,
208
                    ...originalBody.statements,
209
                    ts.factory.createReturnStatement()
210
                ]
211
            } else {
212
                newStatements = [
213
                    scope,
214
                    ...parameterStates,
215
                    earlyReturn,
216
                    ...originalBody.statements
217
                ]
218
            }
219
        } else { // It is an expression or undefined
220
            newStatements = [
221
                scope,
222
                ...parameterStates,
223
                earlyReturn,
224
                ts.factory.createReturnStatement(originalBody)
225
            ]
226
        }
227

228
        const newBody = ts.factory.createBlock(
229
            newStatements,
230
            true
231
        )
232

233
        return {newParameters, newBody}
234
    }
235

236
    getReturnType(node: ts.FunctionLikeDeclaration): ts.TypeNode | undefined {
237
        // We mandate explicit return type for all memo functions
238
        // And only allow it for void case.
239
        return node.type ?? ts.factory.createKeywordTypeNode(ts.SyntaxKind.VoidKeyword)
240
    }
241

242
    transformComputeMethod(originalMethod: ts.MethodDeclaration): ts.MethodDeclaration {
243
        const { newParameters, newBody } = this.transformComputeCallable(
244
            originalMethod.parameters,
245
            originalMethod.body,
246
            this.getReturnType(originalMethod),
247
            (!hasStaticModifier(originalMethod) && !isMethodOfStableClass(this.sourceFile, originalMethod))
248
        )
249
        const updatedMethod = ts.factory.updateMethodDeclaration(
250
            originalMethod,
251
            originalMethod.modifiers,
252
            originalMethod.asteriskToken,
253
            originalMethod.name,
254
            originalMethod.questionToken,
255
            originalMethod.typeParameters,
256
            newParameters,
257
            originalMethod.type,
258
            newBody
259
        )
260

261
        return updatedMethod
262
    }
263

264
    transformComputeArrow(originalArrow: ts.ArrowFunction): ts.ArrowFunction {
265
        const { newParameters, newBody } = this.transformComputeCallable(
266
            originalArrow.parameters,
267
            originalArrow.body,
268
            this.getReturnType(originalArrow)
269
        )
270
        const updatedArrow = ts.factory.updateArrowFunction(
271
            originalArrow,
272
            originalArrow.modifiers,
273
            originalArrow.typeParameters,
274
            newParameters,
275
            originalArrow.type,
276
            ts.factory.createToken(ts.SyntaxKind.EqualsGreaterThanToken),
277
            newBody!
278
        )
279

280
        return updatedArrow
281
    }
282

283
    transformComputeFunctionExpression(original: ts.FunctionExpression): ts.FunctionExpression {
284
        const { newParameters, newBody } = this.transformComputeCallable(
285
            original.parameters,
286
            original.body,
287
            this.getReturnType(original)
288
        )
289
        const updated = ts.factory.updateFunctionExpression(
290
            original,
291
            original.modifiers,
292
            original.asteriskToken,
293
            original.name,
294
            original.typeParameters,
295
            newParameters,
296
            original.type,
297
            newBody!
298
        )
299

300
        return updated
301
    }
302

303
    transformComputeFunctionType(original: ts.FunctionTypeNode): ts.FunctionTypeNode {
304
        const { newParameters, newBody } = this.transformComputeCallable(
305
            original.parameters,
306
            undefined,
307
            undefined
308
        )
309
        const updated = ts.factory.updateFunctionTypeNode(
310
            original,
311
            original.typeParameters,
312
            ts.factory.createNodeArray(newParameters),
313
            original.type
314
        )
315

316
        return updated
317
    }
318

319
    maybeTransformFunctionType(original: ts.TypeNode): ts.TypeNode {
320
        if (!ts.isFunctionTypeNode(original)) return original
321
        return this.transformComputeFunctionType(original)
322
    }
323

324
    transformGetter(original: ts.GetAccessorDeclaration): ts.GetAccessorDeclaration {
325
        return ts.factory.updateGetAccessorDeclaration(
326
            original,
327
            original.modifiers,
328
            original.name,
329
            original.parameters,
330
            original.type ? this.maybeTransformFunctionType(original.type) : undefined,
331
            original.body
332
        )
333
    }
334

335
    transformSetter(original: ts.SetAccessorDeclaration): ts.SetAccessorDeclaration {
336
        return ts.factory.updateSetAccessorDeclaration(
337
            original,
338
            original.modifiers,
339
            original.name,
340
            original.parameters,
341
            original.body
342
        )
343
    }
344

345
    transformComputeMethodSignature(original: ts.MethodSignature): ts.MethodSignature {
346
        const { newParameters, newBody } = this.transformComputeCallable(
347
            original.parameters,
348
            undefined,
349
            undefined
350
        )
351
        const updated = ts.factory.updateMethodSignature(
352
            original,
353
            original.modifiers,
354
            original.name,
355
            original.questionToken,
356
            original.typeParameters,
357
            ts.factory.createNodeArray(newParameters),
358
            original.type
359
        )
360

361
        return updated
362
    }
363

364

365
    transformComputeFunction(originalFunction: ts.FunctionDeclaration): ts.FunctionDeclaration {
366
        const { newParameters, newBody } = this.transformComputeCallable(
367
            originalFunction.parameters,
368
            originalFunction.body,
369
            this.getReturnType(originalFunction)
370
        )
371
        const updatedFunction = ts.factory.updateFunctionDeclaration(
372
            originalFunction,
373
            originalFunction.modifiers,
374
            originalFunction.asteriskToken,
375
            originalFunction.name,
376
            originalFunction.typeParameters,
377
            newParameters,
378
            originalFunction.type,
379
            newBody
380
        )
381

382
        return updatedFunction
383
    }
384

385
    transformIntrinsicMethod(originalMethod: ts.MethodDeclaration): ts.MethodDeclaration {
386
        const updatedMethod = ts.factory.updateMethodDeclaration(
387
            originalMethod,
388
            originalMethod.modifiers,
389
            originalMethod.asteriskToken,
390
            originalMethod.name,
391
            originalMethod.questionToken,
392
            originalMethod.typeParameters,
393
            hiddenParameters(this.rewrite).concat(originalMethod.parameters),
394
            originalMethod.type,
395
            originalMethod.body
396
        )
397

398
        return updatedMethod
399
    }
400

401
    transformIntrinsicFunction(originalFunction: ts.FunctionDeclaration): ts.FunctionDeclaration {
402
        const updatedFunction = ts.factory.updateFunctionDeclaration(
403
            originalFunction,
404
            originalFunction.modifiers,
405
            originalFunction.asteriskToken,
406
            originalFunction.name,
407
            originalFunction.typeParameters,
408
            hiddenParameters(this.rewrite).concat(originalFunction.parameters),
409
            originalFunction.type,
410
            originalFunction.body
411
        )
412

413
        return updatedFunction
414
    }
415

416
    transformCallWithName(node: ts.CallExpression, sourceFile: ts.SourceFile, name: string): ts.CallExpression {
417
        const args = node.arguments.slice()
418
        const idExpression = idPlusKey(this.rewrite.positionalIdTracker)
419

420
        args.unshift(
421
            idExpression
422
        )
423
        args.unshift(
424
            runtimeIdentifier(RuntimeNames.CONTEXT),
425
        )
426

427
        const update = ts.factory.updateCallExpression(
428
            node,
429
            node.expression,
430
            node.typeArguments,
431
            args
432
        )
433
        return update
434
    }
435

436
    transformCall(node: ts.CallExpression, sourceFile: ts.SourceFile): ts.CallExpression {
437
        const identifier = node.expression
438
        if (!identifier || !ts.isIdentifier(identifier)) {
439
            throw `Could not transform @memo call to ${identifier}`
440
        }
441
        const name = ts.idText(identifier)
442

443
        return this.transformCallWithName(node, sourceFile, name)
444
    }
445

446
    transformMethodCall(node: ts.CallExpression, sourceFile: ts.SourceFile): ts.CallExpression {
447
        if (!ts.isPropertyAccessExpression(node.expression)) {
448
            throw `Unexpected expression kind ${ts.SyntaxKind[node.expression.kind]}`
449
        }
450
        const member = node.expression.name
451
        if (!member || !ts.isIdentifier(member)) {
452
            throw `Could not transform @memo call to ${member}`
453
        }
454
        const name = ts.idText(member)
455
        return this.transformCallWithName(node, sourceFile, name)
456
    }
457

458
    wrapObjectLiteralInCompute(node: ts.ObjectLiteralExpression|ts.ArrowFunction): ts.Expression {
459
        return wrapInCompute(node, idPlusKey(this.rewrite.positionalIdTracker))
460
    }
461

462
    transformTransformedType(node: ts.TypeReferenceNode): ts.TypeReferenceNode {
463
        return ts.factory.updateTypeReferenceNode(
464
            node,
465
            ts.factory.createIdentifier("number"), // Yes, just number.
466
            undefined
467
        )
468
    }
469

470
    visitCall(node: ts.CallExpression): ts.Expression {
471
        let transformed: ts.Expression | undefined = undefined
472
        if (ts.isIdentifier(node.expression)) {
473
            const identifier = node.expression
474
            if (callIsIntrinsicContext(identifier)) {
475
                transformed = transformIntrinsicContextCall(node)
476
            } else if (callIsIntrinsicId(identifier)) {
477
                transformed = transformIntrinsicIdCall(node)
478
            } else if (callIsIntrinsicKey(identifier)) {
479
                transformed = transformIntrinsicKeyCall(this.rewrite.positionalIdTracker, node)
480
            } else {
481
                const kind = this.rewrite.callTable.get(ts.getOriginalNode(node, ts.isCallExpression))
482
                if (kind) {
483
                    transformed = this.transformCall(node, this.sourceFile)
484
                }
485
            }
486
        } else if (ts.isPropertyAccessExpression(node.expression)) {
487
            const kind = this.rewrite.callTable.get(ts.getOriginalNode(node, ts.isCallExpression))
488
            if (kind) {
489
                transformed = this.transformMethodCall(node, this.sourceFile)
490
            }
491
        }
492
        return transformed ?? node
493
    }
494

495
    visitor(beforeChildren: ts.Node): ts.Node {
496
        // Note that visitEachChild can create a new node if its children have changed.
497
        const node = this.visitEachChild(beforeChildren)
498
        let transformed: ts.Node|undefined = undefined
499

500
        if (ts.isCallExpression(node)) {
501
            transformed = this.visitCall(node)
502
        } else if (ts.isFunctionDeclaration(node)) {
503
            const originalNode = ts.getOriginalNode(node, ts.isFunctionDeclaration)
504
            switch (this.rewrite.functionTable.get(originalNode)) {
505
                case FunctionKind.MEMO:
506
                    transformed = this.transformComputeFunction(node)
507
                    break;
508
                case FunctionKind.MEMO_INTRINSIC:
509
                    transformed = this.transformIntrinsicFunction(node)
510
                    break;
511
            }
512
        } else if (ts.isMethodDeclaration(node)) {
513
            const originalNode = ts.getOriginalNode(node, ts.isMethodDeclaration)
514
            switch (this.rewrite.functionTable.get(originalNode)) {
515
                case FunctionKind.MEMO:
516
                    transformed = this.transformComputeMethod(node)
517
                    break;
518
                case FunctionKind.MEMO_INTRINSIC:
519
                    transformed = this.transformIntrinsicMethod(node)
520
                    break;
521
            }
522
        } else if (ts.isArrowFunction(node)) {
523
            const originalNode = ts.getOriginalNode(node, ts.isArrowFunction)
524
            switch (this.rewrite.functionTable.get(originalNode)) {
525
                case FunctionKind.MEMO:
526
                    transformed = this.transformComputeArrow(node)
527
                    break
528
                case FunctionKind.MEMO_INTRINSIC:
529
                    break
530
                default: {
531
                    const originalNode = ts.getOriginalNode(node, ts.isArrowFunction)
532
                    const originalParent = originalNode.parent
533
                    if (originalParent && ts.isCallExpression(originalParent)) {
534
                        if (isMemoKind(this.rewrite.callTable.get(originalParent)))
535
                            transformed = this.wrapObjectLiteralInCompute(node)
536
                    }
537
                }
538
                /*
539
                    TODO: no memo:intrinsic lambdas for now.
540
                */
541
            }
542
        } else if (ts.isFunctionExpression(node)) {
543
            const originalNode = ts.getOriginalNode(node, ts.isFunctionExpression)
544
            switch (this.rewrite.functionTable.get(originalNode)) {
545
                case FunctionKind.MEMO:
546
                    transformed = this.transformComputeFunctionExpression(node)
547
                    break;
548
                /*
549
                    TODO: no memo:intrinsic lambdas for now.
550
                */
551
            }
552
        } else if (ts.isFunctionTypeNode(node)) {
553
            const originalNode = ts.getOriginalNode(node, ts.isFunctionTypeNode)
554
            switch (this.rewrite.functionTable.get(originalNode)) {
555
                case FunctionKind.MEMO:
556
                    transformed = this.transformComputeFunctionType(node)
557
                    break;
558
                /*
559
                    TODO: no memo:intrinsic lambdas for now.
560
                */
561
            }
562
        } else if (ts.isMethodSignature(node)) {
563
            const originalNode = ts.getOriginalNode(node, ts.isMethodSignature)
564
            switch (this.rewrite.functionTable.get(originalNode)) {
565
                case FunctionKind.MEMO:
566
                    transformed = this.transformComputeMethodSignature(node)
567
                    break;
568
                /*
569
                    TODO: no memo:intrinsic signatures for now.
570
                */
571
            }
572
        } else if (ts.isGetAccessor(node)) {
573
            const originalNode = ts.getOriginalNode(node, ts.isGetAccessorDeclaration)
574
            if (this.rewrite.functionTable.get(originalNode) == FunctionKind.MEMO) {
575
                transformed = this.transformGetter(node)
576
            }
577
        } else if (ts.isSetAccessor(node)) {
578
            const originalNode = ts.getOriginalNode(node, ts.isSetAccessorDeclaration)
579
            if (this.rewrite.functionTable.get(originalNode) == FunctionKind.MEMO) {
580
                transformed = this.transformSetter(node)
581
            }
582
        } else if (ts.isObjectLiteralExpression(node)) {
583
            const originalNode = ts.getOriginalNode(node, ts.isObjectLiteralExpression)
584
            const originalParent = originalNode.parent
585
            if (originalParent && ts.isCallExpression(originalParent)) {
586
                if (isMemoKind(this.rewrite.callTable.get(originalParent)))
587
                    transformed = this.wrapObjectLiteralInCompute(node)
588
            }
589
        } else if (ts.isTypeReferenceNode(node)) {
590
            if (ts.isIdentifier(node.typeName) && ts.idText(node.typeName) == RuntimeNames.TRANSFORMED_TYPE) {
591
                transformed = this.transformTransformedType(node)
592
            }
593
        }
594
        return transformed ?? node
595
    }
596
}
597

598
function createSyntheticReturn(node: ts.Expression | undefined): ts.ReturnStatement {
599
    return ts.factory.createReturnStatement(
600
        ts.factory.createCallExpression(
601
            runtimeIdentifier(RuntimeNames.SYNTHETIC_RETURN_MARK),
602
            undefined,
603
            node ? [node] : undefined
604
        )
605
    )
606
}
607

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.