One of the many possible definitions of the binomial coefficient is the following:
(with )
This can be trivially implemented using the following approach:
def choose(n, m):
if n == m or m == 0:
return 1
return choose(n - 1, m - 1) + choose(n - 1, m)
Unfortunately, this calculates many values multiple times:
⎛5⎞
___⎝3⎠___
/ \
⎛4⎞ ⎛4⎞
⎝2⎠ ⎝3⎠
/ \ / \
⎛3⎞ ⎛3⎞ ⎛3⎞ ⎛3⎞
⎝1⎠ ⎝2⎠ ⎝2⎠ ⎝3⎠
/ \ / \ / \
⎛2⎞ ⎛2⎞ ⎛2⎞ ⎛2⎞ ⎛2⎞ ⎛2⎞
⎝0⎠ ⎝1⎠ ⎝1⎠ ⎝2⎠ ⎝1⎠ ⎝2⎠
/ | | \ | \
⎛1⎞ ⎛1⎞ ⎛1⎞ ⎛1⎞ ⎛1⎞ ⎛1⎞
⎝0⎠ ⎝1⎠ ⎝0⎠ ⎝1⎠ ⎝0⎠ ⎝1⎠
When we combine the common sub-trees, we get this structure:
⎛5⎞
⎝3⎠
/ \
⎛4⎞ ⎛4⎞
⎝2⎠ ⎝3⎠
/ \ / \
⎛3⎞ ⎛3⎞ ⎛3⎞
⎝1⎠ ⎝2⎠ ⎝3⎠
/ \ / \
⎛2⎞ ⎛2⎞ ⎛2⎞
⎝0⎠ ⎝1⎠ ⎝2⎠
/ \
⎛1⎞ ⎛1⎞
⎝0⎠ ⎝1⎠
That looks peculiarly well-ordered! Note that the upper number always corresponds to the “level” of the calculation, and that the lower number stays the same along a falling diagonal. It also seems to have a fairly rectangular shape.
We can interpret the numbers as coordinates . We can then do a coordinate transform , which “straightens” the above diagram:
2| ⎛2⎞_⎛3⎞_⎛4⎞_⎛5⎞ result
| ⎝0⎠ ⎝1⎠ ⎝2⎠ ⎝3⎠
| | | |
1| ⎛1⎞_⎛2⎞_⎛3⎞_⎛4⎞
| ⎝0⎠ ⎝1⎠ ⎝2⎠ ⎝3⎠
| | | |
0| ⎛1⎞ ⎛2⎞ ⎛3⎞
| ⎝1⎠ ⎝2⎠ ⎝3⎠
|________________
0 1 2 3
Whereas in the above tree the value of each node was the sum of the two lower nodes, in this straightened net the value is the sum of the lower node and the left node. The bottom row and the left-most column are the termination cases, and are always equal to .
This rectangular form is very well suited for an iterative calculation: define a 2D array, fill in the edges, then calculate the remaining fields from its neighbours.
What are the bounds of that field? The height is determined by – that is how often I can decrement the upper value until both values are equal and I hit a termination case. The width is determined by – that is how often I can decrement the lower value until I hit zero. In the above example, the lower left field is missing. This field is unreachable from its neighbours, since both its neighbours are base cases of the recursion. Unless we draw that table for a base case, then we will have more than one row and column and this will always be the only missing field.
When we calculate the field as described above from the bottom upwards, we will notice that we never read again from the bottom row. Since each node only accesses the node to the left and to the bottom, we can calculate the next-upper row by modifying a given row in-place: for each value in the row from left to right, and skipping the left-most element, we add the left element to the current element:
# calculate the next row
for i in range(1, len(row)):
row[i] += row[i - 1]
How is the first row initialized?
For the row zero, no calculations are needed and that row can therefore be skipped.
The first row to be calculated will be the row at vertical index 1.
When we initialize it, the left-most item will be 1
,
and the other items will be taken from the lower row, which are also all 1
.
Therefore, the first row is initialized to consist of [1] * (y + 1)
(an array containing times the number 1).
Note that this also avoids touching the missing element in the zeroth row.
Since we skip the zeroth row, we will only calculate rows. In the last row, the value of our binomial coefficient will be found in the rightmost item. Our code then becomes:
def choose(n, m):
if n == m or m == 0:
return 1
row = [1] * (m + 1)
for _ in range(n - m):
for i in range(1, len(row)):
row[i] += row[i - 1]
return row[-1]
This will always do calculations, so our time complexity is . The space complexity depends on the size of the row, which in the above code is , therefore our space complexity is . However, the binomial coefficient is symmetric: . We can use this to minimize the size of the array, and get space complexity:
def choose(n, m):
if n == m or m == 0:
return 1
if n - m < m:
m = n - m
row = [1] * (m + 1)
for _ in range(n - m):
for i in range(1, len(row)):
row[i] += row[i - 1]
return row[-1]
Note that Python's range(b)
function corresponds to the range ,
and the two-argument version range(a, b)
corresponds to the range .
Correctness
Proving the correctness of this function is a bit tedious. Central to the correctness of this algorithm is the structure of each row. The structure can be described as
This is obviously true for the initial state since it is
Applying the row transition must yield a row of the same structure, but for :
This can be shown by doing an inductive proof over all row indices. Note that .
If the initial state and the row transition are correct, the other important factor for correctness is the number of row transitions. This number is correct when after transitions from the initial state, the rightmost field is . We know that the initial value of the rightmost field is , and that each transition modifies this field as . Therefore, after transitions we have . We choose , and when we substitute that, we see that the algorithm terminates with in the rightmost field.
While this is not a strict proof, it seems very likely that the algorithm is correct.
- next post: Good API Documentation
- previous post: An Overview Of The Marpa Parser