数据结构与算法-平衡二叉搜索树AVL
平衡二叉搜索树:简称平衡二叉树。由前苏联的数学家 Adelse-Velskil 和 Landis 在 1962 年提出的高度平衡的二叉树,根据科学家的英文名也称为 AVL 树。它具有如下几个性质:
1.可以是空树。
2.假如不是空树,任何一个结点的左子树与右子树都是平衡二叉树,并且高度之差的绝对值不超过 1。
上篇文章优先级队列PriorityQueue源码分析分析了优先级队列PriorityQueue
的实现,PriorityQueue
所用的是二叉堆,是一种具备"下沉"和"上移"功能的二叉搜索树。二叉搜索树在一定程度上可以提高查找效率。但是当原本的数据趋向于有序时,如数据123456,数据结构将会退化成链表,查找时间复杂度O(n),这是平衡二叉搜索树出现的原因。
平衡因子
平衡因子(Balance Factor)是指某个节点左子树与右子树的高度差,平衡二叉树的平衡因子只可能是-1,0,1,。如果平衡因子的绝对值大于1,说明此树不是平衡二叉树。
平衡二叉树和非平衡二叉树
基础节点设计
public class BalanceTree<E extends Comparable<E>> {//实现Comparable,节点的值必须是可比较的
private Node root;
private int size;
private class Node {
private E e;
private Node left;
private Node right;
private int height;//height方便计算平衡因子
public Node(E e) {
this.e = e;
this.left = null;
this.right = null;
this.height = 1;//高度初始值是1
}
}
public BalanceTree() {
this.root = null;
this.size = 0;
}
public int getSize() {
return size;
}
private int getHeight(Node node) {//获取节点高度
if (node == null) {
return 0;
}
return node.height;
}
private int getBalanceFactor(Node node) {//获取节点平衡因子
if (node == null) {
return 0;
}
return getHeight(node.left) - getHeight(node.right);
}
public boolean isBalance(Node node) {//判断是否是一颗平衡二叉树,需左右子树都是平衡二叉树
if (node == null) {
return true;
}
int balanceFactor = Math.abs(getBalanceFactor(node));
if (balanceFactor > 1) {
return false;
}
return isBalance(node.left) && isBalance(node.right);
}
}
- 平衡二叉树也是二叉搜索树,所以节点的值必须是可比较的,需实现
Comparable
- 为了方便计算平衡因子的值,设置height变量
- 平衡因子等于左右子树的高度差
- 判断是否是一颗平衡二叉树,需左右子树都是平衡二叉树
添加节点
往平衡二叉树添加节点很有可能导致平衡二叉树失去平衡,所以每次添加节点后我们需要进行平衡维护,添加节点破坏平衡有以下四种情况
- LL(需要右旋)
LL的意思是为往左节点(L)添加左子节点(L)导致失去平衡的情况,需要右旋维护平衡
右旋新插入的节点4,比9和5都小,所以插入到5的左边,9变成了失衡点,所以将失衡点9作为参数进行右旋
private Node rightRotate(Node imbalance) {
Node left = imbalance.left;//获取9的左节点5
Node leftRight = left.right;//5的右节点,这里是null
left.right = imbalance;//5的右节点赋值为9
imbalance.left = leftRight;//5的右节点放到9的左边
//右旋影响了9和5的高度,重新计算赋值
imbalance.height = Math.max(getHeight(imbalance.left), getHeight(imbalance.right)) + 1;
left.height = Math.max(getHeight(left.left), getHeight(left.right)) + 1;
//将新的头节点返回 这里是5这个节点
return left;
}
右旋思路:将失衡点放到失衡点左节点的右边,并重新计算影响到节点的高度。
- RR(需要左旋)
RR的意思是为往右节点(R)添加右子节点(R)导致失去平衡的情况,需要左旋维护平衡
左旋新插入的节点10,比7和9都大,所以插入到9的右边,7变成了失衡点,将7作为参数左旋
private Node leftRotate(Node imbalance) {
Node right = imbalance.right;//获取失衡点7的右节点,这里是9
Node rightLeft = right.left;//获取9的左节点,这里是null
right.left = imbalance;//将9的左节点指向7
imbalance.right = rightLeft;//7的右节点指向9的原来的左节点
//影响了7和9,重新计算高度赋值
right.height = Math.max(getHeight(right.left), getHeight(right.right)) + 1;
imbalance.height = Math.max(getHeight(imbalance.left), getHeight(imbalance.right)) + 1;
//返回新的根节点9
return right;
}
左旋思路:将失衡点放到失衡点右节点的左边,并重新计算影响到节点的高度。
- LR(需要先左旋再右旋)
新插入的节点8,比9小,比7大,所以插在7的右边,形成LR的情况,先将左节点7左旋,再将根节点9右旋
if (balanceFactor > 1 && getBalanceFactor(node.left) < 0) {//LR
//先左旋 再右旋
node.left = leftRotate(node.left);
return rightRotate(node);
}
先左旋再右旋思路:先将根节点的左节点左旋,再把根节点右旋
RL(需要先右旋再左旋)
新插入的节点9,比7大,比7的右节点10小,所以放在了10的左节点上。需要先对10右旋
右旋
然后根节点7左旋
左旋
if (balanceFactor < -1 && getBalanceFactor(node.right) > 0) {//RL
//先右旋 再左旋
node.right = rightRotate(node.right);
return leftRotate(node);
}
先右旋再左旋思路:先将根节点的右节点右旋,再把根节点左旋
添加节点
private Node add(Node node, E e) {
if (node == null) {
size++;
return new Node(e);
}
if (e.compareTo(node.e) < 0) {
node.left = add(node.left, e);
} else if (e.compareTo(node.e) > 0) {
node.right = add(node.right, e);
}
node.height = Math.max(getHeight(node.left), getHeight(node.right)) + 1;
int balanceFactor = getBalanceFactor(node);
if (balanceFactor > 1 && getBalanceFactor(node.left) > 0) {//LL
//右旋
return rightRotate(node);
}
if (balanceFactor < -1 && getBalanceFactor(node.right) < 0) {//RR
//左旋
return leftRotate(node);
}
if (balanceFactor > 1 && getBalanceFactor(node.left) < 0) {//LR
//先左旋 再右旋
node.left = leftRotate(node.left);
return rightRotate(node);
}
if (balanceFactor < -1 && getBalanceFactor(node.right) > 0) {//RL
//先右旋 再左旋
node.right = rightRotate(node.right);
return leftRotate(node);
}
return node;
}
删除节点
public E remove(E e) {
Node node = getNode(root, e);
if (node != null) {
root = remove(root, e);
return node.e;
}
return null;
}
private Node remove(Node node, E e) {
if (node == null) {
return null;
}
Node retNode;
if (e.compareTo(node.e) < 0) {
node.left = remove(node.left, e);
retNode = node;
} else if (e.compareTo(node.e) > 0) {
node.right = remove(node.right, e);
retNode = node;
} else {
if (node.left == null) {
Node rightNode = node.right;
node.right = null;
size--;
retNode = rightNode;
} else if (node.right == null) {
Node leftNode = node.left;
node.left = null;
size--;
retNode = leftNode;
} else {
Node successor = minimum(node.right);
successor.right = remove(node.right, successor.e);
successor.left = node.left;
node.left = node.right = null;
retNode = successor;
}
}
if (retNode == null) {
return null;
}
retNode.height = Math.max(getHeight(retNode.left), getHeight(retNode.right)) + 1;
int balanceFactor = getBalanceFactor(retNode);
if (balanceFactor > 1 && getBalanceFactor(retNode.left) > 0) {
return rightRotate(retNode);
}
if (balanceFactor < -1 && getBalanceFactor(retNode.right) <= 0) {
return leftRotate(retNode);
}
if (balanceFactor > 1 && getBalanceFactor(retNode.left) < 0) {
node.left = leftRotate(retNode.left);
return rightRotate(retNode);
}
if (balanceFactor < -1 && getBalanceFactor(retNode.right) > 0) {
node.right = rightRotate(retNode.right);
return leftRotate(retNode);
}
return node;
}
整体代码
public class BalanceTree<E extends Comparable<E>> {
private Node root;
private int size;
private class Node {
private E e;
private Node left;
private Node right;
private int height;
public Node(E e) {
this.e = e;
this.left = null;
this.right = null;
this.height = 1;
}
}
public BalanceTree() {
this.root = null;
this.size = 0;
}
public int getSize() {
return size;
}
private int getHeight(Node node) {
if (node == null) {
return 0;
}
return node.height;
}
private int getBalanceFactor(Node node) {
if (node == null) {
return 0;
}
return getHeight(node.left) - getHeight(node.right);
}
public boolean isBalance(Node node) {
if (node == null) {
return true;
}
int balanceFactor = Math.abs(getBalanceFactor(node));
if (balanceFactor > 1) {
return false;
}
return isBalance(node.left) && isBalance(node.right);
}
private Node rightRotate(Node imbalance) {
Node left = imbalance.left;
Node leftRight = left.right;
left.right = imbalance;
imbalance.left = leftRight;
imbalance.height = Math.max(getHeight(imbalance.left), getHeight(imbalance.right)) + 1;
left.height = Math.max(getHeight(left.left), getHeight(left.right)) + 1;
return left;
}
private Node leftRotate(Node imbalance) {
Node right = imbalance.right;
Node rightLeft = right.left;
right.left = imbalance;
imbalance.right = rightLeft;
right.height = Math.max(getHeight(right.left), getHeight(right.right)) + 1;
imbalance.height = Math.max(getHeight(imbalance.left), getHeight(imbalance.right)) + 1;
return right;
}
public void add(E e) {
root = add(root, e);
}
private Node add(Node node, E e) {
if (node == null) {
size++;
return new Node(e);
}
if (e.compareTo(node.e) < 0) {
node.left = add(node.left, e);
} else if (e.compareTo(node.e) > 0) {
node.right = add(node.right, e);
}
node.height = Math.max(getHeight(node.left), getHeight(node.right)) + 1;
int balanceFactor = getBalanceFactor(node);
if (balanceFactor > 1 && getBalanceFactor(node.left) > 0) {//LL
//右旋
return rightRotate(node);
}
if (balanceFactor < -1 && getBalanceFactor(node.right) < 0) {//RR
//左旋
return leftRotate(node);
}
if (balanceFactor > 1 && getBalanceFactor(node.left) < 0) {//LR
//先左旋 再右旋
node.left = leftRotate(node.left);
return rightRotate(node);
}
if (balanceFactor < -1 && getBalanceFactor(node.right) > 0) {//RL
//先右旋 再左旋
node.right = rightRotate(node.right);
return leftRotate(node);
}
return node;
}
private Node getNode(Node node, E e) {
if (node == null) {
return null;
}
if (e.equals(node.e)) {
return node;
} else if (e.compareTo(node.e) < 0) {
return getNode(node.left, e);
} else {
return getNode(node.right, e);
}
}
private Node minimum(Node node) {
if (node.left == null) {
return node;
}
return minimum(node.left);
}
public E remove(E e) {
Node node = getNode(root, e);
if (node != null) {
root = remove(root, e);
return node.e;
}
return null;
}
private Node remove(Node node, E e) {
if (node == null) {
return null;
}
Node retNode;
if (e.compareTo(node.e) < 0) {
node.left = remove(node.left, e);
retNode = node;
} else if (e.compareTo(node.e) > 0) {
node.right = remove(node.right, e);
retNode = node;
} else {
if (node.left == null) {
Node rightNode = node.right;
node.right = null;
size--;
retNode = rightNode;
} else if (node.right == null) {
Node leftNode = node.left;
node.left = null;
size--;
retNode = leftNode;
} else {
Node successor = minimum(node.right);
successor.right = remove(node.right, successor.e);
successor.left = node.left;
node.left = node.right = null;
retNode = successor;
}
}
if (retNode == null) {
return null;
}
retNode.height = Math.max(getHeight(retNode.left), getHeight(retNode.right)) + 1;
int balanceFactor = getBalanceFactor(retNode);
if (balanceFactor > 1 && getBalanceFactor(retNode.left) > 0) {
return rightRotate(retNode);
}
if (balanceFactor < -1 && getBalanceFactor(retNode.right) <= 0) {
return leftRotate(retNode);
}
if (balanceFactor > 1 && getBalanceFactor(retNode.left) < 0) {
node.left = leftRotate(retNode.left);
return rightRotate(retNode);
}
if (balanceFactor < -1 && getBalanceFactor(retNode.right) > 0) {
node.right = rightRotate(retNode.right);
return leftRotate(retNode);
}
return node;
}
}