Nov 17, 2009

Learning Scala: Euler's Method

I've been messing around more with Scala to see some of its more interesting features, two of which are pattern matching and currying. These are really handy in certain situations, and I decided to share them with you today.

The example I'm using is the Euler method for approximating a solution to a differential equation. This is probably not the most intuitive way of doing the Euler method, but it uses a lot of Scala's interesting features and I think it is a good way of illustrating them.

The Euler method works like this. You have a differential equation in the form y' = f(x, y) and an initial value (x0, y0). You calculate the slope at the initial value, which is f(x0, y0). You then take a step of size h in the x-direction following the slope, which brings you to (x0 + h, y0 + h * f(x0, y0) ). This becomes your new point, and you repeat until you are satisfied with the results. A smaller h will lead to more accurate results, but with more computation.

Here is the not-so-intuitive code:
object Euler{
def main(args: Array[String]){
// Use the differential equation y' = x^2 + y^2
val my_de = solve_de((x, y) => Math.pow(x, 2) + Math.pow(y, 2)) _

// Use a step size of 0.1, and do 10 steps
val generator = my_de(0.1, 10)

// try with different initial conditions
println(generator(0, 1))
println(generator(1, 1))

// create another generator with a smaller step size
val precise_generator = my_de(0.05, 20)
println(precise_generator(0, 1))
println(precise_generator(1, 1))
}

def solve_de(func: (Double, Double) => Double)(step : Double, iterations: Int)(x_0: Double, y_0: Double) =
(List(List(), List(x_0, y_0)) /: (0 until iterations))((s, i) => s match {
case List(res, List(x, y)) => {
val fxy = func(x, y)
List(res ::: List(fxy), List(x + step, y + step*fxy))
}
}).head

}
Look first at the definition for solve_de. It has three argument lists! One of them takes a function of type (Double, Double) => Double, also known as a function that takes two Doubles and returns a Double. The second takes a Double and an Int, and the third takes two Doubles.
The first parameter list takes a function representing the differential equation we are approximating, the second takes the parameters for the approximation which are the step size and the number of steps we take, the third takes the initial values.

If you look back at the main function, you'll see that we call solve_de with only one of its argument lists, and a _ at the end! This is called currying - it returns a new function with the first parameter list filled out. We save this as our variable my_de (note we use val, so this variable is immutable) which is a solution generator for our specific differential equation described above. You use the _ at the end to tell Scala that you are only partially applying the solve_de function.

Next, we call my_de with the values 0.1 and 10. This creates another function which solves our differential equation using a step size of 0.1 and 10 steps. We can then call this function with different initial conditions to get different solutions to the differential equation. Each time we call it, it returns a list of 10 points that lie along our solution curve. If we wanted to, we could then plot this curve using some graphing library.
Note: for some reason here you don't have to call my_de with the _ at the end, it is probably for some reason that I do not yet understand.

After that, we create a more precise generator with half the step size. I double the number of steps so that if we were to graph this alongside the first list, they would have the same range of x-values.

The next bit is the solve_de function, which illustrates some of the more interesting features of Scala. First one (which probably isn't that interesting) is that there are no curly brackets around the body of solve_de. If you have a function in Scala that is only one line, you can just write:
def foo(x) = ...
You don't need to include curly brackets.
We have a fold using the /: operator. If you've used inject() in Ruby then you'll know what I'm talking about, otherwise take a look here for a description of what fold (aka reduce1) is. In Scala you can write this:
(0 /: myList)(some function f)
This does a left fold of myList with the initial value 0, using the function f - aka f(f(0, myList[0]), myList[1])....
The initial value is a list that looks like this:
[ [], [x0, y0] ]
The first element of this list is where we will be sticking the approximated values, the second element is our current point.
We fold this list into (0 until iterations), which is a range equivalent to 0..(iterations - 1) in Ruby. This is a very interesting piece of Scala, because it shows some of the fancier features. Scala is a pure object-oriented language so the 0 there is actually an object of class Int. Effectively what we are doing is calling 0.until(iterations), which returns a Range object that we can use fold on. In Scala for a method which only takes one argument, you don't need to put the . or the brackets.
However there is no until() method for Int. Where does this until() come from? Scala has a feature called implicit functions, which are used for implicitly converting one type into another - like the auto-boxing between int and Integer in Java. Little do you know, there is actually a class called RichInt which supplies the until() method, and a bunch of other handy things (you could write 0 to 5 if you like). When you call until() on 0, Scala first looks to see if Int has an until() method. Since it doesn't, it checks to see if there is an implicit conversion for Ints into a class that does have an until() method. Since there is only one such class (RichInt), it automatically replaces your statement with something like toRichInt(0).until(iterations). If there were more than one implicit conversions however, then Scala would give you a compile error and you would have to explicitly provide your conversion. The main difference between this and auto-boxing in Java is that you can provide your own implicit conversions between any classes you like, provided they don't result in ambiguities.

The next step is to provide a function to the fold operator to use for folding. After the => we see
s match {
This matches the variable s (the "accumulator") against a set of patterns. This is another feature of Scala called pattern matching. This example doesn't really do it justice since we only have one pattern here, and it is just so that we can have a nice way of extracting the variables out of s without using head() and tail(). I think I might post something more detailed on pattern matching in the future. Anyway, we use the expression List(res, List(x, y)) to match s, and this extracts out our current accumulated values as res, and the current position into x,y. We can then compute f(x, y) and put it in fxy (this is to save some time in computation) and then return:
[ new res, [ x + step, y + step * f(x, y) ] ]
The new res value is just res with f(x, y) stuck on the end (that's what the ::: operator does, it concatenates two lists).

Two small syntactic things to note:
- There is not a single semicolon in this program. Scala doesn't need the semicolons at the end of lines, although you can include them if you like.
- There are no return keywords in this program, even though we have functions. Scala doesn't require the return keyword, it will insert it where it thinks you are trying to return something.

So I'm not sure if I'd recommend you actually write Euler's method like this, instead you would probably write it something like this in Scala:
def solve_de(func: (Double, Double) => Double)(step : Double, iterations: Int)(x_0: Double, y_0: Double) = {
var x = x_0
var y = y_0
var res = List[Double]()

for (i <- 0 until iterations){
val fxy = func(x, y)
x = x + step
y = y + step * fxy
res = res ::: List(fxy)
}
res
}
However in this case, you wouldn't be able to use all those fun little toys that Scala gives you, so I did it in a different way.

1In Scala fold and reduce are two different things: fold takes an initial element, where reduce uses the first element of the list as the initial element. Reduce will throw an exception if used on an empty list, where fold will just return the initial element. In non-Scala languages, reduce and fold are the same thing.

No comments: