Exercise 5.2
A major use of inheritance is in writing code that’s meant to be
extended or customized in various ways—especially in libraries or
frameworks. To illustrate, start by adding the following function to
your stock.py
program:
# stock.py
...
def print_portfolio(portfolio):
'''
Make a nicely formatted table showing portfolio contents.
'''
headers = ('Name','Shares','Price')
for h in headers:
print '%10s' % h,
print
print ('-'*10 + ' ')*len(headers)
for s in portfolio:
print '%10s %10d %10.2f' % (s.name, s.shares, s.price)
Add a little testing section to the bottom of your stock.py
file that runs the above function:
if __name__ == '__main__':
portfolio = read_portfolio('Data/portfolio.csv')
print_portfolio(portfolio)
When you run your stock.py
, you should get this output:
Name Shares Price
---------- ---------- ----------
AA 100 32.20
IBM 50 91.10
CAT 150 83.44
MSFT 200 51.23
GE 95 40.37
MSFT 50 65.10
IBM 100 70.44
(a) An Extensibility Problem
Suppose that you wanted to modify the print_portfolio()
function to support
a variety of different output formats such as plain-text,
HTML, CSV, or XML. To do this, you could try to write one gigantic
function that did everything. However, doing so would likely lead to
an unmaintainable mess. Instead, this is a perfect opportunity to
use inheritance instead.
To start, focus on the steps that are involved in a creating a table.
At the top of the table is a set of table headers. After that, rows
of table data appear. Let’s take those steps and and put them into
their own class. Create a file called tableformat.py
and define the
following class:
# tableformat.py
class TableFormatter(object):
def headings(self, headers):
'''
Emit the table headings.
'''
raise NotImplementedError()
def row(self, rowdata):
'''
Emit a single row of table data.
'''
raise NotImplementedError()
This class does nothing, but it serves as a kind of design specification for
additional classes that will be defined shortly.
Modify the print_portfolio()
function so that it accepts a TableFormatter
object
as input and invokes methods on it to produce the output. For example, like this:
# stock.py
...
def print_portfolio(portfolio, formatter):
'''
Make a nicely formatted table showing portfolio contents.
'''
formatter.headings(['Name', 'Shares', 'Price'])
for s in portfolio:
# Form a row of output data (as strings)
rowdata = [s.name, str(s.shares), '%0.2f' % s.price]
formatter.row(rowdata)
Finally, try your new class by modifying the main program like this:
# stock.py
...
if __name__ == '__main__':
from tableformat import TableFormatter
portfolio = read_portfolio('Data/portfolio.csv')
formatter = TableFormatter()
print_portfolio(portfolio, formatter)
When you run this new code, your program will immediately crash with a NotImplementedError
exception. That’s not
too exciting, but continue to the next part.
(b) Using Inheritance to Produce Different Output
The TableFormatter
class you defined in part (a) is meant to be extended via inheritance.
In fact, that’s the whole idea. To illustrate, define a class TextTableFormatter
like this:
# tableformat.py
...
class TextTableFormatter(TableFormatter):
'''
Emit a table in plain-text format
'''
def headings(self, headers):
for h in headers:
print '%10s' % h,
print
print ('-'*10 + ' ')*len(headers)
def row(self, rowdata):
for d in rowdata:
print '%10s' % d,
print
Modify your main program in stock.py
like this and try it:
# stock.py
...
if __name__ == '__main__':
from tableformat import TextTableFormatter
portfolio = read_portfolio('Data/portfolio.csv')
formatter = TextTableFormatter()
print_portfolio(portfolio, formatter)
This should produce the same output as before:
Name Shares Price
---------- ---------- ----------
AA 100 32.20
IBM 50 91.10
CAT 150 83.44
MSFT 200 51.23
GE 95 40.37
MSFT 50 65.10
IBM 100 70.44
However, let’s change the
output to something else. Define a new class
CSVTableFormatter
that produces output in CSV format:
# tableformat.py
...
class CSVTableFormatter(TableFormatter):
'''
Output portfolio data in CSV format.
'''
def headings(self, headers):
print ','.join(headers)
def row(self, rowdata):
print ','.join(rowdata)
Modify your main program as follows:
# stock.py
...
if __name__ == '__main__':
from tableformat import CSVTableFormatter
portfolio = read_portfolio('Data/portfolio.csv')
formatter = CSVTableFormatter()
print_portfolio(portfolio, formatter)
You should now see CSV output like this:
Name,Shares,Price
AA,100,32.20
IBM,50,91.10
CAT,150,83.44
MSFT,200,51.23
GE,95,40.37
MSFT,50,65.10
IBM,100,70.44
Using a similar idea, define a class HTMLTableFormatter
that produces a table with the following output:
<tr> <th>Name</th> <th>Shares</th> <th>Price</th> </tr>
<tr> <td>AA</td> <td>100</td> <td>32.20</td> </tr>
<tr> <td>IBM</td> <td>50</td> <td>91.10</td> </tr>
Test your code by modifying the main program to create a
HTMLTableFormatter
object instead of a
CSVTableFormatter
object.
(c) Polymorphism in Action
A major feature of object-oriented programming is that you can simply
plug an object into a program and it will work without having to
change any of the existing code. For example, if you wrote a program
that expected to use a TableFormatter
object, it would work no
matter what kind of TableFormatter
you actually gave it. This
behavior is sometimes referred to as "polymorphism."
One potential problem is making it easier for the user to simply pick
the formatter that they want. This can sometimes be fixed by defining
a helper function. In the tableformat.py
file, add a
function create_formatter(name)
that allows a user to create a
formatter given an output name such as 'txt'
, 'csv'
, or 'html'
.
For example:
# stock.py
...
if __name__ == '__main__':
from tableformat import create_formatter
portfolio = read_portfolio('Data/portfolio.csv')
formatter = create_formatter('csv')
print_portfolio(portfolio, formatter)
When you run this program, you’ll see output such as this:
Name,Shares,Price
AA,100,32.20
IBM,50,91.10
CAT,150,83.44
MSFT,200,51.23
GE,95,40.37
MSFT,50,65.10
IBM,100,70.44
Try changing the format to 'txt'
and 'html'
just to make sure
your code is working correctly.
If the user provides a bad output format to the create_formatter()
function, have it raise a RuntimeError
exception. For example:
>>> from tableformat import create_formatter
>>> formatter = create_formatter('xls')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "tableformat.py", line 68, in create_formatter
raise RuntimeError('Unknown table format %s' % name)
RuntimeError: Unknown table format xls
>>>