Thursday, April 15, 2010

First Adventures in ScalaTest - Part II

Following up on my last post, where I showed some test code written with ScalaTest for a simple PotfolioManager demo I'll show the PortfolioManager itself in this post. As mentioned in the last post the code is largely (almost entirely) based on the last chapter of Programming Scala by Venkat Subramaniam.

The PortFoliioManager is an abstract class with a default constructor that takes a path to an XML file and loads the contents:

abstract class PortfolioManager(fileName: String) extends StockPriceFinder {
    val stocksAndUnitsXml = scala.xml.XML.load(fileName)
  //more..
}


The data in the XML file is expected to look something like this:

<symbols>
<symbol ticker="APPL"><units>200</units></symbol>
<symbol ticker="ADBE"><units>125</units></symbol>
</symbols>







and is parsed through XPath queries applied to the XML with the '\' operator. This method on the PortfolioManager does the parsing:

def tickersAndUnits =  
  (Map[String, Int]() /: (stocksAndUnitsXml \ "symbol")) {
    (map, symbolNode) =>
      val ticker = (symbolNode \ "@ticker").toString
      val units = (symbolNode \ "units").text.toInt
      map(ticker) = units
  }

so what goes on there is that we get the list of symbol XML elements from the stocksAndUnitsXml value by use of operator '\' from the Scala standard library. We then iterate over that list with '/:' (aka foldLeft). Through the iteration we build up a map from strings to ints mapping from stock symbols to the number of units. Again the data is pulled out of the XML with XPath and '\'.

Thats all just a bit of warm up. What the PortfolioManager is supposed to do is calculate and report the net worth of the stock portfolio described in the XML. In itself that's not a big deal, but I think its fun to see how it's handled in Scala - it turns out to be sort of neat.

The net worth is calculated by fetching the latest price of each individual stock symbol multiply by the units and sum up. To get the stock prices the code must query some external source. E.g. download.finance.yahoo.com, but I don't want to do that directly because I want to be able to run my tests without depending or waiting for Yahoo. That's why the ProfolioManager is abstract: The call to get the latest price for a symbol is factored out to the StockPriceFinder traits and is abstract:

trait StockPriceFinder {
  protected def getLatestClosingPriceFor(tickerSymbol: String) : Double
}

Since Portfolio manager extends StockPriceFinder it has to be abstract. Client code has to either instantiate concrete subclasses or mix in a trait implementing the getLatestClosingPriceFor method at instantiation time. That's what the tests from the last post did. And that's what makes it easy to switch between a fake implementation for unit tests, and a real implementation for "production" code and integration test code. I'll show an integration test towards the end of the post, but for now lets return to calculating the net worth and print a simple report to an output stream:

def generateTotalNetWorthAndSimpleReport = {
    val report = new java.io.ByteArrayOutputStream
    Console.withOut(report) {

      println("Today is " + new java.util.Date)
      println("Ticket  Units  Closing Price($)  Total value($)")

      val startTime = System.nanoTime
      val netWorth = (0d /: tickersAndUnits) {
        (cumulativeWorth, symbolAndUnitsPair) =>
          val (symbol, units) = symbolAndUnitsPair
          val lastClosingPrice = getLatestClosingPriceFor(symbol)
          val value = lastClosingPrice * units
          println("%-7s  %-5d %-16f  %-16f".format(symbol, units, lastClosingPrice, value))
          cumulativeWorth + value
      }
      val endTime = System.nanoTime

      println("Total value of investments is $" + netWorth)
      println("Report took %f seconds to generate".format( (endTime - startTime)/1000000000.0 ))

      (netWorth, report)
    }
  }

Things to notice in the code above are:
  • Console.withOut(report) redirects printlns in its function parameter to the output stream 'report'
  • The method returns two values packaged up in a tuple simply by ending with the line '(networth, report)'
  • The fold left operation is used again this time to iterate over the map from symbols to units and accumulating the worth
  • The price of a symbol is found by calling the abstract getLatestClosingPriceFor

To use the code we provide an implementation for the StockPriceFinder, a path and call generateTotalNetWorthAndSimpleReport:

class IntegrationSpec extends FlatSpec with ShouldMatchers {
  def withYahooStockPriceFinder(fileName: String)(testFunctionBody: (PortfolioManager) => Unit) {
    testFunctionBody(new PortfolioManager(fileName) with YahooStockPriceFinder)
  }

  "A PortFolioManager with the YahooStockPriceFinder" should "produce a asset report and calculate the total net worth" in {
    withYahooStockPriceFinder("src/configuration/stocks.xml") {
      pm =>
        val (totalNetWorth, report) = pm.generateTotalNetWorthAndSimpleReport
        totalNetWorth should be(85000d plusOrMinus 10000)

        print(report)
    }
  }
}

which produces this output:

Today is Wed Apr 14 20:54:02 CEST 2010
Ticket  Units  Closing Price($)  Total value($)
XRX      240   10.540000         2529.600000   
NSM      200   15.650000         3130.000000   
SYMC     230   17.015000         3913.450000   
ADBE     125   34.875000         4359.375000   
VRSN     200   26.870000         5374.000000   
CSCO     250   26.870100         6717.525000   
TXN      190   26.737500         5080.125000   
ALU      150   3.400000          510.000000     
IBM      215   131.050000        28175.750000   
INTC     160   23.470000         3755.200000   
ORCL     200   26.290000         5258.000000   
APPL     200   0.000000          0.000000      
HPQ      225   54.530000         12269.250000   
AMD      150   9.940000          1491.000000   
MSFT     190   30.730000         5838.700000   
Total value of investments is $88401.975
Report took 6.151573 seconds to generate

Next post I'll introduce some actors in the code to parallelize it, and to see if that brings a speed up.