Polyomorphic Recursion and constexpr if

November 28, 2021

13 mins to read📚

Polymorphic Recursion

Recursion empowers us with great expressiveness in many areas; static typing is no exception:

// Linked list
template <typename T>
struct ListNode {
  T elem;
  std::shared_ptr<ListNode<T>> nextptr;
};

// Tree
template <typename T>
struct TreeNode {
  T elem;
  std::shared_ptr<TreeNode<T>> left, right;
};

With a linked list and a tree defined this way, we can delay the construction of the data structure until run time: ListNode and TreeNode just define the node and the edge of the graph, and we can construct and modify the graph at run time.

Sounds good! Why not more complicated recursive data structure?

template <typename T>
struct PowerNode {
  std::optional<T> elem;
  std::shared_ptr<PowerNode<std::pair<T, T>>> nextptr;
};

Here, PowerNode<T> represents a node where the type of its element is std::optional<T>, and the type of its next node is PowerNode<std::pair<T, T>>. Therefore, the type of each node differs unlike in the case of ListNode and TreeNode.

Now, let me decorate this bare-bone data structure with some member functions:

// COMPILATION FAILS!
template <typename T>
class PowerNode {
public:
  PowerNode() = default;
  PowerNode(std::optional<T> elem): elem(elem) {}
  PowerNode(std::optional<T> elem, std::shared_ptr<PowerNode<std::pair<T, T>>> nextptr): 
    elem(elem), nextptr(nextptr) {}
  PowerNode(std::optional<T> elem, PowerNode<std::pair<T,T>> nextNode):
    elem(elem), nextptr(std::make_shared<PowerNode<std::pair<T,T>>>(nextNode)) {}

  T head() const {
    if (elem) return elem.value();
    else return nextptr->head().first;
  }
  PowerNode cons(T x) const {
    if (!elem) return PowerNode(x, nextptr);
    std::optional<T> new_elem;
    auto carried_elem = std::make_pair(x, elem.value());
    if (nextptr) {
      return PowerNode(new_elem, nextptr->cons(carried_elem));
    } else {
      return PowerNode(new_elem, PowerNode<std::pair<T,T>>(carried_elem));
    }
  }
  size_t size() const {
    return (elem ? 1 : 0) + (nextptr ? 2 * nextptr->size() : 0);
  }

private:
  std::optional<T> elem;
  std::shared_ptr<PowerNode<std::pair<T, T>>> nextptr;
};

Here, PowerNode behaves exactly the same way as ListNode with head, cons and size, except for the difference of the underlying representation and the time efficiency.

Unfortunately, the code above fails to compile. The problem is that, the member function, say head, of PowerNode<T>, calls head of PowerNode<std::pair<T,T>>, which in turn calls head of PowerNode<std::pair<std::pair<T,T>,std::pair<T,T>>>, so on, so forth, indefinitely. Since the C++ compiler tries to generate all the template classes encountered, the compilation never ends.

This kind of polymorphic functions, where the template function recursively calls itself with a different template parameter, is known as polymorphic recursion. When we implement the data structure where the underlying representation possesses polymorphically recursive structure like PowerNode, the member function of it is often polymorphic recursion.

Q&A

Now, let me answer two questions raised by myself:

Q. PowerNode example seems useless. It doesn’t even have a tail operation!

A. I just wanted to demonstrate the problem with a simple toy example. For those interested, I will implement PowerNode with tail operation in Appendix.

Q. Why do we need to care about the polymorphic recursion?

A. The usefulness of polymorphic recursion is discussed in Okasaki’s seminal book, Purely Functional Data Structures. Here I will focus only on the implementation of the polymorphic recursion in C++, because of the limitation of space.

How to tame polymorphic recursion

As far as I am aware, there are two ways to implement polymorphic recursion in C++:

  1. Use expression template and construct the data structure at compile time
  2. Use constexpr if to introduce the cutoff for the number of polymorphic recursion depth, and delay the construction of the data structure until run time

Solution 1: expression templates

As for the former solution, I first saw this idea at this comment by Mathias Gaunard on Bartosz Milewski’s blog post. In short, expression template keeps track of the structure of the class in their template parameter, just like std::array where the size of the array is embedded in their template parameter. For details, see C++ templates book.

Therefore, the expression template is best fit for the numerical calculation, where all the input parameters are fixed at compile time, and the data structure is also constructed at compile time. However, if we want to construct the data structure based on the file input or the user interaction given at run time, the expression template cannot be used.

I will not discuss this solution in detail here, and instead focus on the latter solution.

Solution 2: constexpr if and cutoff

As for the latter solution, I am not aware of the discussion about this solution in the context of C++ implementation, but the idea itself seems old. According to the paper by Vasconcellos, Figueiredo, Camarão, this solution was employed in the Mercury programming language.

I think the idea becomes clear once you look at code, so let me remedy the problem of PowerNode. First, we define the recursive pair type as follows:

template <size_t I,typename T> 
struct pair_depth_n {
  template<typename... Args>
  using type_each = typename pair_depth_n<I-1, T>::template type<T, Args...>;

  template<typename... Args> 
  using type = std::pair<type_each<Args...>, type_each<Args...>>;
};

template <typename T> 
struct pair_depth_n<0, T> {
    template <typename... Args>
    using type = T;
};

template <size_t I,typename T>  
using nested_pair = typename pair_depth_n<I,T>::template type<>;

The code is a bit messy, but you don’t need to understand exactly what I am doing. What is important here is the meaning of nested_pair. For example,

nested_pair<0, int> = int
nested_pair<1, int> = std::pair<int, int>
nested_pair<2, int> = std::pair<std::pair<int, int>, std::pair<int, int>>
...

and so on.

Now, we rewrite PowerNode as follows:

// WARNING: If you set large cutoff, like >~ 18,
// compiler and LSP eat up much cpu and memory
constexpr size_t cutoff = 10;
// COMPILES SUCCESSFULLY!
template <typename T, typename BaseT>
class PowerNodeBase {
public:
  PowerNodeBase() = default;
  PowerNodeBase(std::optional<T> elem): elem(elem) {}
  PowerNodeBase(std::optional<T> elem, std::shared_ptr<PowerNodeBase<std::pair<T, T>, BaseT>> nextptr): 
    elem(elem), nextptr(nextptr) {}
  PowerNodeBase(std::optional<T> elem, PowerNodeBase<std::pair<T,T>, BaseT> nextNode):
    elem(elem), nextptr(std::make_shared<PowerNodeBase<std::pair<T,T>, BaseT>>(nextNode)) {}

  T head() const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {      throw std::runtime_error("cutoff exceeded!");      return T {};    } else {
      if (elem) return elem.value();
      else return nextptr->head().first;
    }
  }
  PowerNodeBase cons(T x) const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {      throw std::runtime_error("cutoff exceeded!");      return PowerNodeBase {};    } else {
      if (!elem) return PowerNodeBase(x, nextptr);
      std::optional<T> new_elem;
      auto carried_elem = std::make_pair(x, elem.value());
      if (nextptr) {
        return PowerNodeBase(new_elem, nextptr->cons(carried_elem));
      } else {
        return PowerNodeBase(new_elem, PowerNodeBase<std::pair<T,T>, BaseT>(carried_elem));
      }
    }
  }
  size_t size() const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {      throw std::runtime_error("cutoff exceeded!");      return size_t {};    } else {
      return (elem ? 1 : 0) + (nextptr ? 2 * nextptr->size() : 0);
    }
  }

private:
  std::optional<T> elem;
  std::shared_ptr<PowerNodeBase<std::pair<T, T>, BaseT>> nextptr;
};

template <typename T>
using PowerNode = PowerNodeBase<T, T>;

As you can see, I added the if constexpr in all the member function which is polymorphically recursive. When the recursion depth reaches the predetermined value cutoff, it simply throws runtime error. Since the number of generated template classes is now bounded, the compilation succeed with no problem!🎉

PowerNode can be used exactly in the same way as the usual linked list and tree structure. If we perform cons 2^10 = 1024 times at run time, an error is raised. We need to increase the cutoff parameter if this limit is unsatisfactory.

Pros

Construction of Data Structure at Run time

As we mentioned, we can construct the PowerNode at run time:

int main() {
  std::cout << "enter test size: ";
  size_t upper_limit;
  std::cin >> upper_limit;
  if (std::cin.fail()) {
    std::cout << "please enter positive integer, exit" << std::endl;
    exit(1);
  }
  auto p = PowerNode<int> {};  for (int i = 0; i < upper_limit; ++i) {
    std::cout << "enter value to cons: ";
    int runtime_value;
    std::cin >> runtime_value;
    if (std::cin.fail()) {
      std::cout << "please enter integer, exit" << std::endl;
      exit(1);
    }
    p = p.cons(runtime_value);    if (runtime_value != p.head()) throw std::runtime_error("Error in head!");    if (i+1 != p.size()) throw std::runtime_error("Error in size!");  }
  std::cout << "test passed!" << std::endl;
}

So, PowerNode can be used exactly the same way as linked list and tree.

Cons

Code Bloat and Blowup of Compilation Time

Since all the classes used by polymorphic recursion are generated, the executable size grows as we increase the cutoff size. In general, the asymptotic scaling form of the executable size versus the cutoff size depends on the form of the polymorphic recursion.

In the case of PowerNode, theoretically it should scale linearly, since the number of classes generated scales linearly with the cutoff number. I performed a simple benchmarking (on my ThinkPad with Intel Core i7-8650U CPU @ 1.90GHz, 2.11 GHz and 16GB of RAM), which admittedly is not sophisticated but is enough to see the scaling, with the following script:

#!/bin/zsh

cppfname=$(echo 'main.cpp')
placeholder=$(echo 'COFFVALUE')
timefname=$(echo 'time.txt')
memfname=$(echo 'memory.txt')

rm $timefname 2> /dev/null
rm $memfname 2> /dev/null

max=22
for (( i=2; i <= $max; ++i))
do
  (time cat $cppfname | sed -e "s/$placeholder/$i/g" | g++-10 -std=c++20 -x c++ -) |& sed -n /g++-10/p | awk '{ print $6 }' >> $timefname
  wc -c < a.out >> $memfname
done

where we replaced the cutoff value with a placeholder in C++ code,

-  constexpr size_t cutoff = 10;
+  constexpr size_t cutoff = COFFVALUE;

and generated PowerNode template class with the main function:

int main() {
  auto p = PowerNode<int> {};
  p = p.cons(1);
}

executable size
executable size
The result of the benchmark is shown above. It scales linearly well. Therefore, in the case of PowerNode, the size of executable does not grow drastically, but the situation may be different for other data structures.

compile time in linear scale
compile time in linear scale
compile time in log scale
compile time in log scale
I also measured the compilation time with time command, and the result is shown as above. Since I measured the compilation time only once, there is some fluctuation in the data. It seems the compilation time grows exponentially (or even faster) with the cutoff size. This is not surprising, since nested_pair<n,T> consists of 2^n Ts.

If we can do the memoization of nested_pair<n,T> at compile time, we can reduce the compilation time, but I am not sure whether that is possible.

Summary

In this post, I discussed the possible implementation of polymorphic recursion in C++ with constexpr if. Unlike the implementation with expression templates, we can delay the construction of the data structure until run time. Since C++ is statically typed language, the large amount of the code generation caused by polymorphic recursion led to larger executable size and longer compile time.

I guess you are now like, “So, is this polymorphic recursion stuff useful in real-world?”, right? Well, the good thing about the polymorphically recursive data structures is, more programmer errors can be catched by the type system, as is discussed in Okasaki’s book and Wikipedia article. Therefore, the implementation with constexpr if may be useful when you come up with some brand-new data structure, which can be converted to polymorphically recursive form, and you want to implement it in C++.

The take-home message is that, polymorphic recursion is not impossible in statically typed language. There is some limit on the number of recursions, but that is also the case with the usual recusive function. If you want to do something interesting with polymorphic recursion in statically typed language which supports compile-time calculation, you can.

Appendix: Complete Code of PowerNode with tail Operation

I checked the code below compiles successfully with g++ 10.3.0 on Ubuntu 20.04 with -std=c++20:

#include <memory>
#include <optional>
#include <tuple>
#include <iostream>
#include <stdexcept>
#include <type_traits>

template <size_t I,typename T> 
struct pair_depth_n {
  template<typename... Args>
  using type_each = typename pair_depth_n<I-1, T>::template type<T, Args...>;

  template<typename... Args> 
  using type = std::pair<type_each<Args...>, type_each<Args...>>;
};

template <typename T> 
struct pair_depth_n<0, T> {
    template <typename... Args>
    using type = T;
};

template <size_t I,typename T>  
using nested_pair = typename pair_depth_n<I,T>::template type<>;

constexpr size_t cutoff = 10;

template <typename T, typename BaseT>
class PowerNodeBase {
public:
  PowerNodeBase() = default;
  PowerNodeBase(std::optional<T> elem): elem(elem) {}
  PowerNodeBase(std::optional<T> elem, std::shared_ptr<PowerNodeBase<std::pair<T, T>, BaseT>> nextptr): 
    elem(elem), nextptr(nextptr) {}
  PowerNodeBase(std::optional<T> elem, PowerNodeBase<std::pair<T,T>, BaseT> nextNode):
    elem(elem), nextptr(std::make_shared<PowerNodeBase<std::pair<T,T>, BaseT>>(nextNode)) {}

  T head() const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {
      throw std::runtime_error("cutoff exceeded!");
      return T {};
    } else {
      if (elem) return elem.value();
      else return nextptr->head().first;
    }
  }
  PowerNodeBase cons(T x) const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {
      throw std::runtime_error("cutoff exceeded!");
      return PowerNodeBase {};
    } else {
      if (!elem) return PowerNodeBase(x, nextptr);
      std::optional<T> new_elem;
      auto carried_elem = std::make_pair(x, elem.value());
      if (nextptr) {
        return PowerNodeBase(new_elem, nextptr->cons(carried_elem));
      } else {
        return PowerNodeBase(new_elem, PowerNodeBase<std::pair<T,T>, BaseT>(carried_elem));
      }
    }
  }
  size_t size() const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {
      throw std::runtime_error("cutoff exceeded!");
      return size_t {};
    } else {
      return (elem ? 1 : 0) + (nextptr ? 2 * nextptr->size() : 0);
    }
  }
  PowerNodeBase tail() const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {
      throw std::runtime_error("cutoff exceeded!");
      return PowerNodeBase {};
    } else {
      if (elem) {
        return PowerNodeBase(std::nullopt, nextptr);
      } else {
        auto [_, next_node] = decr();
        return next_node;
      }
    }
  }

private:
  template<typename U, typename baseU>
  friend class PowerNodeBase;

  std::tuple<T, PowerNodeBase> decr() const {
    if constexpr (std::is_same_v<T, nested_pair<cutoff, BaseT>>) {
      throw std::runtime_error("cutoff exceeded!");
      return {};
    } else {
      if (nextptr->elem) {
        auto new_next_node = PowerNodeBase<std::pair<T,T>, BaseT>(std::nullopt, nextptr->nextptr);
        return {nextptr->elem.value().first, PowerNodeBase(nextptr->elem.value().second, new_next_node)};
      } else {
        auto [decr_elem, nextNode] = nextptr->decr();
        return {decr_elem.first, PowerNodeBase(decr_elem.second, nextNode)};
      }
    }
  }

  std::optional<T> elem;
  std::shared_ptr<PowerNodeBase<std::pair<T, T>, BaseT>> nextptr;
};

template <typename T>
using PowerNode = PowerNodeBase<T, T>;

int main() {
  auto p = PowerNode<int> {};
  for (int i = 0; i < 1000; ++i) {
    p = p.cons(i);
    if (i != p.head()) throw std::runtime_error("Error in head!");
    if (i+1 != p.size()) throw std::runtime_error("Error in size!");
  }
  for (int i = 999; i >= 0; --i) {
    if (i+1 != p.size()) throw std::runtime_error("Error in size!");
    if (i != p.head()) throw std::runtime_error("Error in head!");
    p = p.tail();
  }
  std::cout << "test passed!" << std::endl;
}