Mind and Hand Help

AVL Tree

基础的二分搜索树本身增删不保证平衡,AVL通过旋转保证了二分搜索树的绝对平衡

AVL树的实现

class AVLTree<K : Comparable<K>, V> : BST<K, V> { private var root: BSTNode<K, V>? = null private var count = 0 override fun size(): Int = this.count override fun getRoot(): BSTNode<K, V>? = this.root override fun insert(key: K, value: V) { root = add(root, key, value) } private fun add(node: BSTNode<K, V>?, k: K, v: V): BSTNode<K, V> { if (node == null) { count++ return BasicBSTNode(k, v) } when { k < node.getKey() -> { node.setLeft(add(node.getLeft(), k, v)).updateHeight() } k > node.getKey() -> { node.setRight(add(node.getRight(), k, v)).updateHeight() } else -> { node.setValue(v) } } return node.updateHeight().rebalance(Action.ADD) } override fun remove(k: K): V? { return getNode(k)?.let { node -> val rtV = node.getValue() root = remove(root, k, rtV) rtV } } private fun remove(node: BSTNode<K, V>?, k: K, v: V): BSTNode<K, V>? { if (node == null) { return null } when { k < node.getKey() -> node.setLeft(remove(node.getLeft(), k, v)) k > node.getKey() -> node.setRight(remove(node.getRight(), k, v)) else -> { if (node.getLeft() == null && node.getRight() == null) { count-- return null } // 剩余三种情况 左不空+右不空 左不空+右空 左空+右不空 if (node.getLeft() != null) { // 左不空+右不空 左不空+右空 val leftMax = getMax(node.getLeft())!! node.setKey(leftMax.getKey()) .setValue(leftMax.getValue()) remove(node.getLeft(), leftMax.getKey(), leftMax.getValue()) } else { // 左空+右不空 val rightMin = getMin(node.getRight())!! node.setKey(rightMin.getKey()) .setValue(rightMin.getValue()) node.setRight(remove(node.getRight(), rightMin.getKey(), rightMin.getValue())) } } } return node.updateHeight().rebalance(Action.REMOVE) } override fun clear() { this.root = null this.count = 0 } } /** * Retrieves the height of the specified binary search tree (BST) node. * If the node is null, returns 0. * * @param node the node for which the height is to be determined; can be null * @return the height of the node, or 0 if the node is null */ internal fun <K : Comparable<K>, V> getNodeHeight(node: BSTNode<K, V>?): Int { return node?.getHeight() ?: 0 } internal fun <K : Comparable<K>, V> BSTNode<K, V>.updateHeight(): BSTNode<K, V> { val max = max(getNodeHeight(this.getLeft()), getNodeHeight(this.getRight())) this.setHeight(max + 1) return this }

关键的旋转逻辑

internal enum class Action { ADD, REMOVE } internal fun <K : Comparable<K>, V> BSTNode<K, V>.getBalanceFactor(): Int = getNodeHeight(this.getLeft()) - getNodeHeight(this.getRight()) internal fun Int.valueIn(min: Int, max: Int): Boolean { if (min > max) { throw IllegalArgumentException("min should be less than or equal to max") } return this in min..max } internal fun <K : Comparable<K>, V> BSTNode<K, V>.rebalance(action: Action): BSTNode<K, V> { val balanceFactor = this.getBalanceFactor() if (balanceFactor.valueIn(-1, 1)) { return this } log.debug("balanceFactor: {}, action: {}, node: {}", balanceFactor, action, this) return when (action) { Action.ADD -> { when { balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() == 1 -> this.ll() balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() == -1 -> this.lr() balanceFactor == -2 && this.getRight()!!.getBalanceFactor() == -1 -> this.rr() balanceFactor == -2 && this.getRight()!!.getBalanceFactor() == 1 -> this.rl() else -> throw IllegalStateException("balanceFactor: $balanceFactor, action: $action, node: $this") } } Action.REMOVE -> { when { balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() >= 0 -> this.ll() balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() < 0 -> this.lr() balanceFactor == -2 && this.getRight()!!.getBalanceFactor() <= 0 -> this.rr() balanceFactor == -2 && this.getRight()!!.getBalanceFactor() > 0 -> this.rl() else -> throw IllegalStateException("balanceFactor: $balanceFactor, action: $action, node: $this") } } } } internal fun <K : Comparable<K>, V> BSTNode<K, V>.ll(): BSTNode<K, V> { return this.rotateNodeAndLeft() } internal fun <K : Comparable<K>, V> BSTNode<K, V>.rr(): BSTNode<K, V> { return this.rotateNodeAndRight() } internal fun <K : Comparable<K>, V> BSTNode<K, V>.lr(): BSTNode<K, V> { /* x / y \ z */ return this.setLeft( this.getLeft()?.rotateNodeAndRight() ).updateHeight().rotateNodeAndLeft() } internal fun <K : Comparable<K>, V> BSTNode<K, V>.rl(): BSTNode<K, V> { /* x \ y / z */ return this.setRight( this.getRight()?.rotateNodeAndLeft() ).updateHeight().rotateNodeAndRight() } internal fun <K : Comparable<K>, V> BSTNode<K, V>.rotateNodeAndRight(): BSTNode<K, V> { /* n x / \ / \ ? x n b? / \ / \ a? b? ? a? */ val x = this.getRight() ?: return this // 先更新下层节点的高度 this.setRight(x.getLeft()).updateHeight() // 再更新新节点的高度 return x.setLeft(this).updateHeight() } internal fun <K : Comparable<K>, V> BSTNode<K, V>.rotateNodeAndLeft(): BSTNode<K, V> { /* n x / \ / \ x ? a? n / \ / \ a? b? b? ? */ val x = this.getLeft() ?: return this this.setLeft(x.getRight()).updateHeight() return x.setRight(this).updateHeight() }
27 January 2026