平衡二叉搜索树C++模板实现

2019-01-31  本文已影响0人  Jesson3264

原文:https://www.cnblogs.com/zhangbaochong/p/5164994.html
改正了几个错误。

  1. root 的初始化, LR, RL 旋转连接问题
#ifndef AVL_H
#define AVL_H

#include <iostream>
#include <algorithm>

using namespace  std;

template <typename T>
struct AvlNode
{
    T data;
    int height;
    AvlNode<T> *left;
    AvlNode<T> *right;
    AvlNode<T>(const T theData) : data(theData), left(NULL), right(NULL), height(0){}
};

template <typename T>
class AvlTree
{
public:
    AvlTree<T>(){
        root = NULL;
    }
    ~AvlTree<T>(){}
    AvlNode<T> *root;

    //插入结点
    void Insert(AvlNode<T> *&t, T x);
    //删除结点
    bool Delete(AvlNode<T> *&t, T x);
    //查找是否存在给定值的结点
    bool Contains(AvlNode<T> *t, const T x) const;
    //中序遍历
    void InorderTraversal(AvlNode<T> *t);
    //前序遍历
    void PreorderTraversal(AvlNode<T> *t);
    //最小值结点
    AvlNode<T> *FindMin(AvlNode<T> *t) const;
    //最大值结点
    AvlNode<T> *FindMax(AvlNode<T> *t) const;
    private:
    //求树的高度
    int GetHeight(AvlNode<T> *t);
    //单旋转 左
    AvlNode<T> *LL(AvlNode<T> *t);
    //单旋转 右
    AvlNode<T> *RR(AvlNode<T> *t);
    //双旋转 右左
    AvlNode<T> *LR(AvlNode<T> *t);
    //双旋转 左右
    AvlNode<T> *RL(AvlNode<T> *t);
};

template <typename T>
AvlNode<T>* AvlTree<T>::FindMax(AvlNode<T> *t) const
{
    if (!t) return NULL;
    if (!t->right)  return t;

    return FindMax(t->right);
}

template <typename T>
AvlNode<T>* AvlTree<T>::FindMin(AvlNode<T> *t) const
{
    if (!t) return NULL;
    if (!t->left) return t;
    return FindMin(t->left);
}

template <typename T>
int AvlTree<T>::GetHeight(AvlNode<T> *t)
{
    if (!t) return -1;
    else
        return t->height;
}

// 单旋转
//左左插入导致的不平衡
template <typename T>
AvlNode<T>* AvlTree<T>::LL(AvlNode<T> *t)
{
    AvlNode<T> *q = t->left;
    t->left = q->right;
    q->right = t;
    t = q;
    t->height = max(GetHeight(t->left), GetHeight(t->right)) + 1;
    q->height = max(GetHeight(q->left), GetHeight(q->right)) + 1;

    return q;
}

// 单旋转
// 右右插入导致的不平衡
template <typename T>
AvlNode<T>* AvlTree<T>::RR(AvlNode<T> *t)
{
    AvlNode<T> *q = t->right;
    t->right = q->left;
    q->left = t;
    t = q;
    t->height = max(GetHeight(t->left), GetHeight(t->right)) + 1;
    q->height = max(GetHeight(q->left), GetHeight(q->right)) + 1;
    return q;
}

// 双旋转
// 插入点位于 t 的左儿子的右子树
template <typename T>
AvlNode<T>* AvlTree<T>::LR(AvlNode<T> *t)
{
    //双旋转可以通过两次单旋转实现
    //对t的左结点进行RR旋转,再对根节点进行LL旋转
    AvlNode<T> * q = RR(t->left);
    t->left = q;
    return LL(t);
}

//双旋转
//插入点位于t的右儿子的左子树
template <typename T>
AvlNode<T> * AvlTree<T>::RL(AvlNode<T> *t)
{
    AvlNode<T> *q = LL(t->right);
    t->right = q;
    return RR(t);
}

template <typename T>
void AvlTree<T>::Insert(AvlNode<T> *&t, T x)
{
    if (!t)
        t = new AvlNode<T>(x);
    else if (x < t->data)
    {
        Insert(t->left, x);
        //
        if (GetHeight(t->left) - GetHeight(t->right) > 1)
        {
            //分两种情况 左左或左右
            if (x < t->left->data)//左左
                t = LL(t);
            else                  //左右
                t = LR(t);
        }
    }
    else if (x > t->data)
    {
        Insert(t->right, x);
        if (GetHeight(t->right) - GetHeight(t->left) > 1)
        {
            if (x > t->right->data)
                t = RR(t);
            else
                t = RL(t);
        }
    }
    else
    {
        ;// data repeate
    }

    t->height = max(GetHeight(t->left), GetHeight(t->right)) + 1;
}

template <typename T>
bool AvlTree<T>::Delete(AvlNode<T> *&t, T x)
{
    if (!t) return false;
    else if (t->data == x)
    {
        if (t->left != NULL && t->right != NULL)
        {
            if (GetHeight(t->left) > GetHeight(t->right))
            {
                t->data = FindMax(t->left)->data;
                Delete(t->left, t->data);
            }
            else
            {
                t->data = FindMax(t->right)->data;
                Delete(t->right, t->data);
            }
        }
        else
        {
            AvlNode<T> *old = t;
            t = t->left ? t->left : t->right;
            delete old;
        }
    }
    else if (x < t->data)
    {
        Delete(t->left, x);
        if (GetHeight(t->right) - GetHeight(t->left) > 1)
        {
            if (GetHeight(t->right->left) > GetHeight(t->right->right))
            {
                t = RL(t);
            }
            else
            {
                t = RR(t);
            }
        }
        else
        {
            t->height = max(GetHeight(t->left), GetHeight(t->right)) + 1;
        }
    }
    else if (x > t->data)
    {
        Delete(t->right, x);
        if (GetHeight(t->left) - GetHeight(t->right) > 1)
        {
            if (GetHeight(t->left->right) > GetHeight(t->left->left))
            {
                t = LR(t);
            }
            else
                t = LL(t);
        }
        else
        {
            t->height = max(GetHeight(t->left), GetHeight(t->right)) + 1;
        }
    }

    return true;
}

template <typename T>
bool AvlTree<T>::Contains(AvlNode<T> *t, const T x) const
{
    if (!t) return false;
    if (x < t->data)
        return Contains(t->left, x);
    else if (x > t->data)
        return Contains(t->right, x);
    else
        return true;
}

template <typename T>
void AvlTree<T>::InorderTraversal(AvlNode<T> *t)
{
    if (t)
    {
        InorderTraversal(t->left);
        cout<<t->data<<" ";
        InorderTraversal(t->right);
    }
}

template <typename T>
void AvlTree<T>::PreorderTraversal(AvlNode<T> *t)
{
    if (t)
    {
        cout<<t->data<<" ";
        PreorderTraversal(t->left);
        PreorderTraversal(t->right);
    }
}
#endif // AVL_H

#include "avl.h"
using namespace std;

int main()
{
    AvlTree<int> tree;
    int value;
    int tmp;
    cout << "input(-1 end):" << endl;
    while (cin >> value)
    {
        if (value == -1)
            break;
        tree.Insert(tree.root,value);
    }
    cout << "InorderTraversal:";
    tree.InorderTraversal(tree.root);
    cout << "\nPreorderTraversal:";
    tree.PreorderTraversal(tree.root);
    cout << "\nfind:";
    cin >> tmp;
    if (tree.Contains(tree.root, tmp))
        cout << "found" << endl;
    else
        cout << "value:" << tmp << " not found" << endl;
    cout << "delete:";
    cin >> tmp;
    tree.Delete(tree.root, tmp);
    cout << "InorderTraversal :";
    tree.InorderTraversal(tree.root);
    cout << "\n:PreorderTraversal";
    tree.PreorderTraversal(tree.root);
}
上一篇 下一篇

猜你喜欢

热点阅读