Algebraic Datatypes for C++

Andrew Stitcher

Red Hat

What are Algebraic Datatypes?

Algebraic data types are named such since they correspond to an initial algebra in category theory, giving us some laws, some operations and some symbols to manipulate. We may even use algebraic notation for describing regular data structures...
Don Stewart in an answer on stackoverflow.com http://stackoverflow.com/a/16810/280555

What are Algebraic Datatypes?

Also known as:
  • Sum of product types
  • Discriminated Union

Simple Types

  • Simple enumeration types
  • type boolean = False | True
    
    type primary_colours = Red | Green | Blue
    
  • Simple tuple/record type
  • type complex = float * float
    
    type rational = int * int
    

Enumerations and tuples combined

Simple model for numbers
type number =
    | Int of int
    | Float of float
    | Rational of rational
    | Complex of complex

Enumerations and tuples combined

Calculate absolute magnitude
let magnitude n =
    match n with
    | Int i -> abs_float (float i)
    | Float f -> abs_float f
    | Complex (r,i) -> sqrt(r**2. +. i**2.)
    | Rational (a,b) -> abs_float (float a /. float b)

Simple recursive list type

type 'a mylist = Nil | Cons of 'a * 'a mylist

let seq = Cons(12, Cons(23, Cons(56, Nil)))

More complex recursive type

type expression =
  | Plus of expression * expression
  | Minus of expression * expression
  | Times of expression * expression
  | Divide of expression * expression
  | Power of expression * expression
  | Negate of expression
  | Variable of string
  | Integer of int
  | Float of float

More complex recursive type

let rec eval exp env =
  let lookup env v = env v
  in
  match exp with
    | Plus (e1, e2) -> (eval e1 env) +. (eval e2 env)
    | Minus (e1, e2) -> (eval e1 env) -. (eval e2 env)
    | Times (e1, e2) -> (eval e1 env) *. (eval e2 env)
    | Divide (e1, e2) -> (eval e1 env) /. (eval e2 env)
    | Power (e1, e2) -> (eval e1 env) ** (eval e2 env)
    | Negate (e1) -> ~-. (eval e1 env)
    | Integer i -> float i
    | Float f -> f
    | Variable v -> lookup env v
Note that we model the environment as a function taking the string name of a variable and returning the float value of the variable.

Simple Integer List

type intlist = Cons of int * intlist | Nil

let rec length il = 
  match il with
    | Nil -> 0
    | Cons (_, tl) -> 1 + length(tl)

let seq = Cons (12, Cons (23, Nil))
Note that I'm aware that this definition for length isn't very space efficient because it isn't tail recursive. So it is more efficient to use a function definition that accumulates the length and tail calls a 2 argument version of the length function like so:
let length1 il = 
  let rec len n il =
  match il with
    | Nil -> n
    | Cons (_, tl) -> len (n+1) tl
  in
  len 0 il
But this is more complex and doesn't add much to the exposition that follows except to make it harder to follow because of the extra details.

C++ Datatype Approaches

  • Union
  • Class
  • Boost Variant
These are the options I will be going over for the implementation of the type representation itself. There are independent options for how to implement functions using the data type.

Union for the data

enum intlist_const {Cons, Nil};
struct intlist {
  intlist_const type;
  union {
    tuple<int, intlist_ptr> consdata;
  };

  intlist(intlist_const t): type(t) {}
  ~intlist() {
      switch (type) {
      case Cons:
          consdata.~tuple<int, intlist_ptr>();return;
      case Nil:
          return;
      }
  }
};

Union for the data

Data constructors:
intlist_ptr makeNil() {
  return new intlist(Nil);
}

intlist_ptr makeCons(int i, const intlist_ptr& il) {
  auto v = new intlist(Cons);
  cons(v) = std::tie(i, il);
  return v;
}
Helper Accessors:
tuple<int, intlist_ptr>& cons(const intlist_ptr& il) { 
  return il->consdata;
}

Union for the data

int length(const intlist_ptr& v) {
  switch (v->type) {
  case Nil:
    return 0;
  case Cons:
    return 1 + length(get<1>(cons(v)));
  default:
    throw logic_error("intlist: not all cases covered");
  }
}

Union for the data

Length function using std::tie for destructuring:
int length1(const intlist_ptr& v) {
  switch (v->type) {
  case Nil:
    return 0;
  case Cons: {
    int i; intlist_ptr tl;
    std::tie(i,tl) = cons(v);
    return 1 + length1(tl);
  }
  default:
    throw logic_error("intlist: not all cases covered");
  }
}

Class for the data: RTTI

class intlist {
public:
  virtual ~intlist() {}
};

class Nil: public intlist {
};

class Cons: public intlist {
  tuple<int,intlist_ptr> v;
public:
  Cons(int i0, intlist_ptr il0): v(i0, il0) {}
  const tuple<int, intlist_ptr>& data() const {
    return v;
  }
};

Class for the data: RTTI

int length(intlist_ptr v) {
  if (typeid(*v) == typeid(Nil)) {
    return 0;
  } else if (typeid(*v) == typeid(Cons)) {
    int i; intlist_ptr il;
    tie(i, il) = cons(v);
    return 1 + length(il);
  } else throw logic_error("intlist: not all cases covered");
}

Class for the data: RTTI

int length1(intlist_ptr v) {
  if (dynamic_cast<Nil*>(v)) {
    return 0;
  } else if (auto cons = dynamic_cast<Cons*>(v)) {
    int i; intlist_ptr il;
    tie(i, il) = cons->data();
    return 1 + length1(il);
  } else throw logic_error("intlist: not all cases covered");
}

Class for the data: Handrolled typeinfo

class intlist {
public:
  virtual ~intlist() {}
  enum intlist_const {TCons, TNil};
  virtual intlist_const type() const = 0;
};

class Nil: public intlist {
  intlist_const type() const {return code();}
public:
  static constexpr intlist_const code() {return TNil;}
};

Class for the data: Handrolled typeinfo

int length(intlist_ptr v) {
  switch (v->type()) {
  case Nil::code():
    return 0;
  case Cons::code():
    int i; intlist_ptr il;
    tie(i, il) = cons(v);
    return 1 + length(il);
  default:
    throw logic_error("intlist: not all cases covered");
  }
}

Class for the data: Virtual function

class intlist {
public:
  virtual ~intlist() {}
  virtual int length() const = 0;
};

Class for the data: Virtual function

int length(intlist_ptr v) {
  return v->length();
}
int Nil::length() const {
  return 0;
}
int Cons::length() const {
  return 1 + get<1>(data())->length();
}

Class for the data: Visitor

template <typename T>
class intlistVisitor {
public:
  virtual T operator()(const Nil&) const = 0;
  virtual T operator()(const Cons&) const = 0;
};

Class for the data: Visitor

class intlist {
public:
  virtual ~intlist() {}
  virtual int apply(const intlistVisitor<int>&) = 0;
};

class Nil: public intlist {
  int apply(const intlistVisitor<int>& v) {
    return v(*this);
  }
};

Class for the data: Visitor

class intlistlength: public intlistVisitor<int> {
  int operator() (const Nil&) const {
    return 0;
  }
  int operator() (const Cons& c) const {
    int i; intlist_ptr il;
    tie(i, il) = c.data();
    return 1 + il->apply(*this); // This is recursion
  }
};

int length(intlist_ptr v) {
  return v->apply(intlistlength());
}

Variant

typedef boost::variant<Cons, Nil> intlist;
typedef intlist* intlist_ptr;

class Nil {
};

class Cons {
  tuple<int, intlist_ptr> v;
public:
  Cons(int i0, intlist_ptr il0): v(i0, il0) {}
  const tuple<int, intlist_ptr>& data() const {
    return v;
  }
};

Variant

int length(intlist_ptr v) {
  if ( boost::get<Nil>(v) ) {
    return 0;
  } else if (auto cons = boost::get<Cons>(v)) {
    int i; intlist_ptr il;
    tie(i, il) = cons->data();
    return 1 + length(il);
  } else throw logic_error("intlist: not all cases covered");
}

Variant

class lengthVisitor: public boost::static_visitor<int> {
public:
  int operator()(const Nil&) const {
    return 0;
  }
  int operator()(const Cons& c) const {
    int i; intlist_ptr il;
    tie(i, il) = c.data();
    return 1 + boost::apply_visitor(*this, *il);
  }
};
  
int length1(intlist_ptr v) {
  return boost::apply_visitor(lengthVisitor(), *v);
}

Templated Class

enum class types {Cons, Nil};

class intlist :
  public ADatatypeBase<types, Visitor<int,Nil,Cons>>
{
public:
  virtual int length() const = 0;
};

Templated Class

class Nil:
  public ADatatype<
    intlist, types::Nil, Nil, void>
{
  int length() const;
};

class Cons:
  public ADatatype <
    intlist, types::Cons, Cons, tuple<int, intlist_ptr>>
{
  int length() const;

public:
  Cons(int i0, const intlist_ptr& il0):
    ADatatype(i0, il0)
  {}
};

Performance Comparison

Union1.28
UnionUsing std::tie & shared_ptr3.01
ClassVirtual member1.00
ClassRTTI - typeinfo16.39
ClassRTTI - dynamic_cast26.73
ClassHandrolled typeinfo1.86
ClassVisitor1.65
VariantUsing get<>2.42
VariantUsing visitor1.29

This was tested by doing the length calculation 1,000,000 times for a list of length 2.

Except where noted the figures are nearly identical for raw pointers and shared_ptr.