Logo

dev-resources.site

for different kinds of informations.

Create Your Own Programming Language 10: Optimization

Published at
7/9/2023
Categories
Author
Jason Barr
Categories
1 categories in total
open
Create Your Own Programming Language 10: Optimization

Welcome to the next article in the Create Your Own Programming Language series! In this one we're going to transform the abstract syntax tree (AST) to A Normal Form and also implement one of the coolest features of functional languages: tail call optimization.

As always, if you haven't read the previous article on iteration, do that first before continuing.

Ok, good? Let's do it.

A Normal Form

Unless you're a programming language geek, you've probably never heard of A Normal Form (ANF).

In a compiler, it's normal to have multiple passes that transform the AST into intermediary forms. ANF is one such intermediary form.

ANF is an intermediary form where we un-nest nested expressions (like call expressions inside of other call expressions) and flatten things out.

That means we'll take an expression like this one:

(slice (+ 1 3) (* 2 5) (concat (list 1 2 3) (list 4 5 6)))

And turn it into something like this:

(def _1 (+ 1 3))
(def _2 (* 2 5))
(def _3 (list 1 2 3))
(def _4 (list 4 5 6))
(def _5 (concat _3 _4))
(slice _1 _2 _5)

This simplification of nested expressions makes it possible to do all sorts of further transformations and optimizations on the program before emitting code.

Tail Call Optimization

Tail call optimization is a feature most often found in functional languages like Scheme that allows essentially infinite recursion as long as the recursive call is in tail position.

Tail position means the function call is the last expression in a function's control flow path, and the expression is just the call expression with no other operations being performed.

For example, here's a naive implementation of the factorial function:

(def fact (n)
  (if (= n 1)
    n
    (* n (fact (- n 1)))))

The recursive call to fact is not in tail position here, because it's part of the expression beginning with *.

Here's a tail recursive version of fact:

(def fact (n a)
  (if (= n 1)
    a
    (fact (- n 1) (* a n))))

As you can see, in the tail recursive version we keep the value of the multiplication operation as a parameter to the function. This kind of parameter is called an accumulator.

Normally, a recursive call adds another frame to the call stack, and when the stack grows large enough you get a stack overflow error.

In a language like Scheme, the interpreter is optimized so that tail recursive calls don't add new frames to the stack.

Instead, the frame for the new function call replaces the previous frame on the stack, allowing essentially infinite tail recursion.

The technique we use to optimize tail calls will allow us to have proper tail call optimization even though most JavaScript engines don't implement TCO natively (which is a shame, because it IS actually in the ECMAScript specification).

Fixing a Bug in The Parser

Before we begin implementing the A Normal Form transformation we need to quickly fix a bug in the parser.

I've gone back and fixed this in the original post on functions, but in case you read it before I edited it to include the fix I'll go ahead and share it here too.

In parseFunction in src/parser/parse.js there's a bug that prevents the 2nd expression in the body of a function from being properly parsed. In the code that checks to see if there is a return type annotation, we need to make a change. Starting from the top of the function change it to this:

  let retType, body;

  if (maybeArrow.type === TokenTypes.Symbol && maybeArrow.value === "->") {
    // has return type annotation
    retType = parseTypeAnnotation(maybeRetType);
    body = maybeBody;
  } else {
    retType = null;
    body = maybeRetType
      ? [maybeArrow, maybeRetType, ...maybeBody]
      : [maybeArrow, ...maybeBody];
  }

Then the function can continue from const variadic =... to the end.

Preparing for A Normal Form

We also need to make a couple of changes to the core library in preparation for the transformation to A Normal Form. We'll be constructing new AST nodes in the transformation, including some function calls, and we need to make sure those functions are correct in the standard library.

First, we need to make a minor change to the get function in lib/js/core.js. We'll make it so it also works to get an index from a vector:

    get: rt.makeFunction(
      (n, obj) => {
        let value =
          obj.get && typeof obj.get === "function" ? obj.get(n) : obj[n];

        if (value === undefined) {
          throw new Exception(`Value for index ${n} not found on object`);
        }

        return value;
      },
      { contract: "(((list any) || (vector any)) -> any)", name: "get" }
    ),

We're also going to add a new function slice at the bottom of the module:

    slice: rt.makeFunction(
      (start, end, obj = undefined) => {
        if (end < 0) {
          end = [...obj].length + end;
        } else if (obj === undefined) {
          // didn't pass in an end value, which is valid
          obj = end;
          end = [...obj].length;
        }

        if (isList(obj)) {
          let values = [];
          let i = start;

          while (i < end && obj.get(i) !== undefined) {
            values.push(obj.get(i));
          }

          return Cons.from(values);
        } else if (Array.isArray(obj)) {
          return obj.slice(start, end);
        } else {
          throw new Exception(
            "slice can only take a list or vector as its argument"
          );
        }
      },
      {
        name: "slice",
        contract:
          "(number, number, ((vector any) || (list any)) -> ((vector any) || (list any)))",
      }
    ),

As you'll see when we get to the ANF transformation there's a little white lie here: we're also going to use it on tuples. But since tuples are arrays under the hood and at the point of transformation we're already done with the type checker, it's not a big deal to do that. Just remember that slice won't type check for tuples, so you can't actually use it to slice tuples in your Wanda code.

You'll also notice that we've made the 3rd parameter default to undefined because if we didn't then makeFunction would curry it to 3 parameters. We need to be able to call it with only 2 arguments, as you'll see below when we get to transforming variable declarations.

Transforming to A Normal Form

An expression is said to be in A Normal Form if its complex and nested subexpressions have been un-nested and bound in let expressions (we'll use variable declarations).

Primitive values and symbols are already considered to be in A Normal Form.

All arguments to a function must be trivial, which is to say they should already be terminally evaluated and be passed into the function as either primitives or symbols. Another way of putting it is that arguments to a function need to be in a form where evaluating them halts immediately instead of calling out to another function.

You can have some nested expressions, i.e. the test expression for an if expression can still be a call expression as long as its arguments have been terminally evaluated.

You'll start to develop an intuition for it as we go through the conversion process.

Since we're unnesting expressions that are embedded within other expressions, a number of our conversion functions will have to return arrays of expression nodes. That means we're going to be doing a lot of checks for arrays in our conversion functions.

When a transformation produces an array of expressions, the last expression in the array will always be the target expression. So for example if we convert an if expression and it unnests additional expressions, the converted if expression will always be the last member of the array.

This makes array handling more straightforward, because we can always pop the last expression from the array and then concatenate the remainder of the array to our unnested expressions.

The Program node, do expressions, and when expressions all contain a body property that is an array of expressions, so when we get an array of unnested expressions in a function that processes one of those we'll need to use flatMap to join those arrays to the body properties of those nodes.

We're also going to handle destructuring at this stage and unnest all the separate variables that are parts of the vector and record patterns, which means we'll be able to remove the code in the emitter that handles destructured assignment.

It sounds more convoluted than it actually is, so let's start with some code and you can see how it works.

The Dispatcher Function

First, here are the imports for src/transform/anf.js:

import { AST, ASTTypes } from "../parser/ast.js";
import { isPrimitive } from "../parser/utils.js";
import { Exception } from "../shared/exceptions.js";
import { makeGenSym } from "../runtime/makeSymbol.js";
import { Token } from "../lexer/Token.js";
import { TokenTypes } from "../lexer/TokenTypes.js";

We also need these 2 helper functions:

const createFreshSymbol = (srcloc) => {
  return AST.Symbol(Token.new(TokenTypes.Symbol, makeGenSym(), srcloc));
};

const isPrimitiveOrSymbol = (node) => {
  return isPrimitive(node) || node.kind === ASTTypes.Symbol;
};

As usual, we need a dispatcher function that handles each of the different node kinds:

export const anf = (node) => {
  switch (node.kind) {
    case ASTTypes.Program:
      return transformProgram(node);
    case ASTTypes.NumberLiteral:
    case ASTTypes.StringLiteral:
    case ASTTypes.BooleanLiteral:
    case ASTTypes.KeywordLiteral:
    case ASTTypes.NilLiteral:
    case ASTTypes.Symbol:
      return node;
    case ASTTypes.CallExpression:
      return transformCallExpression(node);
    case ASTTypes.LambdaExpression:
      return transformLambdaExpression(node);
    case ASTTypes.VariableDeclaration:
      return transformVariableDeclaration(node);
    case ASTTypes.SetExpression:
      return transformSetExpression(node);
    case ASTTypes.TypeAlias:
      // ignore
      return node;
    case ASTTypes.VectorLiteral:
      return transformVectorLiteral(node);
    case ASTTypes.RecordLiteral:
      return transformRecordLiteral(node);
    case ASTTypes.MemberExpression:
      return transformMemberExpression(node);
    case ASTTypes.DoExpression:
      return transformDoExpression(node);
    case ASTTypes.AsExpression:
      return transformAsExpression(node);
    case ASTTypes.IfExpression:
      return transformIfExpression(node);
    case ASTTypes.WhenExpression:
      return transformWhenExpression(node);
    case ASTTypes.LogicalExpression:
      return transformLogicalExpression(node);
    default:
      throw new Exception(`Unhandled node kind: ${node.kind}`);
  }
};

You may wonder why we're writing separate functions instead of using our generic visitor. I decided there was no benefit to using the visitor since we'd need to write a method for every single node type that's left after desugaring, so I wrote the transformer as a series of functions.

As you can see, for primitives and symbols it simply returns the node. There's nothing to do with type aliases either, so it returns that node as well. The rest of the nodes all dispatch to separate transformer functions.

Transforming The Program

transformProgram is very simple: it just produces a new body by flatMapping over the original body with anf as the map function:

const transformProgram = (node) => {
  let body = node.body.flatMap(anf);

  return { ...node, body };
};

Transforming Call Expressions

transformCallExpression is the first function to return an array of expressions rather than a single expression.

We start by creating the array. Then we transform the function itself. If the result is an array, we pop the function off the array then concatenate the rest of the array to the unnested expressions array.

Then we need to handle arguments.

Arguments can be either simple or complex/nested expressions, so we need to do some work to handle arguments.

First, we create an array to hold the transformed arguments. Then we loop over the original node's arguments.

An argument to a function can be another call expression, in which case we need to process the subcall and unnest any of its arguments.

So if the argument is a call expression we loop over its args and unnest any nested expressions then push the bare arguments themselves onto a subargs array. We transform complex arguments by unnesting them and creating new variable assignments for the nested expressions, pushing the symbols for the new variables onto the subargs array and adding the unnested expressions and new variable declarations to the unnested expressions array.

We take the transformed subcall arguments and construct a new call expression, using it as the expression for a new variable declaration, then add that variable to the arguments array for the original call expression. The variable declaration gets added to the unnested expressions array.

Processing a subcall serves as a base case for recursively processing call expressions.

If an argument to the main call expression is not another call expression, it simply gets processed by the anf function and the transformed expression is handled by popping the main expression off any array that was returned and concatenating any unnested expressions onto the unnested expressions array.

It sounds convoluted, but the code is fairly straightforward. Here's the code for transformCallExpression, where you'll see I've annotated everything with comments so you can relate the code to the explanation I've given above:

const transformCallExpression = (node) => {
  // create an array for unnested expressions from the call expression
  let unnestedExprs = [];
  // transform the function
  let func = anf(node.func);

  // if func has been transformed into an array, get the actual function
  // which will be the last expression in the array from the transformer
  if (Array.isArray(func)) {
    func = func.pop();
    // add the unnested expressions to our unnested expressions array
    unnestedExprs.concat(func);
  }

  let args = [];

  for (let arg of node.args) {
    // if it's a call expression, we need to unnest any subexpressions from
    // the arguments to the sub-call expression and create a new call expr
    if (arg.kind === ASTTypes.CallExpression) {
      // we'll need unnested arguments for the subcall
      let subArgs = [];
      for (let a of arg.args) {
        // primitives and symbols are already in ANF
        if (isPrimitiveOrSymbol(a)) {
          subArgs.push(a);
          // otherwise, we need to unnest the expression and bind the result to a new
          // variable, then replace the expression in the call body with that variable
        } else {
          const freshLet = AST.VariableDeclaration(
            createFreshSymbol(a.srcloc),
            a,
            a.srcloc,
            null
          );
          const transformedLet = transformVariableDeclaration(freshLet);
          // the actual declaration will always be the last node in the array
          // we're going to need to add this to the unnested expressions
          // for the parent call expression, so we don't pop it
          const actualLet = transformedLet[transformedLet.length - 1];

          // add the unnested expressions from the VariableDeclaration
          // to the unnested expressions from the call expression
          unnestedExprs = unnestedExprs.concat(transformedLet);
          // add the variable that's been assigned to
          // its relative place in the subcall args
          subArgs.push(actualLet.lhv);
        }
      }

      // create a new CallExpression with unnested sub-arguments
      let subCall = AST.CallExpression(arg.func, subArgs, arg.srcloc);
      // create a fresh variable symbol
      const callSymbol = createFreshSymbol(subCall.srcloc);
      // assign the result of the unnested call expression to the fresh variable
      const callLet = AST.VariableDeclaration(
        callSymbol,
        subCall,
        subCall.srcloc
      );

      // add the assignment to the unnested expressions
      unnestedExprs.push(callLet);
      // the argument to the parent call expression should now be the fresh variable
      arg = callSymbol;
    } else {
      arg = anf(arg);
    }

    if (Array.isArray(arg)) {
      // the actual arg will always be the last element in this array, the rest
      // are all unnested expressions and should be concatenated to that array
      args.push(arg.pop());
      unnestedExprs = unnestedExprs.concat(arg);
    } else {
      // the anfed arg is a single node
      args.push(arg);
    }
  }

  const newCallExpr = AST.CallExpression(func, args, node.srcloc);

  return [...unnestedExprs, newCallExpr];
};

Transforming Lambda Expressions

Transforming lambda expressions is extremely simple. We don't need to worry about parameters because there's no nesting there, so all we need to do is flatMap over the body with anf as the mapping function:

const transformLambdaExpression = (node) => {
  const body = node.body.flatMap(anf);
  return { ...node, body };
};

Transforming Variable Declarations

Transforming variable declarations is complicated by the need to unnest variables used in destructuring.

First, like usual, we create an array for unnested expressions. Then we ANF transform the initializer expression for the declaration node.

If the ANFed expression is an array, we pop off the actual expression then concatenate the rest to the unnested expressions array.

Whether or not it's an array, we construct a new declaration using the transformed expression.

If it's a simple variable declaration using a single symbol as the variable name, we're done here and can just return an array containing the unnested expressions and ending with the newly constructed declaration.

const transformVariableDeclaration = (node) => {
  let unnestedExprs = [];
  const anfedExpr = anf(node.expression);
  let anfedDecl;
  let expression;

  if (Array.isArray(anfedExpr)) {
    expression = anfedExpr.pop();
    anfedDecl = { ...node, expression };

    unnestedExprs = unnestedExprs.concat(anfedExpr);
  } else {
    expression = anfedExpr;
    anfedDecl = { ...node, expression };
  }
  // rest of function which handles destructured declarations
  // If we get here, it's a simple variable declaration with a symbol as LHV
  return [...unnestedExprs, anfedDecl];
};

If it's a destructured variable declaration, we've got some work to do.

Transforming Vector Destructuring

If it's vector destructuring, we need to loop over the members of the vector pattern.

For each member of the vector pattern, we construct a new variable declaration with a call to the get function for the current index of the list/vector/tuple to get the correct value for the destructured variable.

Note that this function call will throw an error if the object being destructured doesn't have enough members to satisfy the number of destructured variables.

If the last member is a rest variable, then instead of using the get function we use slice, passing it the current index and the object being destructured. This is why we needed to be able to call slice with only 2 arguments, passing it only a start numeric value instead of requiring both start and stop.

  if (node.lhv.kind === ASTTypes.VectorPattern) {
    // is vector pattern destructuring
    /** @type {import("../parser/ast.js").VectorPattern} */
    const pattern = node.lhv;
    let i = 0;

    for (let mem of pattern.members) {
      if (i === pattern.members.length - 1 && pattern.rest) {
        // need to slice off the rest of the list/vector/tuple and assign it to the last member
        const destructuredDecl = AST.VariableDeclaration(
          mem,
          AST.CallExpression(
            AST.Symbol(Token.new(TokenTypes.Symbol, "slice", mem.srcloc)),
            [
              AST.NumberLiteral(
                Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
              ),
              expression,
            ],
            mem.srcloc
          )
        );
        // and push it onto the unnestedExprs array
        unnestedExprs.push(destructuredDecl);
      } else {
        // need to get the value from the current index of the list/vector/tuple and assign it to the current pattern member
        const destructuredDecl = AST.VariableDeclaration(
          mem,
          AST.CallExpression(
            AST.Symbol(Token.new(TokenTypes.Symbol, "get", mem.srcloc)),
            [
              AST.NumberLiteral(
                Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
              ),
              expression,
            ]
          )
        );
        // and push it onto the unnestedExprs array
        unnestedExprs.push(destructuredDecl);
      }
      i++;
    }
    return unnestedExprs;
  }
  // handle record destructuring...

Transforming Record Destructuring

If it's record destructuring, we start by assigning the actual object being destructured to a fresh variable name and pushing the assignment node onto the unnested expressions array. We save the symbol node for this fresh variable so we can use it in what follows.

Now we loop over the properties of the record pattern. We need to keep track of what properties have been used with an array, so we can use the unused properties to create the object assigned to the rest variable. Since this step comes after the type checker, we'll use the type information we got from the type checker to do that.

If it's not a rest variable, we simply create a new variable declaration node and construct a member expression node using the original object and the current property being looped over, then push that declaration onto the unnested expressions array.

Then we push the property name we've just used onto the used properties array.

If it's a rest variable, we create a new array of all the unused properties by filtering over the RHV's type properties then mapping those properties to new symbol nodes.

Then we get the properties for the remainder object by reducing the unused properties array and creating a new array of Property nodes, then use that array to construct a new RecordLiteral node.

Then we create a new variable declaration assigning the remainder object to the rest variable name.

Finally, we push the final declaration onto the unnested expressions array.

  // closing brace is the last one from the previous code block, do not duplicate
  } else if (node.lhv.kind === ASTTypes.RecordPattern) {
    // is record pattern destructuring
    // remember, we have the RHV's type at this point
    /** @type {import("../parser/ast.js").RecordPattern} */
    const pattern = node.lhv;
    // first we need to assign the actual object to a fresh variable name
    const objSymbol = createFreshSymbol(expression.srcloc);
    const objDecl = AST.VariableDeclaration(
      objSymbol,
      expression,
      expression.srcloc
    );

    unnestedExprs.push(objDecl);

    let i = 0;
    let used = [];

    for (let prop of pattern.properties) {
      if (i === pattern.properties.length - 1 && pattern.rest) {
        // need to get the rest of the object's properties and assign them to the rest variable
        // this maps the array of unused properties from the type to an array of Symbol
        // nodes with each property name as the node name
        const unusedProps = expression.type.properties
          .filter((p) => {
            return !used.includes(p.name);
          })
          .map((p) =>
            AST.Symbol(Token.new(TokenTypes.Symbol, p.name, prop.srcloc))
          );

        // now create an object using the properties
        const properties = unusedProps.reduce((props, p) => {
          return [
            ...props,
            AST.Property(
              p,
              AST.MemberExpression(objSymbol, p, p.srcloc),
              p.srcloc
            ),
          ];
        }, []);
        const remainingObject = AST.RecordLiteral(properties, prop.srcloc);
        // and a variable declaration using the remainder object assigning it to the rest variable
        const restDecl = AST.VariableDeclaration(
          prop,
          remainingObject,
          prop.srcloc
        );

        unnestedExprs.push(restDecl);
      } else {
        // need to assign the current variable's object property
        const currentDecl = AST.VariableDeclaration(
          prop,
          AST.MemberExpression(objSymbol, prop, prop.srcloc)
        );
        // and push it onto the unnestedExprs array
        unnestedExprs.push(currentDecl);
        used.push(prop.name);
      }
      i++;
    }

    return unnestedExprs;
  }

Here's the entire transformVariableDeclaration function:

const transformVariableDeclaration = (node) => {
  let unnestedExprs = [];
  const anfedExpr = anf(node.expression);
  let anfedDecl;
  let expression;

  if (Array.isArray(anfedExpr)) {
    expression = anfedExpr.pop();
    anfedDecl = { ...node, expression };

    unnestedExprs = unnestedExprs.concat(anfedExpr);
  } else {
    expression = anfedExpr;
    anfedDecl = { ...node, expression };
  }

  if (node.lhv.kind === ASTTypes.VectorPattern) {
    // is vector pattern destructuring
    /** @type {import("../parser/ast.js").VectorPattern} */
    const pattern = node.lhv;
    let i = 0;

    for (let mem of pattern.members) {
      if (i === pattern.members.length - 1 && pattern.rest) {
        // need to slice off the rest of the list/vector/tuple and assign it to the last member
        const destructuredDecl = AST.VariableDeclaration(
          mem,
          AST.CallExpression(
            AST.Symbol(Token.new(TokenTypes.Symbol, "slice", mem.srcloc)),
            [
              AST.NumberLiteral(
                Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
              ),
              expression,
            ],
            mem.srcloc
          )
        );
        // and push it onto the unnestedExprs array
        unnestedExprs.push(destructuredDecl);
      } else {
        // need to get the value from the current index of the list/vector/tuple and assign it to the current pattern member
        const destructuredDecl = AST.VariableDeclaration(
          mem,
          AST.CallExpression(
            AST.Symbol(Token.new(TokenTypes.Symbol, "get", mem.srcloc)),
            [
              AST.NumberLiteral(
                Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
              ),
              expression,
            ]
          )
        );
        // and push it onto the unnestedExprs array
        unnestedExprs.push(destructuredDecl);
      }
      i++;
    }
    return unnestedExprs;
  } else if (node.lhv.kind === ASTTypes.RecordPattern) {
    // is record pattern destructuring
    // remember, we have the RHV's type at this point
    /** @type {import("../parser/ast.js").RecordPattern} */
    const pattern = node.lhv;
    // first we need to assign the actual object to a fresh variable name
    const objSymbol = createFreshSymbol(expression.srcloc);
    const objDecl = AST.VariableDeclaration(
      objSymbol,
      expression,
      expression.srcloc
    );

    unnestedExprs.push(objDecl);

    let i = 0;
    let used = [];

    for (let prop of pattern.properties) {
      if (i === pattern.properties.length - 1 && pattern.rest) {
        // need to get the rest of the object's properties and assign them to the rest variable
        // this maps the array of unused properties from the type to an array of Symbol
        // nodes with each property name as the node name
        const unusedProps = expression.type.properties
          .filter((p) => {
            return !used.includes(p.name);
          })
          .map((p) =>
            AST.Symbol(Token.new(TokenTypes.Symbol, p.name, prop.srcloc))
          );

        // now create an object using the properties
        const properties = unusedProps.reduce((props, p) => {
          return [
            ...props,
            AST.Property(
              p,
              AST.MemberExpression(objSymbol, p, p.srcloc),
              p.srcloc
            ),
          ];
        }, []);
        const remainingObject = AST.RecordLiteral(properties, prop.srcloc);
        // and a variable declaration using the remainder object assigning it to the rest variable
        const restDecl = AST.VariableDeclaration(
          prop,
          remainingObject,
          prop.srcloc
        );

        unnestedExprs.push(restDecl);
      } else {
        // need to assign the current variable's object property
        const currentDecl = AST.VariableDeclaration(
          prop,
          AST.MemberExpression(objSymbol, prop, prop.srcloc)
        );
        // and push it onto the unnestedExprs array
        unnestedExprs.push(currentDecl);
        used.push(prop.name);
      }
      i++;
    }

    return unnestedExprs;
  }
  // If we get here, it's a simple variable declaration with a symbol as LHV
  return [...unnestedExprs, anfedDecl];
};

Remember, we're only using toplevel destructuring without nested patterns. Adding nested patterns would make this even more complex.

Transforming Set Expressions

The code for transformSetExpressions is similar to what just the simple symbol variable parts of transformVariableDeclaration look like:

const transformSetExpression = (node) => {
  const anfedExpr = anf(node.expression);

  if (Array.isArray(anfedExpr)) {
    let expression = anfedExpr.pop();
    return [...anfedExpr, { ...node, expression }];
  }

  return [{ ...node, expression: anfedExpr }];
};

Transforming Vector Literals

To transform a vector literal we just need to unnest any nested expressions that make up the vector's members:

const transformVectorLiteral = (node) => {
  let unnestedExprs = [];
  let members = [];

  for (let mem of node.members) {
    let anfed = anf(mem);

    if (Array.isArray(anfed)) {
      members.push(anfed.pop());
      unnestedExprs = unnestedExprs.concat(anfed);
    } else {
      members.push(anfed);
    }
  }

  return [...unnestedExprs, { ...node, members }];
};

Transforming Record Literals

Same goes for transforming record literals where we unnest any nested expressions used in creating the record's properties:

const transformRecordLiteral = (node) => {
  let unnestedExprs = [];
  let properties = [];

  for (let prop of node.properties) {
    let anfed = anf(prop.value);

    if (Array.isArray(anfed)) {
      let value = anfed.pop();
      properties.push({ ...prop, value });
      unnestedExprs = unnestedExprs.concat(anfed);
    } else {
      properties.push({ ...prop, value: anfed });
    }
  }

  return [...unnestedExprs, { ...node, properties }];
};

Transforming Member Expressions

The only complication in transforming a member expression is that the object can be practically any node type as long as the expression produces an object:

const transformMemberExpression = (node) => {
  let unnestedExprs = [];
  let anfedObject = anf(node.object);

  if (Array.isArray(anfedObject)) {
    const object = anfedObject.pop();
    unnestedExprs = unnestedExprs.concat(anfedObject);
    return [...unnestedExprs, { ...node, object }];
  }

  return [{ ...node, object: anfedObject }];
};

Transforming Do Expressions

Transforming a do expression is pretty much the same as transforming the Program:

const transformDoExpression = (node) => {
  const body = node.body.flatMap(anf);
  return { ...node, body };
};

Transforming As Expressions

In transforming as expressions, we're just going to transform the expression part and return it. That means we'll no longer need to handle as expressions in the emitter, just like we no longer need to handle destructuring in the emitter:

const transformAsExpression = (node) => {
  const anfed = anf(node.expression);

  if (Array.isArray(anfed)) {
    return anfed;
  }

  return [anfed];
};

Transforming If Expressions

If expressions are straightforward. Transform the test, then append any unnested expressions to the unnested expressions array. Then do the same with the consequent (then) and alternate (else) branches:

const transformIfExpression = (node) => {
  let unnestedExprs = [];
  let transformedCondition = anf(node.test);
  let test;

  if (Array.isArray(transformedCondition)) {
    test = transformedCondition.pop();
    unnestedExprs = unnestedExprs.concat(transformedCondition);
  } else {
    test = transformedCondition;
  }

  let transformedThen = anf(node.then);
  let then;

  if (Array.isArray(transformedThen)) {
    then = transformedThen.pop();
    unnestedExprs = unnestedExprs.concat(transformedThen);
  } else {
    then = transformedThen;
  }

  let transformedElse = anf(node.else);
  let elseBranch;

  if (Array.isArray(transformedElse)) {
    elseBranch = transformedElse.pop();
    unnestedExprs = unnestedExprs.concat(transformedElse);
  } else {
    elseBranch = transformedElse;
  }

  return [...unnestedExprs, { ...node, test, then, else: elseBranch }];
};

Transforming When Expressions

When expressions are like a combination between if and do expressions. Transform the test like an if expression, then transform the body like a do expression:

const transformWhenExpression = (node) => {
  let unnestedExprs = [];
  let transformedCondition = anf(node.test);
  let test;

  if (Array.isArray(transformedCondition)) {
    test = transformedCondition.pop();
    unnestedExprs = unnestedExprs.concat(transformedCondition);
  }

  let body = node.body.flatMap(anf);

  return [...unnestedExprs, { ...node, test, body }];
};

Transforming Logical Expression

Logical expressions are straightforward. Transform the left, transform the right, add anything that needs it to the unnested expressions array:

const transformLogicalExpression = (node) => {
  let unnestedExprs = [];
  let transformedLeft = anf(node.left);
  let left;

  if (Array.isArray(transformedLeft)) {
    left = transformedLeft.pop();
    unnestedExprs = unnestedExprs.concat(transformedLeft);
  } else {
    left = transformedLeft;
  }

  let transformedRight = anf(right);
  let right;

  if (Array.isArray(transformedRight)) {
    right = transformedRight.pop();
    unnestedExprs = unnestedExprs.concat(transformedRight);
  } else {
    right = transformedRight;
  }

  return [...unnestedExprs, { ...node, left, right }];
};

That's all the transformation functions!

Figuring out how to do AST transformations vexed me for a long time, so I really hope this helps you understand how transforming the tree works.

Tail Call Optimization

Now let's take the transformed tree and use it to detect recursive tail calls.

To detect recursive tail calls we traverse the whole transformed AST. When we come to a lambda expression, we check if it has a name property (which gets added when function declarations are desugared into variable declarations with lambdas).

If the last (tail position) expression in the lambda body is a call expression, or if it's an if, do, or logical expression that contains a call expression in tail position, we check to see if the func property of that call expression is a Symbol node with the same name. If it is, then it's a recursive tail call and we can mark it as such.

We'll mark the call expression, lambda expression, and if the tail call is part of an if expression the if expression with a boolean flag indicating if it's tail recursive.

We'll use a class and extend Visitor for this, because we only need functions for a couple of node types but we need to traverse the whole tree.

First the imports, in src/transform/tco.js:

import { ASTTypes } from "../parser/ast.js";
import { Visitor } from "../visitor/Visitor.js";

We need a couple of helper functions. First, we'll need one to swap the last node in an expression body. We need to replace tail call nodes with new nodes that mark them as tail calls. Note that the 1st argument to this is the node with the body property, not the body itself:

const swapLastExpr = (node, expr) => {
  node.body[node.body.length - 1] = expr;
};

We also need a function to check if expressions. We'll use a function for this instead of using the regular visitor method because we only need to check if expressions that are in tail position in the body of a lambda or do expression. All other if expressions can be handled as normal.

This function recursively checks if expressions that are the then or else branch of the if expression being checked:

const checkIfExpression = (node, name, visitor) => {
  let isTailRec = false;

  if (
    node.then.kind === ASTTypes.CallExpression &&
    node.then.func.name === name
  ) {
    node.then = visitor.visitCallExpression(node.then, true);
    isTailRec = true;
  }

  if (
    node.else.kind === ASTTypes.CallExpression &&
    node.else.func.name === name
  ) {
    node.else = visitor.visitCallExpression(node.else, true);
    isTailRec = true;
  }

  if (node.then.kind === ASTTypes.IfExpression) {
    node.then = checkIfExpression(node.then, name, visitor);
    if (node.then.isTailRec) {
      isTailRec = true;
    }
  }

  if (node.else.kind === ASTTypes.IfExpression) {
    node.else = checkIfExpression(node.else, name, visitor);
    if (node.else.isTailRec) {
      isTailRec = true;
    }
  }

  return { ...node, isTailRec };
};

Ok, first let's stub out the class for the TCO transformer:


class TCOTransformer extends Visitor {
  constructor(program) {
    super(program);
  }

  static new(program) {
    return new TCOTransformer(program);
  }
}

Now let's add an override for visitCallExpression that takes an extra, optional parameter isTailRec that defaults to false. It has to have a default because it will only get called with true from our override version of visitLambdaExpression:

  visitCallExpression(node, isTailRec = false) {
    return { ...node, isTailRec };
  }

Ok, now let's turn to visitLambdaExpression. It's a lot of code, but it's straightforward. If the lambda node has a name property, save it to a variable. Then get the last expression from the lambda body.

If the last expression is a call expression we check to see if its func property is a Symbol node with the same name as the lambda.

If it's a do expression, if expression, or logical expression we check to see if either the last expression (of a do expression) or either branch (of an if or logical expression) contains a call in tail position that calls a function with the same name.

If we detect a recursive tail call, we mark the call expression and lambda expression both with flags isTailRec set to true.

  visitLambdaExpression(node) {
    const name = node.name ?? "";
    const lastExpr = node.body[node.body.length - 1];

    // Could be tail recursive: CallExpression, DoExpression, IfExpression, LogicalExpression
    if (lastExpr.kind === ASTTypes.CallExpression) {
      // should only work if func is symbol
      if (lastExpr.func.name === name) {
        const newCall = this.visitCallExpression(lastExpr, true);
        swapLastExpr(node, newCall);
        return { ...node, isTailRec: true };
      }

      lastExpr.isTailRec = false;
      return { ...node, isTailRec: false };
    } else if (lastExpr.kind === ASTTypes.DoExpression) {
      const lastBodyExpr = lastExpr.body[lastExpr.body.length - 1];

      if (
        lastBodyExpr.kind === ASTTypes.CallExpression &&
        lastBodyExpr.func.name === name
      ) {
        const newCall = this.visitCallExpression(lastBodyExpr, true);
        swapLastExpr(lastExpr, newCall);
        return { ...node, isTailRec: true };
      } else if (lastBodyExpr.kind === ASTTypes.IfExpression) {
        const newIf = checkIfExpression(lastBodyExpr, name, this);
        swapLastExpr(lastExpr, newIf);
        return { ...node, isTailRec: newIf.isTailRec };
      } else if (lastBodyExpr.kind === ASTTypes.LogicalExpression) {
        let isTailRec = false;

        if (
          lastBodyExpr.left.kind === ASTTypes.CallExpression &&
          lastBodyExpr.left.func.name === name
        ) {
          isTailRec = true;
          lastBodyExpr.left = this.visitCallExpression(lastBodyExpr.left, true);
        }

        if (
          lastBodyExpr.right.kind === ASTTypes.CallExpression &&
          lastBodyExpr.right.func.name === name
        ) {
          isTailRec = true;
          lastBodyExpr.right = this.visitCallExpression(
            lastBodyExpr.right,
            true
          );
        }

        return { ...node, isTailRec };
      } else {
        lastExpr.isTailRec = false;
        return { ...node, isTailRec: false };
      }
    } else if (lastExpr.kind === ASTTypes.IfExpression) {
      let newIf = checkIfExpression(lastExpr, name, this);
      swapLastExpr(node, newIf);
      return { ...node, isTailRec: newIf.isTailRec };
    } else if (lastExpr.kind === ASTTypes.LogicalExpression) {
      let isTailRec = false;
      if (
        lastExpr.left.kind === ASTTypes.CallExpression &&
        lastExpr.left.func.name === name
      ) {
        lastExpr.left = this.visitCallExpression(lastExpr.left, true);
        isTailRec = true;
      }

      if (
        lastExpr.right.kind === ASTTypes.CallExpression &&
        lastExpr.right.func.name === name
      ) {
        lastExpr.right = this.visitCallExpression(lastExpr.right, true);
        isTailRec = true;
      }

      return { ...node, isTailRec };
    }

    return { ...node, isTailRec: false };
  }

Finally, we export a function that constructs the TCO transformer and runs it on a program:

export const tco = (program) => TCOTransformer.new(program).visit();

Now we need a function to handle all transformations, in src/transform/transform.js:

import { anf } from "./anf.js";
import { tco } from "./tco.js";

export const transform = (program) => tco(anf(program));

Changes to The Runtime

Now that we've added the ability to detect tail recursive calls, we need to optimize them.

The simplest way to do that in JavaScript is with a trampoline.

A trampoline is a loop that handles recursive calls. We're going to trampoline our tail recursive functions by making the tail call return a special object that includes the function and arguments to the tail call as properties on the object. You'll see how we rewrite the tail recursive call to achieve this when we get to the emitter.

The setup actually has 2 parts: a recur function that returns the special object and a trampoline function that returns the function that runs the loop.

It sounds complicated, but the two functions are actually very simple, in src/runtime/trampoline.js:

import { makeWandaValue } from "./conversion.js";

// based on this Stack Overflow answer: https://stackoverflow.com/a/50493099

export const recur = (f, ...args) => ({ tag: recur, f, args });

export const trampoline = (f) => {
  const trampolined = (...args) => {
    let t = f(...args);
    while (t && t.tag === recur) {
      t = t.f(...t.args);

      if (t && t.tag !== recur) {
        return makeWandaValue(t);
      }
    }
  };

  trampolined.f = f;

  return trampolined;
};

We also need to make a minor change to our makeFunction function in the runtime.

Currently the function we create in makeFunction returns a call to makeWandaValue, which means right now if we rewrite the recursive call to use rt.recur the object that returns will be transformed into a Wanda object, which will mess up the trampoline function's ability to process the special object.

So let's add a tailRec property to the options argument that defaults to false so when we detect a tail recursive function we can prevent it from transforming the object returned by the tail call. We'll let the trampoline function handle converting the final value to a Wanda value.

Here's our new makeFunction function in src/runtime/makeFunction.js:

import objectHash from "object-hash";
import { curryN } from "ramda";
import { makeWandaValue } from "./conversion.js";
import { addMetaField } from "./object.js";
import { parseContract } from "./parseContract.js";

export const makeFunction = (
  func,
  { contract = "", name = "", tailRec = false } = {}
) => {
  let fn = curryN(func.length, (...args) => {
    const val = tailRec ? func(...args) : makeWandaValue(func(...args));

    if (typeof val === "function") {
      return makeFunction(val);
    }

    return val;
  });
  const hash = objectHash(func);

  addMetaField(fn, "wanda", true);
  addMetaField(fn, "arity", func.length);
  addMetaField(fn, "name", name || hash);

  if (contract !== "") {
    Object.defineProperty(fn, "contract", {
      enumerable: false,
      configurable: false,
      writable: false,
      value: parseContract(contract),
    });
  }

  Object.defineProperty(fn, "name", {
    enumerable: false,
    configurable: false,
    writable: false,
    value: name || hash,
  });

  return fn;
};

Then we need to add trampoline and recur to makeRuntime in src/runtime/makeRuntime.js:

// other imports
import { trampoline, recur } from "./trampoline.js";

export const makeRuntime = () => {
  return {
    // other members
    trampoline,
    recur,
  };
};

Now that our trampoline is in place, we need to modify the emitter to use it.

Changes to The Emitter

The first thing we're going to do to our emitter is remove everything related to handling destructuring and as expressions, since thanks to our ANF transformation they won't be making it to the emitter anymore.

You can delete the emitAsExpression method and the case for it from the emit method. This leaves you with this for emit in src/emitter/Emitter.js:

  emit(node = this.program, ns = this.ns) {
    switch (node.kind) {
      case ASTTypes.Program:
        return this.emitProgram(node, ns);
      case ASTTypes.NumberLiteral:
        return this.emitNumber(node, ns);
      case ASTTypes.StringLiteral:
        return this.emitString(node, ns);
      case ASTTypes.BooleanLiteral:
        return this.emitBoolean(node, ns);
      case ASTTypes.KeywordLiteral:
        return this.emitKeyword(node, ns);
      case ASTTypes.NilLiteral:
        return this.emitNil(node, ns);
      case ASTTypes.Symbol:
        return this.emitSymbol(node, ns);
      case ASTTypes.CallExpression:
        return this.emitCallExpression(node, ns);
      case ASTTypes.VariableDeclaration:
        return this.emitVariableDeclaration(node, ns);
      case ASTTypes.SetExpression:
        return this.emitSetExpression(node, ns);
      case ASTTypes.DoExpression:
        return this.emitDoExpression(node, ns);
      case ASTTypes.TypeAlias:
        return this.emitTypeAlias(node, ns);
      case ASTTypes.MemberExpression:
        return this.emitMemberExpression(node, ns);
      case ASTTypes.RecordLiteral:
        return this.emitRecordLiteral(node, ns);
      case ASTTypes.RecordPattern:
        return this.emitRecordPattern(node, ns);
      case ASTTypes.VectorLiteral:
        return this.emitVectorLiteral(node, ns);
      case ASTTypes.VectorPattern:
        return this.emitVectorPattern(node, ns);
      case ASTTypes.LambdaExpression:
        return this.emitLambdaExpression(node, ns);
      case ASTTypes.LogicalExpression:
        return this.emitLogicalExpression(node, ns);
      case ASTTypes.IfExpression:
        return this.emitIfExpression(node, ns);
      case ASTTypes.WhenExpression:
        return this.emitWhenExpression(node, ns);
      default:
        throw new SyntaxException(node.kind, node.srcloc);
    }
  }

You can also delete the emitVariableDeclarationAssignment method since we won't be using it anymore.

Finally, you can vastly simplify emitVariableDeclaration so it looks like this:

  emitVariableDeclaration(node, ns) {
    const name = node.lhv.name;
    const translatedName = makeSymbol(name);

    if (ns.has(name)) {
      throw new ReferenceException(
        `Name ${name} has already been accessed in the current namespace; cannot access name before its definition`,
        node.srcloc
      );
    }

    ns.set(name, translatedName);

    return `var ${makeSymbol(name)} = ${this.emit(node.expression, ns)}`;
  }

Now to handle tail recursion and the trampoline.

In emitLambdaExpression change the last line beginning with code += to handle the new tailRec option for rt.makeFunction:

    code += `${
      node.name
        ? `, { name: "${node.name}"${node.isTailRec ? ", tailRec: true" : ""} }`
        : ""
    })`;

Also change the return statement to use the trampoline if node.isTailRec is true:

    return node.isTailRec ? `rt.trampoline(${code})` : code;

Now in emitCallExpression we're going to separate the call expression's function from its arguments since the rt.recur function needs them separated. Then if it's a tail call we call rt.recur and pass it the original function (NOT the trampolined version) as well as the arguments. We pass it the original function, which is stored in the f property added to the trampolined function, because if the trampoline function called itself we'd just run into the same recursion limits. Here's the new version of emitCallExpression:

  emitCallExpression(node, ns) {
    const func = `(${this.emit(node.func, ns)})`;
    const args = `${node.args.map((a) => this.emit(a, ns)).join(", ")}`;

    return node.isTailRec ? `rt.recur(${func}.f, ${args})` : `${func}(${args})`;
  }

That's it for the emitter! Now we just need to add the transformation to our compilation pipeline and it will be done.

Changes to The CLI

In src/cli/compile.js we need to import our transform function:

import { transform } from "../transform/transform.js";

And finally we need to add transform to the compile function. Here's the new compile function:

export const compile = (
  input,
  file = "stdin",
  ns = undefined,
  typeEnv = undefined
) =>
  emit(
    transform(
      desugar(typecheck(parse(expand(read(tokenize(input, file)))), typeEnv))
    ),
    ns
  );

And with that we have implemented TCO in our compiler. Were you expecting it to be more difficult? I was certainly surprised at how simple and straightforward it turned out to be.

Trying It Out

Ok, let's fire up a REPL and try it out. First, try the naive factorial function:

(def fact (n)
  (if (= n 1)
    n
    (* n (fact (- n 1)))))

Try it with (fact 1000). Ok, yeah, the answer is Infinity, so we're not going to get better numeric answers for bigger numbers, but that's not the point.

Try it with (fact 2000). If your Node instance is configured like mine, you just got a stack overflow error.

Close your REPL and then open it again (note to self: add a command to refresh the REPL state), then try the tail recursive version of fact:

(def fact (n a)
  (if (= n 1)
    a
    (fact (- n 1) (* a n))))

Start with (fact 1000 1). Now keep incrementing by 1000 and seeing what happens.

I got to 10,000 and then decided to do something crazy: (fact 100000 1).

Infinity.

No stack overflow error, even with 100,000 calls!

I'd say the trampoline works pretty damned well.

Conclusion

I'm thrilled that I was able to get this working the way I wanted to. Now we have true tail call optimization with virtually unlimited tail recursion! I'm excited with where we're at right now.

Like I said, AST transformations are something I struggled with conceptually for a long time before I finally figured out how to make this work.

As always, you can see the current state of the code at the relevant tag in the GitHub repo.

This is currently the last planned post in the series, though it's possible I'll come back and add more later. I hope you've had as much fun reading as I've had writing!

I also hope this inspires you to go out and create your own languages and language related tools. Let me know if you make something you think is cool!

Featured ones: