C++ : Sums with constructors
I've been working recently on a type to model "sums with constructors" in C++ (ala OCaml). The implementation technique is "novel" in that it makes use of C++11's "unrestricted unions" feature. I learned it from the FTL library where the idea is credited to Björn Aili. FTL also shows how to provide a NEAT (for C++) syntax for pattern matching but, unless I just didn't get it, the FTL version doesn't admit recursive types "out-of-the-box". So, I extended Björn's work to admit recursive types by applying the recursive wrapper idea from Boost.Variant (Eric Friedman, Itay Maman). The resulting library, I call the "pretty good sum" library. It's C++14 but can be back-ported to C++11 (update : that's done and a lot of bug-fixes to). The code is online here if you want to play with it in your own programs.
There are a number of usage examples provided in the library tests/documentation. I'll provide a small one here - the ubiquitous option<>
type (c.f. Boost.Optional and OCaml's builtin type α option
).
In OCaml, the type definition is given by
type α option = Some of α | Nonewhich is not recursive (see the other examples on github for that e.g. functional lists, abstract syntax trees) but I hope this example is still interesting in that it explores the type's monadic nature to implement so called "safe-arithmetic", that is, integer arithmetic that guards against overflow and division by zero (source : "Ensure that operations on signed integers do not result in overflow"). See this post for more on monads in C++.
The code in the example is fairly extensively commented so I hope you will excuse me this time if I don't provide my usual narrative (I've presented this program before in a Felix tutorial - there's a narrative there note to self : and some typos that I mean to get back to and fix).
Without further ado... A type for optional values in C++ using the "pretty good sum" type!
#include <pgs/pgs.hpp> #include <gtest/gtest.h> #include <iostream> #include <cstdlib> #include <climits> #include <functional> //type 'a t = Some of 'a | None namespace { using namespace pgs; template <class T> struct some_t { //Case 1 T data; template <class U> explicit some_t (U&& data) : data { std::forward<U> (data) } {} }; struct none_t //Case 2 {}; //Options are a type that can either hold a value of type `none_t` //(undefined) or `some_t<T>` template<class T> using option = sum_type<some_t<T>, none_t>; //is_none : `true` if a `some_t<>`, `false` otherwise template<class T> bool is_none (option<T> const& o) { return o.template is<none_t> (); } //A trait that can "get at" the type `T` contained by an option template <class> struct option_value_type; template <class T> struct option_value_type<option<T>> { typedef T type; }; template <class T> using option_value_type_t = typename option_value_type<T>::type; //Factory function for case `none_t` template <class T> option<T> none () { return option<T>{constructor<none_t>{}}; } //Factory function for case `some_t<>` template <class T> option<decay_t<T>> some (T&& val) { using t = decay_t<T>; return option<t>{constructor<some_t<t>>{}, std::forward<T> (val)}; } //is_some : `false` if a `none_t`, `true` otherwise template<class T> inline bool is_some (option<T> const& o) { return o.template is<some_t<T>>(); } //Attempt to get a `const` reference to the value contained by an //option template <class T> T const& get (option<T> const & u) { return u.template match<T const&> ( [](some_t<T> const& o) -> T const& { return o.data; }, [](none_t const&) -> T const& { throw std::runtime_error {"get"}; } ); } //Attempt to get a non-`const` reference to the value contained by an //option template <class T> T& get (option<T>& u) { return u.template match<T&> ( [](some_t<T>& o) -> T& { return o.data; }, [](none_t&) -> T& { throw std::runtime_error {"get"}; } ); } //`default x (Some v)` returns `v` and `default x None` returns `x` template <class T> T default_ (T x, option<T> const& u) { return u.template match<T> ( [](some_t<T> const& o) -> T { return o.data; }, [=](none_t const&) -> T { return x; } ); } //`map_default f x (Some v)` returns `f v` and `map_default f x None` //returns `x` template<class F, class U, class T> auto map_default (F f, U const& x, option<T> const& u) -> U { return u.template match <U> ( [=](some_t<T> const& o) -> U { return f (o.data); }, [=](none_t const&) -> U { return x; } ); } //Option monad 'bind' template<class T, class F> auto operator * (option<T> const& o, F k) -> decltype (k (get (o))) { using result_t = decltype (k ( get (o))); using t = option_value_type_t<result_t>; return o.template match<result_t> ( [](none_t const&) -> result_t { return none<t>(); }, [=](some_t<T> const& o) -> result_t { return k (o.data); } ); } //Option monad 'unit' template<class T> option<decay_t<T>> unit (T&& a) { return some (std::forward<T> (a)); } //map template <class T, class F> auto map (F f, option<T> const& m) -> option<decltype (f (get (m)))>{ using t = decltype (f ( get (m))); return m.template match<option<t>> ( [](none_t const&) -> option<t> { return none<t>(); }, [=](some_t<T> const& o) -> option<t> { return some (f (o.data)); } ); } }//namespace<anonymous> TEST (pgs, option) { ASSERT_EQ (get(some (1)), 1); ASSERT_THROW (get (none<int>()), std::runtime_error); auto f = [](int i) { //avoid use of lambda in unevaluated context return some (i * i); }; ASSERT_EQ (get (some (3) * f), 9); auto g = [](int x) { return x * x; }; ASSERT_EQ (get (map (g, some (3))), 9); ASSERT_TRUE (is_none (map (g, none<int>()))); ASSERT_EQ (default_(1, none<int>()), 1); ASSERT_EQ (default_(1, some(3)), 3); auto h = [](int y) -> float{ return float (y * y); }; ASSERT_EQ (map_default (h, 0.0, none<int>()), 0.0); ASSERT_EQ (map_default (h, 0.0, some (3)), 9.0); } namespace { //safe "arithmetic" std::function<option<int>(int)> add (int x) { return [=](int y) -> option<int> { if ((x > 0) && (y > INT_MAX - x) || (x < 0) && (y < INT_MIN - x)) { return none<int>(); //overflow } return some (y + x); }; } std::function<option<int>(int)> sub (int x) { return [=](int y) -> option<int> { if ((x > 0) && (y < (INT_MIN + x)) || (x < 0) && (y > (INT_MAX + x))) { return none<int>(); //overflow } return some (y - x); }; } std::function<option<int>(int)> mul (int x) { return [=](int y) -> option<int> { if (y > 0) { //y positive if (x > 0) { //x positive if (y > (INT_MAX / x)) { return none<int>(); //overflow } } else { //y positive, x nonpositive if (x < (INT_MIN / y)) { return none<int>(); //overflow } } } else { //y is nonpositive if (x > 0) { // y is nonpositive, x is positive if (y < (INT_MIN / x)) { return none<int>(); } } else { //y, x nonpositive if ((y != 0) && (x < (INT_MAX / y))) { return none<int>(); //overflow } } } return some (y * x); }; } std::function<option<int>(int)> div (int x) { return [=](int y) { if (x == 0) { return none<int>();//division by 0 } if (y == INT_MIN && x == -1) return none<int>(); //overflow return some (y / x); }; } }//namespace<\anonymous> TEST(pgs, safe_arithmetic) { //2 * (INT_MAX/2) + 1 (won't overflow since `INT_MAX` is odd and //division will truncate) ASSERT_EQ (get (unit (INT_MAX) * div (2) * mul (2) * add (1)), INT_MAX); // //2 * (INT_MAX/2 + 1) (overflow) ASSERT_TRUE (is_none (unit (INT_MAX) * div (2) * add (1) * mul (2))); // //INT_MIN/(-1) ASSERT_TRUE (is_none (unit (INT_MIN) * div (-1))); }