Saturday, August 27, 2016

Balanced binary search trees

The type of "association tables" (binary search trees).

type (α, β) t =
| Empty
| Node of (α , β) t * α * β * (α, β) t * int
There are two cases : a tree that is empty or, a node consisting of a left sub-tree, a key, the value associated with that key, a right sub-tree and, an integer representing the "height" of the tree (the number of nodes to traverse before reaching the most distant leaf).

The binary search tree invariant will be made to apply in that for any non empty tree $n$, every node in the left sub-tree is ordered less than $n$ and every node in the right sub-tree of $n$ is ordered greater than $n$ (in this program, ordering of keys is performed using the function).

This function, height, given a tree, extracts its height.

let height : (α, β) t -> int = function
  | Empty -> 0
  | Node (_, _, _, _, h) -> h

The value empty, is a constant, the empty tree.

let empty : (α, β) t = Empty

create l x d r creates a new non-empty tree with left sub-tree l, right sub-tree r and the binding of key x to the data d. The height of the tree created is computed from the heights of the two sub-trees.

let create (l : (α, β) t) (x : α) (d : β) (r : (α, β) t) : (α, β) t =
  let hl = height l and hr = height r in
  Node (l, x, d, r, (max hl hr) + 1)

This next function, balance is where all the action is at. Like the preceding function create, it is a factory function for interior nodes and so takes the same argument list as create. It has an additional duty though in that the tree that it produces takes balancing into consideration.

let balance (l : (α, β) t) (x : α) (d : β) (r : (α, β) t) : (α, β) t =
  let hl = height l and hr = height r in
  if hl > hr + 1 then
    match l with
In this branch of the program, it has determined that production of a node with the given left and right sub-trees (denoted $l$ and $r$ respectively) would be unbalanced because $h(l) > hr(1) + 1$ (where $h$ denotes the height function).

There are two possible reasons to account for this. They are considered in turn.

    (*Case 1*)
    | Node (ll, lv, ld, lr, _) when height ll >= height lr ->
      create ll lv ld (create lr x d r)
So here, we find that $h(l) > h(r) + 1$, because of the height of the left sub-tree of $l$.
    (*Case 2*)
    | Node (ll, lv, ld, Node (lrl, lrv, lrd, lrr, _), _) ->
      create (create ll lv ld lrl) lrv lrd (create lrr x d r)
In this case, $h(l) > h(r) + 1$ because of the height of the right sub-tree of $l$.
    | _ -> assert false
We assert false for all other patterns as we aim to admit by construction no further possibilities.

We now consider the case $h(r) > h(l) + 1$, that is the right sub-tree being "too long".

  else if hr > hl + 1 then
    match r with

There are two possible reasons.

    (*Case 3*)
    | Node (rl, rv, rd, rr, _) when height rr >= height rl ->
      create (create l x d rl) rv rd rr
Here $h(r) > h(l) + 1$ because of the right sub-tree of $r$.
    (*Case 4*)
    | Node (Node (rll, rlv, rld, rlr, _), rv, rd, rr, _) ->
      create (create l x d rll) rlv rld (create rlr rv rd rr)
Lastly, $h(r) > h(l) + 1$ because of the left sub-tree of $r$.
    | _ -> assert false
Again, all other patterns are (if we write this program correctly according to our intentions,) impossible and so, assert false as there are no further possibilities.

In the last case, neither $h(l) > h(r) + 1$ or $h(r) > h(l) + 1$ so no rotation is required.

    create l x d r

add x data t computes a new tree from t containing a binding of x to data. It resembles standard insertion into a binary search tree except that it propagates rotations through the tree to maintain balance after the insertion.

let rec add (x : α) (data : β) : (α, β) t -> (α, β) t = function
    | Empty -> Node (Empty, x, data, Empty, 1)
    | Node (l, v, d, r, h) ->
      let c = compare x v in
      if c = 0 then
        Node (l, x, data, r, h)
      else if c < 0 then
        balance (add x data l) v d r
        balance l v d (add x data r)

To implement removal of nodes from a tree, we'll find ourselves needing a function to "merge" two binary searchtrees $l$ and $r$ say where we can assume that all the elements of $l$ are ordered before the elements of $r$.

let rec merge (l : (α, β) t) (r : (α, β) t) : (α, β) t = 
  match (l, r) with
  | Empty, t -> t
  | t, Empty -> t
  | Node (l1, v1, d1, r1, h1), Node (l2, v2, d2, r2, h2) ->
    balance l1 v1 d1 (balance (merge r1 l2) v2 d2 r2)
Again, rotations are propagated through the tree to ensure the result of the merge results in a balanced tree.

With merge available, implementing remove becomes tractable.

let remove (id : α) (t : (α, β) t) : (α, β) t = 
  let rec remove_rec = function
    | Empty -> Empty
    | Node (l, k, d, r, _) ->
      let c = compare id k in
      if c = 0 then merge l r else
        if c < 0 then balance (remove_rec l) k d r
        else balance l k d (remove_rec r) in
  remove_rec t

The remaining algorithms below are "stock" algorithms for binary search trees with no particular consideration of balancing necessary and so we won't dwell on them here.

let rec find (x : α) : (α, β) t -> β = function
  | Empty ->  raise Not_found
  | Node (l, v, d, r, _) ->
    let c = compare x v in
    if c = 0 then d
    else find x (if c < 0 then l else r)

let rec mem (x : α) : (α, β) t -> bool = function
  | Empty -> false
  | Node (l, v, d, r, _) ->
    let c = compare x v in
    c = 0 || mem x (if c < 0 then l else r)
let rec iter (f : α -> β -> unit) : (α, β) t -> unit = function
  | Empty -> ()
  | Node (l, v, d, r, _) ->
    iter f l; f v d; iter f r

let rec map (f : α -> β -> γ) : (α, β) t -> (α, γ) t = function
  | Empty -> Empty
  | Node (l, k, d, r, h) -> 
    Node (map f l, k, f k d, map f r, h)

let rec fold (f : α -> β -> γ -> γ) (m : (α, β) t) (acc : γ) : γ =
  match m with
  | Empty -> acc
  | Node (l, k, d, r, _) -> fold f r (f k d (fold f l acc))

open Format

let print 
    (print_key : formatter -> α -> unit)
    (print_data : formatter -> β -> unit)
    (ppf : formatter)
    (tbl : (α, β) t) : unit =
  let print_tbl ppf tbl =
    iter (fun k d -> 
           fprintf ppf "@[<2>%a ->@ %a;@]@ " print_key k print_data d)
      tbl in
  fprintf ppf "@[[[%a]]@]" print_tbl tbl

The source code for this post can be found in the file 'ocaml/misc/' in the OCaml source distribution. More information on balanced binary search trees including similar but different implementation techniques and complexity analyses can be found in this Cornell lecture and this one.

Friday, August 19, 2016

Even Sillier C++

The C++ try...catch construct provides a facility for discrimination of exceptions based on their types. This is a primitive "match" construct. It turns out, this is enough to encode sum types.

The program to follow uses the above idea to implement an interpreter for the language of additive expressions using exception handling for case discrimination.

#include <iostream>
#include <cassert>
#include <exception>
#include <memory>

struct expr {
  virtual ~expr() {}

  virtual void throw_ () const = 0;

using expr_ptr = std::shared_ptr<expr const>;

class expr is an abstract base class, class int_ and class add derived classes corresponding to the two cases of expressions. Sub-expressions are represented as std::shared_ptr<expr> instances.

struct int_ : expr { 
  int val; 
  int_ (int val) : val{val}

  void throw_ () const { throw *this; } 

struct add : expr { 
  expr_ptr left; 
  expr_ptr right; 

  template <class U, class V>
  add (U const& left, V const& right) 
    : left {expr_ptr{new U{left}}}
    , right {expr_ptr{new V{right}}}

  void throw_ () const { throw *this; } 

With the above machinery in place, here then is the "interpreter". It is implemented as a pair of mutually recursive functions.

int eval_rec ();

int eval (expr const& xpr) {
  try {
    xpr.throw_ ();
  catch (...) {
    return eval_rec ();

int eval_rec () {
  assert (std::current_exception());

  try {
  catch (int_ const& i) {
    return i.val;
  catch (add const& op) {
    return eval (*op.left) +  eval (*op.right);

This little program exercises the interpreter on the expression $(1 + 2) + 3$.

int main () {

    // (1 + 2) + 3
    std::cout << eval (add{add{int_{1}, int_{2}}, int_{3}}) << std::endl;
  catch (...){
    std::cerr << "Unhandled exception\n";
  return 0;

Credit to Mathias Gaunard who pointed out using a virtual function for the throwing of an expression, removed the need for explicit dynamic_cast operations in an earlier version of this program.