Fork me on GitHub

Hacking Scala

#scala #hacking

April 28, 2013 at 12:46am

Essential Scala Collection Functions

Let me ask you a question. How many times you wrote similar code?

case class User(id : Int, userName: String)

val users: List[User] = // ....
val resultUsers: List[User] = // ....

for (i <- 0 until users.size) {
    if (users(i).userName != "test") {
        resultUsers += users(i)
    }
}

May imperative languages like Java, C, C++ have very similar approach for iteration. Some of them have some syntactic sugar (like Java for each loop). There are many problems with this approach. I think the most important annoyance is that it’s hard to tell, what this code actually does. Of course, if you have 10+ years Java experience, you can tell it within a second, but what if the body of the for is more involving and has some state manipulatin? Let me show you another example:

def formatUsers(users: List[User]): String = {
    val result = new StringBuilder 

    for (i <- 0 until users.size) {
        val userName = users(i).userName

        if (userName != "test") {
            result append userName

            if (i < users.size - 1) {
                result append ", "
            }
        }
    }

    return result.toString
}

Was it easy to spot the purpose of this code? I think more important question: how do you read and try to understand this code? In order to understand it you need to iterate in your head - you create simple trace table in order to find out how values of i and result are changing with each iteration. And now the most important question: is this code correct? If you think, that it is correct, then you are wrong! It has an error, and I leave it to you to find it. It’s pretty large amount of code for such a simple task, isn’t it? I even managed to make an error it.

I’m going to show you how this code can be improved, so that intent is more clear and a room for the errors is much smaller. I’m not going to show you complete Scala collection API. Instead I will show you 6 most essential collection functions. These 6 function can bring you pretty far in your day-to-day job and can handle most of your iteration needs. I going to cover following functions: filter, map, flatMap, foldLeft, foreach and mkString.

Also note, that these functions will work on any collection type like List, Set or even Map (in this case they will work on tuple (key, value))

filter

filter is easy one. Let me show you how it can be used in order to simplify my first example:

val resultUsers = users filter (_.userName != "test")

As you can see, what left is most essential part of my algorithm. filter takes a predicate function as it’s argument (function that returns Boolean: User => Boolean in my example) and uses it to filter the collection.

map

map converts one collection into another with the help of your function. Here is an example:

val userNames = users map (_.userName)

In this code I just converting my list of users in list of user names. As you can guess userNames has type List[String].

What if I have two kinds of users: AdminUser and NormalUser. For each of them I want to use different logic - I don’t want to leak user names of my admins! In place of function I can use something called partial function - it’s special function that can have only one argument and can be defined only for subset of it’s values. The best thing about it, is that it has special syntax, that allows to pattern match on it’s argument. Here is an example that describes how it looks like:

case class AdminUser(userName: String)
case class NormalUser(userName: String)

val users = List(AdminUser("bob"), NormalUser("john"))
val userNames = users map {
    case NormalUser(userName) => userName
    case _ => "secret"
}

assert(userNames == List("secret", "john"))

This feature can be also useful, if you are working with Map collection. During iteration Map gives you tuple with key and value in it. First attempt to use map function can look like this:

val users = Map(1 -> User(1, "bob"), 2 -> User(2, "john"))
val userNames = users map (_._2.userName)

_2 is field on tuple that returns it’s second value. As you can see it’s a little bit cryptic. But you can improve it by using partial function:

val userNames = users map {case (id, user) => user.userName}

flatMap

flatMap is surprisingly very similar to map. It also converts one collection into another. The difference is that it will also flatten all elements of the resulting collection. So if your function returns some kind of container like List or Set, then flatMap will also put all it’s elements in the result. Let me show you it in action:

val suffixes = List("ain", "o")
val words = suffixes flatMap (suffix => List("P" + suffix, "G" + suffix))

assert(words == List("Pain", "Gain", "Po", "Go"))

As you can see, everything is merged in one resulting list. If you will use map, the result would be: List(List(Pain, Gain), List(Po, Go)). Here is another, more practical, example:

def downloadAllImages(url: String): List[File] = //....

val urls = List("http://www.google.com", "http://www.example.com")
val images: List[File] = urls flatMap downloadAllImages

Now that you already know map and filter functions, you probably already have some ideas, how formatUsers can be simplified. We can try to partly rewrite it using map and filter:

val safeUserNames: List[String] =
    users map (_.userName) filter (_ != "test")

But what if I have AdminUser and NormalUser like in previous examples, but this time I want to show only normal users. In order to archive this in one flatMap we can use partial function mentioned before. I also need to tell flatMap somehow, that I don’t want some of the items from the original collection - namely admin users. We can achieve this with very useful class - Option. It’s also container which can hold either a single value (in this case it has type Some[T]) or nothing at all (in this case it has type None). As you can guess flatMap will just ignore all None values because they do not contain anything in them:

val users = List(AdminUser("bob"), NormalUser("test"), NormalUser("john"))
val userNames = users flatMap {
    case NormalUser(userName) if userName != "test" => Some(userName)
    case _ => None
}

assert(userNames == List("john"))

foldLeft

foldLeft is just combining all values of the collection into one resulting value. For example you can sum all numbers in the collection with initial value 0:

val nums = List(1, 2, 3, 4, 5)
val sum = nums.foldLeft(0)(_ + _)

foldLeft has 2 argument lists - first is the initial value and the second is your function that tells foldLeft how you want to combine two items together. Following imperative code demonstrates how foldLeft actually works, so it will produce the same results:

var sum = 0

for (i <- 0 until nums.size) {
    sum = sum + nums(i)
}

foldLeft can be very useful, especially if you want to pass some value through the whole iteration process. Let’s continue our implementation of formatUsers and combine all user names in nice comma separated string:

userNames.tail.foldLeft(userNames.head)(_ + ", " + _)

As you can see used 2 new small functions: head that returns the first element of the collection and tail - it returns collection of the same type with all element except the first one. But this kind of operations can be simplified, so I encourage you to read further.

foreach

foreach just iterates through the whole collection and execute your function on each of the elements. The return type of your function and foreach is Unit. This means, that it does not returns anything of interest and even if you return something - it would be just ignored. This function is useful for making side effects. For example I can print all elements of the collection like this:

users foreach println

It’s behavior is very similar to Java’s own for each loop. Here is equivalent in Java:

// Java code
for (User user: users) {
    println(user)
}

mkString

And the last, but not least - mkString. It’s small and useful function that just joins all elements of the collection in one string. You can also provide separator that would be used between elements. Let’s continue with our new formatUsers implementation and format user names:

userNames mkString ", "

mkString is very simple function and sometimes people wonder whether it can be useful for some more complicated tasks. Even in our simple case - we originally don’t have list of user names, but we have the list of the users. mkString is not enough for this. But we can use already familiar functions to achieve our goal - print nice list of user names by having only list of users:

users map (_.userName) mkString ", "

This function composition is very nice way to work with the collections.

Improving formatUsers

Now that you already know all of the essential tools that collection framework provides us, it should be straightforward to write a better implementation of the formatUsers method. Here is the whole implementation:

def formatUsers(users: List[User]): String = 
    users map (_.userName) filter (_ != "test") mkString ", "

The first and most important thing to notice - it’s correct. I had not enough room to make the same error. I also was able to extract the most essential parts of the algorithm and put it in front of you. Boilerplate and language ceremony is also reduced dramatically.

Another important aspect of this implementation is how you reading and writing it. Now you don’t think in terms of iterations and variable mutation. You are concentrating yourself on the set of collection transformations. Each transformation/step produces some correct and meaningful result. In my example I have three distinct steps: map, filter and make string. Each of them produces something that is valid and self contained.

In imperative version I had only 2 meaningful states: at the beginning and at the end of the method. Everything in between is broken/incomplete state, that can’t be used for anything else. When I’m trying to understand it, I need to keep the whole algorithm in head in order to understand how it works and what results it can produce. In contrast, in my function version each step is self contained and can be analyzed separately.

In my new implementation I can, without any code modifications, extract users map (_.userName), userNames filter (_ != "test") or safeUserNames mkString ", " to some methods and they can potentially be reused by other code. So my ability to reuse this code is also increased greatly.

for comprehension

map, flatMap, filter and foreach are so popular, that Scala even has special syntax for them and you already saw it. for is not a for loop that you know from languages like Java and C. It’s much more powerful and allows following three things: iterate collections, filter them and assign variables. Here is an example:

val safeUserNames = for (user <- users; userName = user.userName; if userName != "test") yield userName

The same can also be written without semicolons with following syntax:

val safeUserNames = 
    for {
        user <- users 
        userName = user.userName
        if userName != "test"
    } yield userName

What’s interesting about for is that under the hood it would be compiled to the series of map, flatMap, filter and foreach method calls. It can provide very nice syntax when you have some nested loops like this:

val allUserPrivileges = 
    for {
        user <- users
        role <- user.roles
        privilege <- user.privileges
    } yield privilege

Performance considerations

Some of you may be concerned about performance, and it’s valid argument. It’s clear, that now in each step I’m iterating through the whole collection. In my concrete example it’s not a problem because I need this anyway even in imperative version. But imperative algorithms often use break and continue in order to stop iteration when needed value is found or required amount of the result elements is reached. If you will use methods like map or filter, then collection would be always completely processed.

There are two nice ways to deal with this. The first one is laziness. If you will call view on a collection, then you will receive lazy version of it. Here is an example:

val safeUserNames = users.view map (_.userName) filter (_ != "test")

At this point neither map nor filter are evaluated. safeUserNames will evaluate only when it would be explicitly asked to make it. You can use force method to make it. In this case you will apply filter and map on every element in the collection. mkString and foldLeft work the same way. But you don’t have to make it. For example you can take first 5 safe user names like this:

safeUserNames take 5

If you don’t know how many element you need, but you know when to stop, then use takeWhile:

safeUserNames takeWhile (un => un(0) == 'a')

assuming that list is sorted, this code will return only user names that start with ‘a’.

The second way to improve performance is to use parallel collections. All you need to do is to call par method on the collection - it will return parallel version of it. So this map operation will process it’s elements in parallel:

users.par map (_.userName)

Conclusion

I encourage you to use more functional style when you are working with collections. It has a lot of advantages over imperative style. But I think the most important is that it makes you think in terms of collection transformations which makes it much easier to spot the intent of the code. This makes it easier for you to write code and for other people to read it.

I hope you will find this post interesting and inspiring. If you have any questions, suggestions or something else to tell me - just drop a comment, I would be very happy :) You feedback is highly appreciated!

Notes

  1. hacking-scala posted this