Programming with Immutable Datastructures

An immutable datastructure is one that cannot be changed after it has been created. Some examples you've already seen are strings, tuples, numbers, and true and false. In constrast, a mutable datastructure is one that can be changed after it has been created. Some examples are arrays, vectors, and sequences.

If something is guaranteed not to change, then there are two details that you no longer have to worry about.

  1. You don't have to care about which object it is. There is no difference between the value 42 and the value 20 + 22. They are the same value. You can replace every occurrence of 42 in your program with 20 + 22 and it will still behave the same way. Similarly, you can replace every occurrence of
    "Timon and Pumbaa"
    in your program with
    append("Timon", " and Pumbaa")
    without changing its behaviour. In contrast, consider the following call for adding a number to a vector.
    add(xs, 42)
    Now you do need to pay very close attention to which vector xs is referring to. It would be an error to add 42 to the wrong vector.
  2. You don't have to think about when to do something to an object. Consider the following code for popping an item from the vector xs and then adding two new items to it.
    pop(xs)
    add(xs, 42)
    add(xs, 43)
    The ordering of those expressions are critically important. Every possible ordering of those three expressions results in a different behaviour. Notice that this sort of thinking is never done with strings, tuples, or numbers; simply because there's nothing than can be done to them except to create new objects out of them.

Lists

A List object represents a singly linked list of objects. A list is Stanza's most basic immutable datastructure and cannot be changed once created. Here is how to create an empty list.

List()

Here is how to create a list containing a single item.

List(42)

Here is how to create a list containing two items.

List(42, "Timon")

This works for lists containing up to four items. For creating lists containing more than four items, you may use the to-list function to convert sequences into lists.

to-list([1, 2, 3, 4, 5, "Timon", "and", "Pumbaa"])

You may also use cons (short for construct) to create a new list by tacking a new item to the beginning of an existing list.

val xs = List(1, 2, 3)
val ys = cons(42, xs)

cons allows you to tack on up to three items.

val xs = List(1, 2, 3)
val ys0 = cons(42, xs)
val ys1 = cons(42, 43, xs)
val ys2 = cons(42, 43, 44, xs)

To append more than three items to the beginning of another list, use the append function.

val xs = List(1, 2, 3)
val ys = append([42, 43, 44, 45, 46, 47], xs)
println(ys)

Compiling and running the above prints out

(42 43 44 45 46 47 1 2 3)

Fundamental Operations

A list is defined by three fundamental operations.

  1. You can check whether the list is empty.
  2. You can retrieve the first element in the list.
  3. You can retrieve a list containing all the elements after the first one.

Assuming that xs is a list, here is how to check whether xs is empty.

empty?(xs)

Here is how to retrieve the first element in xs.

head(xs)

And here is how to retrieve all the elements after the first one, as another list.

tail(xs)

Example: Coin Counting

Suppose you have access to pennies, nickels, dimes, quarters, and loonies, and the poutine you bought costs $1.17. (Loonies are Canadian coins worth 100 cents each.) How may different combinations of coins are there that total up to $1.17?

Here is our algorithm for calculating it. num-coin-combos takes two arguments: cents, which represents the amount of money you wish to make represented in cents, and coins, a list of the cent values of the coins you can use.

defn num-coin-combos (cents:Int, coins:List<Int>) -> Int :
   if cents == 0 :
      1
   else if cents < 0 :
      0
   else if empty?(coins) :
      0
   else :
      val with-first-coin = num-coin-combos(cents - head(coins), coins)
      val without-first-coin = num-coin-combos(cents, tail(coins))
      with-first-coin + without-first-coin

Let's read through each case of the algorithm one by one. The first case is

if cents == 0 :
   1

There is only one way to make 0 cents, and that is to not use any coins at all. Makes sense.

The second case is

else if cents < 0 :
   0

There is no way to make a negative cent value. Makes sense.

The third case is

else if empty?(coins) :
   0

If we're not allowed to use any kind of coin, then there's also no way to make our total. Makes sense as well.

The real work of the algorithm is done by the fourth case.

val with-first-coin = num-coin-combos(cents - head(coins), coins)
val without-first-coin = num-coin-combos(cents, tail(coins))
with-first-coin + without-first-coin

Consider the next type of coin in our list. Suppose it's a loonie. There are two choices we can now make.

  1. We can account for 100 cents by using the loonie, and count the number of ways to make cents - 100. This is calculated as with-first-coin.
  2. We can choose not to use the loonie, and count the number of ways to make cents without using loonies. This is calculated as without-first-coin.

The total number of combinations is the sum of the results of the two possible choices we can make.

Let's now use our num-coin-combos function to answer the original question.

defn main () :
   val coins = [100, 25, 10, 5, 1]
   val num-combos = num-coin-combos(117, to-list(coins))
   println("There are %_ coin combinations that total to 117 cents." %
           [num-combos])

main()

which prints out

There are 349 coin combinations that total to 117 cents.

Strange Lands

Suppose we find ourselves in strange lands with a strange currency. The currency is made up of buzzles, with a value of 57 cents, moozles (26 cents), foogs (10 cents), goofs (5 cents), and tents (3 cents). Now how many ways are there to make the $1.17 needed to buy poutine? (Though the currency may be strange, poutine is fairly universal).

Let's adapt our main function to calculate with the new currency.

defn main () :
   val coins = [57, 26, 10, 5, 3]
   val num-combos = num-coin-combos(117, to-list(coins))
   println("There are %_ coin combinations that total to 117 cents." %
           [num-combos])

main()

which prints out

There are 137 coin combinations that total to 117 cents.

indicating that buzzles and foogs are a little less flexible than Canadian currency.

SICP

This exercise is adapted from the best book on computer science ever written, The Structure and Interpretation of Computer Programs by Abelson and Sussman. I highly recommend it to anyone interested in the deep connections between languages and computation. And since Stanza is a (highly modified) Scheme dialect at heart, all the exercises can easily be done in Stanza as well.

List Library

List is a subtype of Collection and so all of Stanza's sequence library also works on lists. The core library also includes a few functions specifically for managing lists. You've been introduced to a few of them already: head, tail, append, cons. Here's a few more.

get

The get function allows you to retrieve the element at a specific index in a list.

val xs = to-list(0 to 1000 by 3)
get(xs, 11)

Using Stanza's built-in operator, the above could also be written as

val xs = to-list(0 to 1000 by 3)
xs[11]

headn

headn returns a list containing the first n items in a list.

val xs = to-list(0 to 1000 by 3)
headn(xs, 10)

tailn

tailn returns a new list containing the items following the first n items in a list.

val xs = to-list(0 to 1000 by 3)
tailn(xs, 10)

reverse

reverse takes an argument list and returns a new list containing the same items in reversed order.

val xs = to-list(0 to 1000 by 3)
reverse(xs)

last

last takes an argument list and returns the last item in it. The list must not be empty.

val xs = to-list(0 to 1000 by 3)
last(xs)

but-last

but-last takes an argument list and returns a new list containing all the items from the argument list except the last one.

val xs = to-list(0 to 1000 by 3)
but-last(xs)

map

map is the most commonly used function on lists. It takes an argument function, f, and a list, xs, and returns a new list containing the results of calling f on each item.

Here is an example that calculates the lengths of all the strings in the list xs.

val xs = to-list(["Timon" "and" "Pumbaa"])
val lengths = map(length, xs)

map is also an operating function, and it can be used together with the for construct. Here is an example of doubling every integer in the list xs.

val xs = to-list(0 to 1000 by 3)
val doubled = for x in xs map :
   x * 2

Example: More Coin Counting

One limitation of our previous algorithm for coin counting is that it calculated the number of ways we can make a certain total, but it never told us what these combinations actually were. You may be (as I was) actually quite curious about how to make $1.17 using buzzles and foogs.

Let's write a function called coin-combos that does that. Like num-coin-combos, coin-combos takes two arguments: cents, which represents the number of cents you wish to make, and coins, a list of the cent values of the coins. The difference is that coin-combos returns a list of combinations. Each combination is a list containing the number of times each coin is used.

defn coin-combos (cents:Int, coins:List<Int>) -> List<List<Int>> :
   if cents == 0 :
      List(map({0}, coins))
   else if cents < 0 :
      List()
   else if empty?(coins) :
      List()
   else :
      defn head+1 (xs:List<Int>) : cons(head(xs) + 1, tail(xs))
      defn cons-0 (xs:List<Int>) : cons(0, xs)
      val with-first-coin = map(head+1, coin-combos(cents - head(coins), coins))
      val without-first-coin = map(cons-0, coin-combos(cents, tail(coins)))
      append(with-first-coin, without-first-coin)

Let's examine each case separately.

if cents == 0 :
   List(map({0}, coins))

There is only one way to make up 0 cents, and that is by using no coins at all. So return a list with a single combination indicating that each coin is used 0 times.

else if cents < 0 :
   List()

There is no way to make a negative total so return an empty list.

else if empty?(coins) :
   List()

If we're not allowed to use any kind of coin, then there's also no way to make our total. Return an empty list.

And finally, we're at the last case again.

val with-first-coin = map(head+1, coin-combos(cents - head(coins), coins))
val without-first-coin = map(cons-0, coin-combos(cents, tail(coins)))
append(with-first-coin, without-first-coin)

All the combinations resulting from choosing to use the first coin are computed in with-first-coin. And all the combinations resulting from choosing not to use the first coin are computed in without-first-coin. We then append both lists to get the complete list of combinations.

The fourth case relies upon two helper functions, head+1, which adds 1 to the head of a list, and cons-0, which tacks 0 on to the beginning of a list.

defn head+1 (xs:List<Int>) : cons(head(xs) + 1, tail(xs))
defn cons-0 (xs:List<Int>) : cons(0, xs)

Let's now update our main function to report all the different ways we can use buzzles and foogs to make $1.17. Recall that buzzles are worth 57 cents, moozles are 26 cents, foogs are 10 cents, goofs are 5 cents, and tents are 3 cents.

defn main () :
   val coins = [57, 26, 10, 5, 3]
   val combos = coin-combos(117, to-list(coins))
   println("There are %_ coin combinations that total to 117 cents." % [
      length(combos)])
   do(println, combos)

main()

Compiling and running the above prints out

There are 137 coin combinations that total to 117 cents.
(2 0 0 0 1)
(1 2 0 1 1)
(1 1 2 1 3)
(1 1 1 3 3)
...
(0 0 0 6 29)
(0 0 0 3 34)
(0 0 0 0 39)

Thus we can pay for our $1.17 poutine using two buzzles and a foog. Or if we don't mind holding up the line, we can hunt around for thirty nine tents.

Readable Combos

For the sake of readability, let's write a printing function for formatting the combinations in a readable way. print-combo takes as arguments a combination, combo, and a collection representing the names of the coins, names.

defn print-combo (combo:List<Int>, names:Collection<String>) :
   val parts = for (c in combo, n in names) seq? :
      if c == 0 : None()
      else if c == 1 : One("%_ %_" % [c, n])
      else : One("%_ %_s" % [c, n])
   println-all(join(parts, ", "))   

You are encouraged to read the reference documentation for a description of what seq? does. You should be able to understand it now.

Now update the final call to print in the main function.

val coin-names = ["buzzle", "moozle", "foog", "goof", "tent"]
do(print-combo{_, coin-names}, combos)

Compiling and running the program now prints out

There are 137 coin combinations that total to 117 cents.
2 buzzles, 1 tent
1 buzzle, 2 moozles, 1 goof, 1 tent
1 buzzle, 1 moozle, 2 foogs, 1 goof, 3 tents
1 buzzle, 1 moozle, 1 foog, 3 goofs, 3 tents
...
15 goofs, 14 tents
12 goofs, 19 tents
9 goofs, 24 tents
6 goofs, 29 tents
3 goofs, 34 tents
39 tents

Extended Example: Automatic Differentiation

In your own programming, you are encouraged to define and use immutable datastructures when possible. Uses of mutation and stateful objects should serve a clear purpose. In this example, we define an immutable datastructure for manipulating algebra expressions and write a function for automatically differentiating expressions.

Symbols

Symbol objects are used to represent a unique constant object in Stanza. For example, the following creates and assigns symbols to x and to y.

val x = `Timon
val y = `Pumbaa

Symbols are created by prefixing an identifier with the backtick (`) operator. Very little can be done with symbols except check whether it is equal to another symbol. The following

println(x == `Timon)
println(y == `Timon)

prints out

true
false

and represents the most common use case for symbols. In this respect they are used in much the same way as enumerated constants in other languages. We will use symbols to represent the name of variables in our algebraic expressions.

The Expression Datastructure

We will first declare a type, Exp, to refer to an algebraic expression.

deftype Exp
defstruct Const <: Exp : (value:Int)
defstruct Variable <: Exp : (name:Symbol)
defstruct Add <: Exp : (a:Exp, b:Exp)
defstruct Subtract <: Exp : (a:Exp, b:Exp)
defstruct Multiply <: Exp : (a:Exp, b:Exp)
defstruct Divide <: Exp : (a:Exp, b:Exp)
defstruct Power <: Exp : (a:Exp, b:Exp)
defstruct Log <: Exp : (a:Exp)

A handful of different types of expressions are supported. Const represents constant integer literals, Variable represents a named variable, and the standard arithmetic operators are represented by Add, Subtract, Multiply, and Divide. Power represents one expression raised to the power of another, and Log represents the natural logarithm of an expression. Notice that many of the expressions contain fields that are themselves types of Exp. So the type Exp contains fields of type Exp. We call such a type a recursive type.

Printing an Expression

As usual, we will provide a custom print method for the Exp type to allow us to print it out.

defmethod print (o:OutputStream, e:Exp) :
   defn print-operator (a:Exp, op:String, b:Exp) :
      print(o, a)
      print(o, op)
      print(o, b)
   match(e) :
      (e:Const) : print(o, value(e))
      (e:Variable) : print(o, name(e))
      (e:Log) : print(o, "ln(%_)" % [a(e)])
      (e:Add) : print-operator(a(e), " + ", b(e))
      (e:Subtract) : print-operator(a(e), " - ", b(e))
      (e:Multiply) : print-operator(a(e), " * ", b(e))
      (e:Divide) : print-operator(a(e), " / ", b(e))
      (e:Power) : print-operator(a(e), " ^ ", b(e))

Let's now create an expression and print it out. The expression we will create is

2 * x ^ 2 + (1 + 3) * x + ln(x + 4)

Here is our main function.

defn main () :
   val term1 = Multiply(Const(2), Power(Variable(`x), Const(2)))
   val term2 = Multiply(Add(Const(1), Const(3)), Variable(`x))
   val term3 = Log(Add(Variable(`x), Const(4)))
   val exp = Add(Add(term1, term2), term3)
   println(exp)  

main()

Compiling and running the above prints out

2 * x ^ 2 + 1 + 3 * x + ln(x + 4)

We're off to a great start!

Handling Precedence

Our printing method for expressions is close, but it doesn't handle precedences correctly. The 1 + 3 in the printed expression should be surrounded by parentheses. Otherwise the meaning is different than intended.

Let's add a mechanism to handle precedences properly. Here's the basic algorithm. Every type of expression is associated with a number that represents its precedence. Const, Log, and Variable expressions have the highest precedence 3. Power has precedence 2. Multiply and Divide have precedence 1. And Add and Subtract have the lowest precedence 0.

defn precedence (e:Exp) :
   match(e) :
      (e:Add|Subtract) : 0
      (e:Multiply|Divide|Power) : 1
      (e:Power) : 2
      (e:Const|Variable|Log) : 3

The basic rule is that if a lower precedence expression appears as a child of a higher precedence expression, then the lower precedence expression needs to be surrounded by parentheses when printed out. So we'll define a new nested function within print to help us print nested expressions in the context of expression e.

defn print-nested (ne:Exp) :
   if precedence(ne) < precedence(e) :
      print(o, "(%_)" % [ne])
   else :
      print(o, ne)

If the nested expression ne has lower precedence than e, then ne is printed with surrounding parentheses. Otherwise ne is just printed directly.

The print-operator function also needs to be updated to call print-nested.

defn print-operator (a:Exp, op:String, b:Exp) :
   print-nested(a)
   print(o, op)
   print-nested(b)

Those are all the changes needed to handle precedence. Here is the full print method.

defmethod print (o:OutputStream, e:Exp) :
   defn print-nested (ne:Exp) :
      if precedence(ne) < precedence(e) :
         print(o, "(%_)" % [ne])
      else :
         print(o, ne)
   defn print-operator (a:Exp, op:String, b:Exp) :
      print-nested(a)
      print(o, op)
      print-nested(b)
   match(e) :
      (e:Const) : print(o, value(e))
      (e:Variable) : print(o, name(e))
      (e:Log) : print(o, "ln(%_)" % [a(e)])
      (e:Add) : print-operator(a(e), " + ", b(e))
      (e:Subtract) : print-operator(a(e), " - ", b(e))
      (e:Multiply) : print-operator(a(e), " * ", b(e))
      (e:Divide) : print-operator(a(e), " / ", b(e))
      (e:Power) : print-operator(a(e), " ^ ", b(e))

If you compile and run the program again, it should now correctly print out

2 * x ^ 2 + (1 + 3) * x + ln(x + 4)

Operator Overloading

The code we used to construct the expression

val term1 = Multiply(Const(2), Power(Variable(`x), Const(2)))
val term2 = Multiply(Add(Const(1), Const(3)), Variable(`x))
val term3 = Log(Add(Variable(`x), Const(4)))
val exp = Add(Add(term1, term2), term3)

is quite verbose. Let's overload some operators to help us with that.

Recall that the operators +, -, *, /, and ^ are just syntactic shorthands for calling the functions plus, minus, times, divide, and bit-xor. Thus all we need to do is define those functions on Exp objects.

defn plus (a:Exp, b:Exp) : Add(a, b)
defn minus (a:Exp, b:Exp) : Subtract(a, b)
defn times (a:Exp, b:Exp) : Multiply(a, b)
defn divide (a:Exp, b:Exp) : Divide(a, b)
defn bit-xor (a:Exp, b:Exp) : Power(a, b)
defn ln (a:Exp) : Log(a)

Now let's rewrite our main function using the new operators.

defn main () :
   val x = Variable(`x)
   val [c1, c2, c3, c4] = [Const(1), Const(2), Const(3), Const(4)]
   val exp = c2 * x ^ c2 + (c1 + c3) * x + ln(x + c4)
   println(exp)

Much better! If we overlook the little c's in front of each constant it's essentially identical to our printed expression.

The Differentiation Algorithm

Now we can implement the differentiation algorithm! The function differentiate takes two arguments: the expression to differentiate, e, and the variable with respect to which it will differentiate, x.

The actual formulas used to do the differentiation are standard, and we won't explain how to derive them. If you have taken a course on calculus, you can break open your old textbook and copy the formulas here. If you haven't taken a course on calculus, then armed with this program, you'll never have to manually differentiate again.

defn differentiate (e:Exp, x:Symbol) -> Exp :
   defn ddx (e:Exp) : differentiate(e, x)
  
   match(e) :
      (e:Const) :
         Const(0)
      (e:Variable) :
         if name(e) == x : Const(1)
         else : Const(0)
      (e:Add) :
         ddx(a(e)) + ddx(b(e))
      (e:Subtract) :
         ddx(a(e)) - ddx(b(e))
      (e:Multiply) :
         a(e) * ddx(b(e)) + b(e) * ddx(a(e))
      (e:Divide) :
         val num = b(e) * ddx(a(e)) - a(e) * ddx(b(e))
         val den = b(e) ^ Const(2)
         num / den
      (e:Power) :
         e * (b(e) * ddx(a(e)) / a(e) + ln(a(e)) * ddx(b(e)))
      (e:Log) :
         ddx(a(e)) / a(e)

Let's try differentiating our example expression now.

defn main () :
   val x = Variable(`x)
   val [c1, c2, c3, c4] = [Const(1), Const(2), Const(3), Const(4)]
   val exp = c2 * x ^ c2 + (c1 + c3) * x + ln(x + c4)
   val dexp = differentiate(exp, `x)
  
   println("Original Expression: %_" % [exp])
   println("Differentiated Expression: %_" % [dexp])

Compiling and running the program prints out

Original Expression: 2 * x ^ 2 + (1 + 3) * x + ln(x + 4)
Differentiated Expression: 2 * x ^ 2 * (2 * 1 / x + ln(x) * 0) +
                           x ^ 2 * 0 + (1 + 3) * 1 +
                           x * (0 + 0) + (1 + 0) / (x + 4)

If you check the result, it does work! The only problem is that the result contains a lot of expressions that can be trivially simplified. We'll fix that later. But this isn't bad at all for a 22-line algorithm.

Simplification

The only thing left to do now is simplify the resulting expression. We will write a very simple simplifier that simply looks for patterns like adding an expression to zero, or dividing by one, et cetera. But before we introduce the simplification algorithm, we need to first write a very useful helper function.

defn map (f: Exp -> Exp, e:Exp) -> Exp :
   match(e) :
      (e:Add) : Add(f(a(e)), f(b(e)))
      (e:Subtract) : Subtract(f(a(e)), f(b(e)))
      (e:Multiply) : Multiply(f(a(e)), f(b(e)))
      (e:Divide) : Divide(f(a(e)), f(b(e)))
      (e:Power) : Power(f(a(e)), f(b(e)))
      (e:Log) : Log(f(a(e)))
      (e) : e

map takes an argument function, f, and an expression, e, and returns a new expression resulting from calling f on every subexpression in e. Its behaviour is analogous to the map function for lists. Calling map on a list maps f onto every element in the list. Similarly, calling map on an expression maps f onto every subexpression in the expression.

We're now ready to write the simplify function. It takes an expression as its argument, and returns a simplified version of the expression by replacing specific patterns with simpler expressions.

defn simplify (e:Exp) :
   defn const? (e:Exp, v:Int) :
      match(e) :
         (e:Const) : value(e) == v
         (e) : false
   defn one? (e:Exp) : const?(e, 1)
   defn zero? (e:Exp) : const?(e, 0)  

   match(map(simplify, e)) :
      (e:Add) :
         if zero?(a(e)) : b(e)
         else if zero?(b(e)) : a(e)
         else : e
      (e:Subtract) :
         if zero?(a(e)) : Const(-1) * b(e)
         else if zero?(b(e)) : a(e)
         else : e        
      (e:Multiply) :
         if one?(a(e)) : b(e)
         else if one?(b(e)) : a(e)
         else if zero?(a(e)) or zero?(b(e)) : Const(0)
         else : e
      (e:Divide) :
         if zero?(a(e)) : Const(0)
         else if one?(b(e)) : a(e)
         else : e
      (e:Power) :
         if one?(a(e)) : Const(1)
         else if zero?(b(e)) : Const(1)
         else : e
      (e:Log) :
         if one?(a(e)) : Const(0)
         else : e
      (e) : e

Most of the work of the simplifier is done in the branches of the match expression; you can read through them to understand which patterns are being simplified and what they're being simplified to. However, the most magical part of the function is the call to map.

match(map(simplify, e)) :
   (e:Add) :
   ...

In English, that pattern says: first simplify all the nested subexpressions in e and then look for these patterns and replace them with simpler ones.

Let's update our main function now to simplify the differentiated expression.

defn main () :
   val x = Variable(`x)
   val [c1, c2, c3, c4] = [Const(1), Const(2), Const(3), Const(4)]
   val exp = c2 * x ^ c2 + (c1 + c3) * x + ln(x + c4)
   val dexp = differentiate(exp, `x)
   val sexp = simplify(dexp)

   println("Original Expression: %_" % [exp])
   println("Differentiated Expression: %_" % [dexp])
   println("Simplified Expression: %_" % [sexp])   

When compiled and ran it prints out

Original Expression: 2 * x ^ 2 + (1 + 3) * x + ln(x + 4)
Differentiated Expression: 2 * x ^ 2 * (2 * 1 / x + ln(x) * 0) +
                           x ^ 2 * 0 + (1 + 3) * 1 +
                           x * (0 + 0) + (1 + 0) / (x + 4)
Simplified Expression: 2 * x ^ 2 * 2 / x + 1 + 3 + 1 / (x + 4)

The simplified expression is much cleaner now! This concludes our automatic differentiation example. The simplicity of both the differentiation and the simplification algorithm stems from the fact that Exp is an immutable datastructure. In fact, the programming language Lisp, which strongly emphasized computation with immutable list structures and also heavily influenced the design of Stanza, was invented in part for writing computer algebra systems. John McCarthy started writing differentiation algorithms in Lisp even before the language was running!

Program Listing

Here's a full program listing of the example.

defpackage calculus :
   import core

;Expression definition
deftype Exp
defstruct Const <: Exp : (value:Int)
defstruct Variable <: Exp : (name:Symbol)
defstruct Add <: Exp : (a:Exp, b:Exp)
defstruct Subtract <: Exp : (a:Exp, b:Exp)
defstruct Multiply <: Exp : (a:Exp, b:Exp)
defstruct Divide <: Exp : (a:Exp, b:Exp)
defstruct Power <: Exp : (a:Exp, b:Exp)
defstruct Log <: Exp : (a:Exp)

;Precedences
defn precedence (e:Exp) :
   match(e) :
      (e:Add|Subtract) : 0
      (e:Multiply|Divide|Power) : 1
      (e:Power) : 2
      (e:Const|Variable|Log) : 3

;Print behaviour for expressions
defmethod print (o:OutputStream, e:Exp) :
   defn print-nested (ne:Exp) :
      if precedence(ne) < precedence(e) :
         print(o, "(%_)" % [ne])
      else :
         print(o, ne)
   defn print-operator (a:Exp, op:String, b:Exp) :
      print-nested(a)
      print(o, op)
      print-nested(b)
   match(e) :
      (e:Const) : print(o, value(e))
      (e:Variable) : print(o, name(e))
      (e:Log) : print(o, "ln(%_)" % [a(e)])
      (e:Add) : print-operator(a(e), " + ", b(e))
      (e:Subtract) : print-operator(a(e), " - ", b(e))
      (e:Multiply) : print-operator(a(e), " * ", b(e))
      (e:Divide) : print-operator(a(e), " / ", b(e))
      (e:Power) : print-operator(a(e), " ^ ", b(e))

;Overloaded operators
defn plus (a:Exp, b:Exp) : Add(a, b)
defn minus (a:Exp, b:Exp) : Subtract(a, b)
defn times (a:Exp, b:Exp) : Multiply(a, b)
defn divide (a:Exp, b:Exp) : Divide(a, b)
defn bit-xor (a:Exp, b:Exp) : Power(a, b)
defn ln (a:Exp) : Log(a)

;Differentiation algorithm
defn differentiate (e:Exp, x:Symbol) -> Exp :
   defn ddx (e:Exp) : differentiate(e, x)
  
   match(e) :
      (e:Const) :
         Const(0)
      (e:Variable) :
         if name(e) == x : Const(1)
         else : Const(0)
      (e:Add) :
         ddx(a(e)) + ddx(b(e))
      (e:Subtract) :
         ddx(a(e)) - ddx(b(e))
      (e:Multiply) :
         a(e) * ddx(b(e)) + b(e) * ddx(a(e))
      (e:Divide) :
         val num = b(e) * ddx(a(e)) - a(e) * ddx(b(e))
         val den = b(e) ^ Const(2)
         num / den
      (e:Power) :
         e * (b(e) * ddx(a(e)) / a(e) + ln(a(e)) * ddx(b(e)))
      (e:Log) :
         ddx(a(e)) / a(e)

;Map helper
defn map (f: Exp -> Exp, e:Exp) -> Exp :
   match(e) :
      (e:Add) : Add(f(a(e)), f(b(e)))
      (e:Subtract) : Subtract(f(a(e)), f(b(e)))
      (e:Multiply) : Multiply(f(a(e)), f(b(e)))
      (e:Divide) : Divide(f(a(e)), f(b(e)))
      (e:Power) : Power(f(a(e)), f(b(e)))
      (e:Log) : Log(f(a(e)))
      (e) : e

;Simplification algorithm
defn simplify (e:Exp) :
   defn const? (e:Exp, v:Int) :
      match(e) :
         (e:Const) : value(e) == v
         (e) : false
   defn one? (e:Exp) : const?(e, 1)
   defn zero? (e:Exp) : const?(e, 0)  

   match(map(simplify, e)) :
      (e:Add) :
         if zero?(a(e)) : b(e)
         else if zero?(b(e)) : a(e)
         else : e
      (e:Subtract) :
         if zero?(a(e)) : Const(-1) * b(e)
         else if zero?(b(e)) : a(e)
         else : e        
      (e:Multiply) :
         if one?(a(e)) : b(e)
         else if one?(b(e)) : a(e)
         else if zero?(a(e)) or zero?(b(e)) : Const(0)
         else : e
      (e:Divide) :
         if zero?(a(e)) : Const(0)
         else if one?(b(e)) : a(e)
         else : e
      (e:Power) :
         if one?(a(e)) : Const(1)
         else if zero?(b(e)) : Const(1)
         else : e
      (e:Log) :
         if one?(a(e)) : Const(0)
         else : e
      (e) : e

;Main program
defn main () :
   val x = Variable(`x)
   val [c1, c2, c3, c4] = [Const(1), Const(2), Const(3), Const(4)]
   val exp = c2 * x ^ c2 + (c1 + c3) * x + ln(x + c4)
   val dexp = differentiate(exp, `x)
   val sexp = simplify(dexp)

   println("Original Expression: %_" % [exp])
   println("Differentiated Expression: %_" % [dexp])
   println("Simplified Expression: %_" % [sexp])  

;Start!
main()

Exercises

  1. Our differentiation algorithm is general enough to always give the right answer (for the types of expressions it supports), but it's often too general. This is most obvious in the differentiation rule for Power expressions. The current rule handles the case where both the base and exponent are functions of x, but typically only one of the two is a function of x and the other is a constant expression. Look for these special cases and handle them more intelligently.
  2. Extend the simplifier to be able to simplify 1 + 3 to 4.
  3. Extend the simplifier to be able to simplify 1 + x + 3 to 4 + x.
  4. Extend the simplifier to be able to simplify x - x to 0.
  5. Extend the simplifier to be able to simplify x + 1 - x to 1.
  6. Extend the simplifier to be able to simplify x * x to x ^ 2.
  7. Extend the simplifier to be able to simplify x / x to 1.
  8. Extend the simplifier to be able to simplify x ^ 2 / x to x.
  9. Extend the simplifier to be able to simplify (x + 1) ^ 2 / (x + 1) to x + 1.