Wednesday, April 28, 2010

First Adventures in ScalaTest - Part III: Introducing Actors in the NetWorth Sample

This is the third post about my little net worth Scala sample. This third post really doesn't have a lot to do with ScalaTest, so the title is probably somewhat misleading. What this post is about is how to introduce some parallelism in the net worth calculation code by using Scala actors. The first post in the series showed some test written in ScalaTest, the second one showed the net worth sample itself, and this one starts where that post ended: The code for calculating the net worth based on an XML file of stock symbols and units is:
val stocksAndUnitsXml = scala.xml.XML.load(fileName)

 val tickersAndUnits =
    (Map[String, Int]() /: (stocksAndUnitsXml \ "symbol")) {
      (map, symbolNode) =>
        val ticker = (symbolNode \ "@ticker").toString
        val units = (symbolNode \ "units").text.toInt
        map(ticker) = units
    }

 def generateTotalNetWorthAndSimpleReport = {
    val report = new java.io.ByteArrayOutputStream
    Console.withOut(report) {

      println("Today is " + new java.util.Date)
      println("Ticket  Units  Closing Price($)  Total value($)")

      val startTime = System.nanoTime

      val netWorth = (0d /: tickersAndUnits) {
        (cumulativeWorth, symbolAndUnitsPair) =>
          val (symbol, units) = symbolAndUnitsPair
          val lastClosingPrice = getLatestClosingPriceFor(symbol)
          val value = lastClosingPrice * units

          println("%-7s  %-5d %-16f  %-16f".format(symbol, units, lastClosingPrice, value))

          cumulativeWorth + value
      }

      val endTime = System.nanoTime

      println("Total value of investments is $" + netWorth)
      println("Report tool %f seconds to generate".format( (endTime - startTime)/1000000000.0 ))

      (netWorth, report)
    }
  }
That code is explained in my last post.

Now I want to parallelize that calculation. Specifically I want to parallelize the web service calls to get the latests price for each stock symbol, so I'm going to fetch each price in a separate actor and then send the prices back to the main thread and accumulate the net worth there.

The code that gets the price for a single symbol and sends it to another actor called 'caller' is:
    caller ! (symbol, getLatestClosingPriceFor(symbol))
To do that in a separate actor for each symbol I call this method:
   private[this] def getLatestPricesForAllSymbols(caller: Actor) =
     tickersAndUnits.keys.foreach {
       symbol =>
         actor {
           caller ! (symbol, getLatestClosingPriceFor(symbol))
         }
    }
  }
the call to 'actor' starts a new actor an executes the function given to it asynchronously. The '!' method on the actor 'caller' sends the value given as an argument to the '!' method to 'caller'.

To receive a single symbol and price I do:
  receiveWithin(10000) {
    case (symbol: String, lastClosingPrice: Double) =>
      val units = tickersAndUnits(symbol)
      val value = lastClosingPrice * units
  }
'receiveWithin(10000)' blocks until a message arrives for the current actor, or times out after 10 seconds, so if this code is in the 'caller' actor it will receive one of the symbol/price pairs sent using '!' in the code above.

In order to accumulate the whole net worth I place the above in a function that also takes an 'cucumulativeWorth' and returns an updated cumulative worth:
   private[this] def receiveAndProcessOneSymbol(cumulativeWorth: Double) = {
    receiveWithin(10000) {
      case (symbol: String, lastClosingPrice: Double) =>
        val units = tickersAndUnits(symbol)
        val value = lastClosingPrice * units

        println("%-7s  %-5d %-16f  %-16f".format(symbol, units, lastClosingPrice, value))

        cumulativeWorth + value
    }
  }
and to receive all the symbol/price pairs I call that method as many times as there are symbols in the 'tickersAndUnits' map:
  private[this] def receiveAndAccumulateWorthForAllSymbol =
    (0d /: (1 to tickersAndUnits.size)) {
      (cumulativeWorth, index) =>
        receiveAndProcessOneSymbol(cumulativeWorth)
    }
Putting all that together the 'generateTotalNetWorthAndSimpleReport' becomes:
  def generateTotalNetWorthAndSimpleReport = {
    val report = new ByteArrayOutputStream
    Console.withOut(report) {
      calculateNetWorthAndWriteReportTo(report)
    }
  }

  private def calculateNetWorthAndWriteReportTo(report: ByteArrayOutputStream) = {
      writeReportHeader

      val startTime = System.nanoTime
      getLatestPricesForAllSymbols(self)
      val netWorth = receiveAndAccumulateWorthForAllSymbol
      val endTime = System.nanoTime

      writeReportFooter(netWorth, endTime, startTime)

      (netWorth, report)
    }

  private def writeReportHeader = {
      println("Today is " + new java.util.Date)
      println("Ticket  Units  Closing Price($)  Total value($)")
    }

  private def writeReportFooter(netWorth: Double, endTime: Long, startTime: Long): Unit = {
    println("Total value of investments is $" + netWorth)
    println("Report took %f seconds to generate".format((endTime - startTime) / 1000000000.0))
  }

  private[this] def getLatestPricesForAllSymbols(caller: Actor) =
    tickersAndUnits.keys.foreach {
      symbol =>
        actor {
          caller ! (symbol, getLatestClosingPriceFor(symbol))
        }
    }

  private[this] def receiveAndAccumulateWorthForAllSymbol =
    (0d /: (1 to tickersAndUnits.size)) {
      (cumulativeWorth, index) =>
        receiveAndProcessOneSymbol(cumulativeWorth)
    }

   private[this] def receiveAndProcessOneSymbol(cumulativeWorth: Double) = {
    receiveWithin(10000) {
      case (symbol: String, lastClosingPrice: Double) =>
        val units = tickersAndUnits(symbol)
        val value = lastClosingPrice * units

        println("%-7s  %-5d %-16f  %-16f".format(symbol, units, lastClosingPrice, value))

        cumulativeWorth + value
    }
Now the calculation is done by fetching prices in parallel using the lightweight Scala actors, and sending the results back as one way messages that are aggregated to the final result in the main thread. That was easy, don't you think?

Oh, and finally even in this little sample and on a small dataset this actually does speed up things. If I rerun the integration test shown in the last post:
class IntegrationSpec extends FlatSpec with ShouldMatchers {
  def withYahooStockPriceFinder(fileName: String)(testFunctionBody: (PortfolioManager) => Unit) {
    testFunctionBody(new PortfolioManager(fileName) with YahooStockPriceFinder)
  }
 
  "A PortFolioManager with the YahooStockPriceFinder" should "produce a asset report and calculate the total net worth" in {
    withYahooStockPriceFinder("src/configuration/stocks.xml") {
      pm =>
        val (totalNetWorth, report) = pm.generateTotalNetWorthAndSimpleReport
        totalNetWorth should be(85000d plusOrMinus 10000)

        print(report)
    }
  }
}
I get:
Today is Wed Apr 28 20:37:38 CEST 2010
Ticket  Units  Closing Price($)  Total value($)
MSFT     190   30,930000         5876,700000    
INTC     160   23,220000         3715,200000    
ALU      150   3,150000          472,500000     
ORCL     200   25,940000         5188,000000    
NSM      200   14,970000         2994,000000    
CSCO     250   27,170000         6792,500000    
AMD      150   9,550000          1432,500000    
IBM      215   130,060000        27962,900000   
VRSN     200   26,840000         5368,000000    
XRX      240   10,980000         2635,200000    
APPL     200   0,000000          0,000000       
HPQ      225   53,330000         11999,250000   
ADBE     125   35,420000         4427,500000    
SYMC     230   17,270000         3972,100000    
TXN      190   26,380000         5012,200000    
Total value of investments is $87848.55
Report took 0,837941 seconds to generate
which shows that the calculation took less than a second now, whereas the calculation took over 6 seconds in the sequential version.
Oh, and also notice that the symbols appear in a different order in the report than they did in the last post. The input data is exactly the same. The difference is that they are printed in the order that the main thread receives prices for them, and that ordering is not deterministic any more because it depends on how fast each web service call returns.

And that's it. I'm still having fun with learning Scala :-)