// vi:set ft=cpp: -*- Mode: C++ -*-
/**
 * \file
 * \brief AVL tree
 */
/*
 * (c) 2008-2009 Alexander Warg <warg@os.inf.tu-dresden.de>,
 *               Carsten Weinhold <weinhold@os.inf.tu-dresden.de>
 *     economic rights: Technische Universität Dresden (Germany)
 *
 * This file is part of TUD:OS and distributed under the terms of the
 * GNU General Public License 2.
 * Please see the COPYING-GPL-2 file for details.
 *
 * As a special exception, you may use this file as part of a free software
 * library without restriction.  Specifically, if other files instantiate
 * templates or use macros or inline functions from this file, or you compile
 * this file and link it with other files to produce an executable, this
 * file does not by itself cause the resulting executable to be covered by
 * the GNU General Public License.  This exception does not however
 * invalidate any other reasons why the executable file might be covered by
 * the GNU General Public License.
 */

#pragma once

#include "std_ops"
#include "pair"

#include "bits/bst.h"
#include "bits/bst_iter.h"

namespace cxx {

/**
 * \brief Node of an AVL tree.
 */
class Avl_tree_node : public Bits::Bst_node
{
private:
  template< typename Node, typename Get_key, typename Compare >
  friend class Avl_tree;

  /// Shortcut for Balance values (we use Direction for that).
  typedef Bits::Direction Bal;
  /// Alias for Direction.
  typedef Bits::Direction Dir;

  // We are a final BST node, hide interior.
  /**@{*/
  using Bits::Bst_node::next;
  using Bits::Bst_node::next_p;
  using Bits::Bst_node::rotate;
  /**@}*/

  /// The balance value (#Direction::N is balanced).
  Bal _balance;

protected:
  /// Create an uninitialized node, this is what you should do.
  Avl_tree_node() = default;

private:
  Avl_tree_node(Avl_tree_node const &o) = delete;
  Avl_tree_node(Avl_tree_node &&o) = delete;

  /// default copy for friend Avl_tree
  Avl_tree_node &operator = (Avl_tree_node const &o) = default;
  /// default move for friend Avl_tree
  Avl_tree_node &operator = (Avl_tree_node &&o) = default;

  /// Create an initialized node (for internal stuff).
  explicit Avl_tree_node(bool) : Bits::Bst_node(true), _balance(Dir::N) {}

  /// Double rotation of \a t.
  static Bits::Bst_node *rotate2(Bst_node **t, Bal idir, Bal pre);

  /// Is this subtree balanced?
  bool balanced() const { return _balance == Bal::N; }

  /// What is the balance of this subtree?
  Bal balance() const { return _balance; }

  /// Set the balance of this subtree to \a b.
  void balance(Bal b) { _balance = b; }
};


/**
 * \brief A generic AVL tree.
 * \tparam Node    The data type of the nodes (must inherit from Avl_tree_node).
 * \tparam Get_key The meta function to get the key value from a node.
 *                 The implementation uses `Get_key::key_of(ptr_to_node)`. The
 *                 type of the key values must be defined in `Get_key::Key_type`.
 * \tparam Compare Binary relation to establish a total order for the
 *                 nodes of the tree. `Compare()(l, r)` must return true if
 *                 the key \a l is smaller than the key \a r.
 *
 * This implementation does not provide any memory management. It is the
 * responsibility of the caller to allocate nodes before inserting them and
 * to free them when they are removed or when the tree is destroyed.
 * Conversely, the caller must also ensure that nodes are removed
 * from the tree before they are destroyed.
 */
template< typename Node, typename Get_key,
          typename Compare = Lt_functor<typename Get_key::Key_type> >
class Avl_tree : public Bits::Bst<Node, Get_key, Compare>
{
private:
  typedef Bits::Bst<Node, Get_key, Compare> Bst;

  /// Hide this from possible descendants.
  using Bst::_head;

  /// Provide access to keys of nodes.
  using Bst::k;

  /// Alias type for balance values.
  typedef typename Avl_tree_node::Bal Bal;
  /// Alias type for Direction values.
  typedef typename Avl_tree_node::Bal Dir;

  Avl_tree(Avl_tree const &o) = delete;
  Avl_tree &operator = (Avl_tree const &o) = delete;
  Avl_tree(Avl_tree &&o) = delete;
  Avl_tree &operator = (Avl_tree &&o) = delete;

public:
  ///@{
  typedef typename Bst::Key_type Key_type;
  typedef typename Bst::Key_param_type Key_param_type;
  ///@}

  // Grab iterator types from Bst
  ///@{
  /// Forward iterator for the tree.
  typedef typename Bst::Iterator Iterator;
  /// Constant forward iterator for the tree.
  typedef typename Bst::Const_iterator Const_iterator;
  /// Backward iterator for the tree.
  typedef typename Bst::Rev_iterator Rev_iterator;
  /// Constant backward iterator for the tree.
  typedef typename Bst::Const_rev_iterator Const_rev_iterator;
  ///@}

  /**
   * \brief Insert a new node into this AVL tree.
   * \param new_node A pointer to the new node.
   * \return A pair, with \a second set to `true` and \a first pointing to
   *         \a new_node, on success.  If there is already a node with the
   *         same key then \a first points to this node and \a second is 'false'.
   */
  Pair<Node *, bool> insert(Node *new_node);

  /**
   * \brief Remove the node with \a key from the tree.
   * \param key The key to the node to remove.
   * \return The pointer to the removed node on success,
   *         or 0 if no node with the \a key exists.
   */
  Node *remove(Key_param_type key);
  /**
   * \brief An alias for remove().
   */
  Node *erase(Key_param_type key) { return remove(key); }

  /// Create an empty AVL tree.
  Avl_tree() = default;

  /// Destroy the tree.
  ~Avl_tree() noexcept
  {
    this->remove_all([](Node *){});
  }

#ifdef __DEBUG_L4_AVL
  bool rec_dump(Avl_tree_node *n, int depth, int *dp, bool print, char pfx);
  bool rec_dump(bool print)
  {
    int dp=0;
    return rec_dump(static_cast<Avl_tree_node *>(_head), 0, &dp, print, '+');
  }
#endif
};


//----------------------------------------------------------------------------
/* IMPLEMENTATION: Bits::__Bst_iter_b */


inline
Bits::Bst_node *
Avl_tree_node::rotate2(Bst_node **t, Bal idir, Bal pre)
{
  typedef Bits::Bst_node N;
  typedef Avl_tree_node A;
  N *tmp[2] = { *t, N::next(*t, idir) };
  *t = N::next(tmp[1], !idir);
  A *n = static_cast<A*>(*t);

  N::next(tmp[0], idir, N::next(n, !idir));
  N::next(tmp[1], !idir, N::next(n, idir));
  N::next(n, !idir, tmp[0]);
  N::next(n, idir, tmp[1]);

  n->balance(Bal::N);

  if (pre == Bal::N)
    {
      static_cast<A*>(tmp[0])->balance(Bal::N);
      static_cast<A*>(tmp[1])->balance(Bal::N);
      return 0;
    }

  static_cast<A*>(tmp[pre != idir])->balance(!pre);
  static_cast<A*>(tmp[pre == idir])->balance(Bal::N);

  return N::next(tmp[pre == idir], !pre);
}

//----------------------------------------------------------------------------
/* Implementation of AVL Tree */

/* Insert new _Node. */
template< typename Node, typename Get_key, class Compare>
Pair<Node *, bool>
Avl_tree<Node, Get_key, Compare>::insert(Node *new_node)
{
  typedef Avl_tree_node A;
  typedef Bits::Bst_node N;
  N **t = &_head; /* search variable */
  N **s = &_head; /* node where rebalancing may occur */
  Key_param_type new_key = Get_key::key_of(new_node);

  // search insertion point
  for (N *p; (p = *t);)
    {
      Dir b = this->dir(new_key, p);
      if (b == Dir::N)
	return pair(static_cast<Node*>(p), false);

      if (!static_cast<A const *>(p)->balanced())
	s = t;

      t = A::next_p(p, b);
    }

  *static_cast<A*>(new_node) = A(true);
  *t = new_node;

  N *n = *s;
  A *a = static_cast<A*>(n);
  if (!a->balanced())
    {
      A::Bal b(this->greater(new_key, n));
      if (a->balance() != b)
	{
	  // ok we got in balance the shorter subtree go higher
	  a->balance(Bal::N);
	  // propagate the new balance down to the new node
	  n = A::next(n, b);
	}
      else if (b == Bal(this->greater(new_key, A::next(n, b))))
	{
	  // left-left or right-right case -> single rotation
	  A::rotate(s, b);
	  a->balance(Bal::N);
	  static_cast<A*>(*s)->balance(Bal::N);
	  n = A::next(*s, b);
	}
      else
	{
	  // need a double rotation
	  n = A::next(A::next(n, b), !b);
	  n = A::rotate2(s, b, n == new_node ? Bal::N : Bal(this->greater(new_key, n)));
	}
    }

  for (A::Bal b; n && n != new_node; static_cast<A*>(n)->balance(b), n = A::next(n, b))
    b = Bal(this->greater(new_key, n));

  return pair(new_node, true);
}


/* remove an element */
template< typename Node, typename Get_key, class Compare>
inline
Node *Avl_tree<Node, Get_key, Compare>::remove(Key_param_type key)
{
  typedef Avl_tree_node A;
  typedef Bits::Bst_node N;
  N **q = &_head; /* search variable */
  N **s = &_head; /* last ('deepest') node on the search path to q
                   * with balance 0, at this place the rebalancing
                   * stops in any case */
  N **t = 0;
  Dir dir;

  // find target node and rebalancing entry
  for (N *n; (n = *q); q = A::next_p(n, dir))
    {
      dir = Dir(this->greater(key, n));
      if (dir == Dir::L && !this->greater(k(n), key))
	/* found node */
	t = q;

      if (!A::next(n, dir))
	break;

      A const *a = static_cast<A const *>(n);
      if (a->balanced() || (a->balance() == !dir && A::next<A>(n, !dir)->balanced()))
	s = q;
    }

  // nothing found
  if (!t)
    return 0;

  A *i = static_cast<A*>(*t);

  for (N *n; (n = *s); s = A::next_p(n, dir))
    {
      dir = Dir(this->greater(key, n));

      if (!A::next(n, dir))
	break;

      A *a = static_cast<A*>(n);
      // got one out of balance
      if (a->balanced())
	a->balance(!dir);
      else if (a->balance() == dir)
	a->balance(Bal::N);
      else
	{
	  // we need rotations to get in balance
	  Bal b = A::next<A>(n, !dir)->balance();
	  if (b == dir)
	    A::rotate2(s, !dir, A::next<A>(A::next(n, !dir), dir)->balance());
	  else
	    {
	      A::rotate(s, !dir);
	      if (b != Bal::N)
		{
		  a->balance(Bal::N);
		  static_cast<A*>(*s)->balance(Bal::N);
		}
	      else
		{
		  a->balance(!dir);
		  static_cast<A*>(*s)->balance(dir);
		}
	    }
	  if (n == i)
	    t = A::next_p(*s, dir);
	}
    }

  A *n = static_cast<A*>(*q);
  *t = n;
  *q = A::next(n, !dir);
  *n = *i;

  return static_cast<Node*>(i);
}

#ifdef __DEBUG_L4_AVL
template< typename Node, typename Get_key, class Compare>
bool Avl_tree<Node, Get_key, Compare>::rec_dump(Avl_tree_node *n, int depth, int *dp, bool print, char pfx)
{
  typedef Avl_tree_node A;

  if (!n)
    return true;

  int dpx[2] = {depth,depth};
  bool res = true;

  res = rec_dump(A::next<A>(n, Dir::R), depth + 1, dpx + 1, print, '/');

  if (print)
    {
      fprintf(stderr, "%2d: [%8p] b=%1d: ", depth, n, (int)n->balance().d);

      for (int i = 0; i < depth; ++i)
	std::cerr << "    ";

      std::cerr << pfx << (static_cast<Node*>(n)->item) << std::endl;
    }

  res = res & rec_dump(A::next<A>(n, Dir::L), depth + 1, dpx, print, '\\');

  int b = dpx[1] - dpx[0];

  if (b < 0)
    *dp = dpx[0];
  else
    *dp = dpx[1];

  Bal x = n->balance();
  if ((b < -1 || b > 1) ||
      (b == 0 && x != Bal::N) ||
      (b == -1 && x != Bal::L) ||
      (b == 1 && x != Bal::R))
    {
      if (print)
	fprintf(stderr, "%2d: [%8p] b=%1d: balance error %d\n", depth, n, (int)n->balance().d, b);
      return false;
    }
  return res;
}
#endif

}

