Saturday, October 31, 2015

C++ : Folds over variadic templates

C++ : Folds over variadic templates

Code like the following motivates the need to compute conjunctions (and disjunctions) of predicate packs.

template <class T, class... Ts>
struct recursive_union {

  // ...

  //'U' is not 'T' but 'T' is a recursive wrapper and 'U' is the type
  //contained in 'T'
  template <class U, class... Args,
  std::enable_if_t<
     and_<
      is_recursive_wrapper<T>
    , std::is_same<U, unwrap_recursive_wrapper_t<T>>>::value, int> = 0
  >
  explicit recursive_union (constructor<U>, Args&&... args)
    noexcept (std::is_nothrow_constructible<U, Args...>::value)
  : v (std::forward<Args>(args)...)
  {}

  // ...

};
I was much helped by a suitable implementation of and_<> credited to Jonathan Wakely. A more general approach to is to use fold.
#include <type_traits>

namespace pgs {

template<class F, class Acc, class... Ts>
struct fold_left : Acc {
};

template <class F, class Acc, class T, class... Ts>
struct fold_left<F, Acc, T, Ts...> : 
    fold_left <F, typename F::template apply<Acc, T>::type, Ts...> {
};

//or

struct or_helper {
  template <class Acc, class T>
  struct apply : std::integral_constant<bool, Acc::value || T::value> {
  };
};

template <class... Ts>
struct or_ : fold_left <or_helper, std::false_type, Ts...> {
};

//and

struct and_helper {
  template <class Acc, class T>
  struct apply : std::integral_constant<bool, Acc::value && T::value> {
  };
};

template <class... Ts>
struct and_ : fold_left <and_helper, std::true_type, Ts...> {
};

}//namespace pgs

Friday, October 9, 2015

Expression algebras (Python)

This is a Python version of the C++ program presented in this earlier blog entry on expression algebras.

import functools
_isconst = \
  lambda x :functools.reduce \
    (lambda acc, c : acc and isinstance (c, _const), x, True)

class _float :
  def __neg__ (self) :
    return _const (-self.f) if _isconst ([self]) else _neg (self)
  def __add__ (self, x) : 
      return _const (self.f + x.f) if _isconst ([self, x]) else \
      x if _isconst ([self]) and self.f == 0 else               \
      self if _isconst ([x]) and x.f == 0 else _add (self, x)
  def __sub__ (self, x) : 
      return _const (self.f - x.f) if _isconst ([self, x]) else \
      const (-x.f) if _isconst ([self]) and self.f == 0 else    \
      self if _isconst ([x]) and x.f == 0.0 else _sub (self, x)
  def __mul__ (self, x) : 
      return _const (self.f * x.f) if _isconst ([self, x]) else \
      x if _isconst ([self]) and self.f == 1 else               \
      self if _isconst ([x]) and x.f == 1 else _mul (self, x)
  def __div__ (self, x) : 
      return _const (self.f / x.f) if _isconst ([self, x]) else \
      self if _isconst([x]) and x.f == 1 else _div (self, x)

class _neg (_float):
  def __init__ (self, f) : self.f = f
  def __str__ (self) : return "-" + "(" + str(self.f) + ")"
class _fix (_float):
  def __init__ (self, d, f) : self.d = d; self.f = f
  def __str__ (self) : return "fix(" + str(self.d) +", " + str(self.f) + ")"
class _add (_float) :
  def __init__ (self, lhs, rhs) : self.lhs = lhs; self.rhs = rhs
  def __str__ (self) : return str(self.lhs)+ " + " + str(self.rhs)
class _sub (_float):
  def __init__ (self, lhs, rhs) : self.lhs = lhs; self.rhs = rhs
  def __str__ (self) : return str(self.lhs)+ " - " + str(self.rhs)
class _mul (_float):
  def __init__ (self, lhs, rhs) : self.lhs = lhs; self.rhs = rhs
  def __str__ (self) : return str (self.lhs)+ " * " + str (self.rhs)
class _div (_float):
  def __init__ (self, lhs, rhs) : self.lhs = lhs; self.rhs = rhs
  def __str__ (self) : return str (self.lhs)+ " / " + str (self.rhs)
class _const (_float):
  def __init__ (self, f) : self.f = f;
  def __str__ (self) : return str (self.f)
class _obs (_float):
  def __init__ (self, tag) : self.tag = tag
  def __str__ (self) : return "observation \"" + str(self.tag) + "\""
class _max (_float):
  def __init__ (self, lhs, rhs) : 
      self.lhs = lhs; self.rhs = rhs
  def __str__ (self) : 
    return "max(" + str (self.lhs) + ", " + str (self.rhs) + ")"
class _min (_float):
  def __init__ (self, lhs, rhs) : 
      self.lhs = lhs; self.rhs = rhs
  def __str__ (self): 
    return "min(" + str (self.lhs) + ", " + str (self.rhs) + ")"

def visit (f, acc, xpr):
  if isinstance (xpr, _const) : return f._const (acc, xpr)
  if isinstance (xpr, _neg) : return f._neg (acc, xpr)
  if isinstance (xpr, _fix) : return f._fix (acc, xpr)
  if isinstance (xpr, _obs) : return f._obs (acc, xpr)
  if isinstance (xpr, _add) : return f._add (acc, xpr)
  if isinstance (xpr, _sub) : return f._sub (acc, xpr)
  if isinstance (xpr, _mul) : return f._mul (acc, xpr)
  if isinstance (xpr, _div) : return f._div (acc, xpr)
  if isinstance (xpr, _max) : return f._max (acc, xpr)
  if isinstance (xpr, _min) : return f._min (acc, xpr)

  raise RuntimeError ("Expression match failure")

const = lambda c : _const (c)
observation = lambda s : _obs (s)
max_ = lambda a, b : _max (a, b)
min_ = lambda a, b : _min (a, b)

def fix (d, x):

  class __fix_visitor:
    def __init__ (self, d) : 
      self.d = d
    def _const (self, _, xpr) : 
      return xpr
    def _obs (self, _, xpr) : 
      return _fix (self.d, xpr)
    def _fix (self, _, xpr) : return xpr
    def _neg (self, _, xpr) : 
      return _neg (visit (self, _, xpr.f))
    def _add (self, _, xpr) : 
      return _add (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _sub (self, _, xpr) : 
      return _sub (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _mul (self, _, xpr) : 
      return _mul (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _div (self, _, xpr) : 
      return _div (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _max (self, _, xpr) : 
      return _max (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _min (self, _, xpr) : 
      return _min (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))

    return visit (__fix_visitor (d), None, x)

def simplify (fs, x):

  class _apply_fixings_visitor :
    def __init__(self, fs) : self.fs = fs
    def _const (self, _, xpr) : return xpr
    def _obs (self, _, xpr) : return xpr
    def _fix (self, _, xpr) : 
      fs = [f for f in self.fs if f[0] == xpr.f.tag and f[1] == xpr.d]
      return xpr if len (fs) == 0 else _const (fs[0][2])
    def _neg (self, _, xpr) : 
      return _neg (visit (self, _, xpr.f))
    def _add (self, _, xpr) : 
      return _add (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _sub (self, _, xpr) : 
      return _sub (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _mul (self, _, xpr) : 
      return _mul (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _div (self, _, xpr) : 
      return _div (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _max (self, _, xpr) : 
      return _max (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))
    def _min (self, _, xpr) : 
      return _min (visit (self, _, xpr.lhs), visit (self, _, xpr.rhs))

  class _simplify_visitor:
    def _const (self, _, xpr) : 
      return xpr
    def _fix (self, _, xpr) : 
      return xpr
    def _obs (self, _, xpr) : 
      return xpr
    def _neg (self, _, xpr) : 
      f = visit (self, _, xpr.f)
      return xpr if not _isconst ([f]) else -f
    def _add (self, _, xpr) : 
      l = visit (self, _, xpr.lhs); r = visit (self, _, xpr.rhs)
      return xpr if not _isconst([l, r]) else const (l.f + r.f)
    def _sub (self, _, xpr) :
      l = visit (self, _, xpr.lhs); r = visit (self, _, xpr.rhs)
      return xpr if not _isconst([l, r]) else const (l.f - r.f)
    def _mul (self, _, xpr) :
      l = visit (self, _, xpr.lhs); r = visit (self, _, xpr.rhs)
      return xpr if not _isconst([l, r]) else const (l.f * r.f)
    def _div (self, _, xpr) :
      l = visit (self, _, xpr.lhs); r = visit (self, _, xpr.rhs)
      return xpr if not _isconst([l, r]) else const (l.f / r.f)
    def _max (self, _, xpr) :
      l = visit (self, _, xpr.lhs); r = visit (self, _, xpr.rhs)
      return xpr if not _isconst([l, r]) else const (max (l.f, r.f))
    def _min (self, _, xpr) :
      l = visit (self, _, xpr.lhs); r = visit (self, _, xpr.rhs)
      return xpr if not _isconst([l, r]) else const (min (l.f, r.f))

  return visit ( \
    _simplify_visitor (), None, visit (_apply_fixings_visitor (fs), None, x))

Sunday, October 4, 2015

List comprehensions in C++ via the list monad

Monads

As explained in Monads for functional programming by Philip Wadler, a monad is a triple $(t, unit, *)$. $t$ is a parametric type, $unit$ and $*$ are operations:

  val unit : α -> α t
  val ( * ) : α t -> (α -> β t) -> β t
We can read expressions like

$m * \lambda\;a.n$

as, "perform computation $m$, bind $a$ to the resulting value, and then perform computation $n$". Referring to the signatures of $*$ and $unit$, in terms of types we see $m$ has the type α t, $\lambda\;a.n$ has type α -> β t and the whole expression has type β t.

In order for $(t, unit, *)$ to be a monad the operations $unit$ and $*$ need satisfy three laws :

  • Left unit. Compute the value $a$, bind $b$ to the result, and compute $n$. The result is the same as $n$ with value $a$ substituted for variable $b$.

    $unit\;a * \lambda\;b.n = n[a/b]$.

  • Right unit. Compute $m$, bind the result to $a$, and return $a$. The result is the same as $m$.

    $m * \lambda\;a.unit\;a = m$.

  • Associative. Compute $m$, bind the result to $a$, compute $n$, bind the result to $b$, compute $o$. The order of parentheses doesn't matter.

    $m * (\lambda\;a.n * \lambda\;b.o) = (m * \lambda\;a.n) * \lambda\;b.o$.

The list monad

Lists can be viewed as monads.That is, there exist operations $unit$ and $*$ that we may define for lists such that the three monad laws from the preceding section hold.

#include <list>
#include <iterator>
#include <type_traits>
#include <algorithm>
#include <iostream>

/*
  The list monad
*/

//The unit list containing 'a'
/*
  let unit : 'a -> 'a t = fun a -> [a]
*/
template <class A> 
std::list<A> unit (A const& a) { return std::list<A> (1u, a); }

//The 'bind' operator
/*
  let rec ( * ) : 'a t -> ('a -> 'b t) -> 'b t =
    fun l -> fun k ->
      match l with | [] -> [] | (h :: tl) -> k h @ tl * k
*/
template <class A, class F>
typename std::result_of<F(A)>::type 
operator * (std::list<A> a, F k) {
  typedef typename std::result_of<F(A)>::type result_t;

  if (a.empty ())
    return result_t ();

  result_t res = k (a.front ());
  a.pop_front ();
  res.splice (res.end (), a * k);

  return res;
}
The invocation $unit\;a$ forms the unit list containing $a$. The expression, $m * k$ applies $k$ to each element of the list $m$ and appends together the resulting lists.

There are well known derived forms. For example, $join\;z$ is the expression $z * \lambda\;m. m$. In the list monad, it results in a function that concatenates a list of lists.

//'join' concatenates a list of lists
/*
    let join : 'a t t z = z * fun m -> m
*/
template <class A>
std::list <A> join (std::list<std::list<A>> const& z) {
  return z * [](auto m) { return m; };
}
The function $map$ is defined by the expression $map\;f\;m = m * \lambda\;a.unit\;(f\;a)$.
//'map' is the equivalent of 'std::transform'
/*
    let map : ('a -> b') -> 'a t -> 'b t =
      fun f -> fun m -> m * fun a -> unit (f a)
*/
template <class A, class F>
std::list<A> map (F f, std::list<A> const& m) {
  return m * [=](auto a) { return unit (f (a)); };
}

List comprehensions

List comprehensions are neatly expressed as monad operations. Here are some examples.
int main () {

  //l = [1, 2, 3]
  std::list<int> l = {1, 2, 3};
  
  //m = [1, 4, 9]
  auto m = l * [](int x) { return unit (float (x * x)); };

  //n = l x m = [(1, 1), (1, 4), (1, 9), (2, 1), (2, 4), (2, 9), ...]
  auto n = l * ([&m](int x){ return m * ([=](float y){ return unit (std::make_pair (x, y)); });});

  return 0;
}