AVL Tree
基础的二分搜索树本身增删不保证平衡,AVL通过旋转保证了二分搜索树的绝对平衡
AVL树的实现
class AVLTree<K : Comparable<K>, V> : BST<K, V> {
private var root: BSTNode<K, V>? = null
private var count = 0
override fun size(): Int = this.count
override fun getRoot(): BSTNode<K, V>? = this.root
override fun insert(key: K, value: V) {
root = add(root, key, value)
}
private fun add(node: BSTNode<K, V>?, k: K, v: V): BSTNode<K, V> {
if (node == null) {
count++
return BasicBSTNode(k, v)
}
when {
k < node.getKey() -> {
node.setLeft(add(node.getLeft(), k, v)).updateHeight()
}
k > node.getKey() -> {
node.setRight(add(node.getRight(), k, v)).updateHeight()
}
else -> {
node.setValue(v)
}
}
return node.updateHeight().rebalance(Action.ADD)
}
override fun remove(k: K): V? {
return getNode(k)?.let { node ->
val rtV = node.getValue()
root = remove(root, k, rtV)
rtV
}
}
private fun remove(node: BSTNode<K, V>?, k: K, v: V): BSTNode<K, V>? {
if (node == null) {
return null
}
when {
k < node.getKey() -> node.setLeft(remove(node.getLeft(), k, v))
k > node.getKey() -> node.setRight(remove(node.getRight(), k, v))
else -> {
if (node.getLeft() == null && node.getRight() == null) {
count--
return null
}
// 剩余三种情况 左不空+右不空 左不空+右空 左空+右不空
if (node.getLeft() != null) {
// 左不空+右不空 左不空+右空
val leftMax = getMax(node.getLeft())!!
node.setKey(leftMax.getKey())
.setValue(leftMax.getValue())
remove(node.getLeft(), leftMax.getKey(), leftMax.getValue())
} else {
// 左空+右不空
val rightMin = getMin(node.getRight())!!
node.setKey(rightMin.getKey())
.setValue(rightMin.getValue())
node.setRight(remove(node.getRight(), rightMin.getKey(), rightMin.getValue()))
}
}
}
return node.updateHeight().rebalance(Action.REMOVE)
}
override fun clear() {
this.root = null
this.count = 0
}
}
/**
* Retrieves the height of the specified binary search tree (BST) node.
* If the node is null, returns 0.
*
* @param node the node for which the height is to be determined; can be null
* @return the height of the node, or 0 if the node is null
*/
internal fun <K : Comparable<K>, V> getNodeHeight(node: BSTNode<K, V>?): Int {
return node?.getHeight() ?: 0
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.updateHeight(): BSTNode<K, V> {
val max = max(getNodeHeight(this.getLeft()), getNodeHeight(this.getRight()))
this.setHeight(max + 1)
return this
}
关键的旋转逻辑
internal enum class Action {
ADD, REMOVE
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.getBalanceFactor(): Int =
getNodeHeight(this.getLeft()) - getNodeHeight(this.getRight())
internal fun Int.valueIn(min: Int, max: Int): Boolean {
if (min > max) {
throw IllegalArgumentException("min should be less than or equal to max")
}
return this in min..max
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.rebalance(action: Action): BSTNode<K, V> {
val balanceFactor = this.getBalanceFactor()
if (balanceFactor.valueIn(-1, 1)) {
return this
}
log.debug("balanceFactor: {}, action: {}, node: {}", balanceFactor, action, this)
return when (action) {
Action.ADD -> {
when {
balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() == 1 -> this.ll()
balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() == -1 -> this.lr()
balanceFactor == -2 && this.getRight()!!.getBalanceFactor() == -1 -> this.rr()
balanceFactor == -2 && this.getRight()!!.getBalanceFactor() == 1 -> this.rl()
else -> throw IllegalStateException("balanceFactor: $balanceFactor, action: $action, node: $this")
}
}
Action.REMOVE -> {
when {
balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() >= 0 -> this.ll()
balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() < 0 -> this.lr()
balanceFactor == -2 && this.getRight()!!.getBalanceFactor() <= 0 -> this.rr()
balanceFactor == -2 && this.getRight()!!.getBalanceFactor() > 0 -> this.rl()
else -> throw IllegalStateException("balanceFactor: $balanceFactor, action: $action, node: $this")
}
}
}
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.ll(): BSTNode<K, V> {
return this.rotateNodeAndLeft()
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.rr(): BSTNode<K, V> {
return this.rotateNodeAndRight()
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.lr(): BSTNode<K, V> {
/*
x
/
y
\
z
*/
return this.setLeft(
this.getLeft()?.rotateNodeAndRight()
).updateHeight().rotateNodeAndLeft()
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.rl(): BSTNode<K, V> {
/*
x
\
y
/
z
*/
return this.setRight(
this.getRight()?.rotateNodeAndLeft()
).updateHeight().rotateNodeAndRight()
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.rotateNodeAndRight(): BSTNode<K, V> {
/*
n x
/ \ / \
? x n b?
/ \ / \
a? b? ? a?
*/
val x = this.getRight() ?: return this
// 先更新下层节点的高度
this.setRight(x.getLeft()).updateHeight()
// 再更新新节点的高度
return x.setLeft(this).updateHeight()
}
internal fun <K : Comparable<K>, V> BSTNode<K, V>.rotateNodeAndLeft(): BSTNode<K, V> {
/*
n x
/ \ / \
x ? a? n
/ \ / \
a? b? b? ?
*/
val x = this.getLeft() ?: return this
this.setLeft(x.getRight()).updateHeight()
return x.setRight(this).updateHeight()
}
27 January 2026