jingyulog

Binary Search Tree, AVL Tree, Red-Black Tree 본문

컴퓨터 사이언스/자료구조

Binary Search Tree, AVL Tree, Red-Black Tree

jingyulog 2025. 10. 11. 19:20

개요

Binary Search Tree (BST)

  • 정의: 각 노드가 최대 2개의 자식을 가지며, 왼쪽 서브트리의 모든 값 < 현재 노드의 값, 오른쪽 서브트리의 모든 값 > 현재 노드의 값인 조건을 만족하는 트리.
  • 특징: 탐색, 삽입, 삭제 모두 평균 O(log n). 하지만, 균현이 무너지면 한쪽으로 치우친 편향 트리가 되어 O(n)까지 느려진다.
  • 예시
        10
       /  \
      5    15
     / \     \
    3   7     20

AVL Tree (Adelson-Velsky and Landis Tree)

  • 정의: BST의 일종이지만, 항상 균형을 유지하는 트리
  • 균형 조건: 모든 노드에 대해 |왼쪽 서브트리 높이 - 오른쪽 서브트리 높이| <= 1 조건을 만족해야 한다. 즉, 높이 차이가 1을 넘지 않도록 삽입/삭제 시 회전(Ratation)을 수행한다.
  • 장점: 항상 균형을 유지하므로 탐색 O(log n)성능을 보장한다.
  • 단점: 삽입/삭제시 회전이 자주 일어나서 오버헤드가 있다.

Red-Black Tree

  • 정의 AVL처럼 균형 잡힌 BST이지만, 완벽한 균형 대신 느슨한 균형을 유지해서 삽입/삭제 속도를 더 빠르게 한 트리이다.
  • 균형 조건
    1. 각 노드는 Red, Black
    2. Root는 항상 Black
    3. 모든 leaf node(NIL)은 Black
    4. Red 노드의 자식은 항상 Black (red-red 연속 불가)
    5. 어떤 노드에서 leaf node까지 가는 모든 경로의 Black node 개수는 동일해야 한다.
  • 장점
    • 완전한 균형은 아니지만 높이가 O(log n)으로 유지된다.
    • 삽입/삭제시 회전이 적다.
    • 언어 라이브러리의 기본 트리로 많이 사용된다. (C++ STL의 std::map, std::set, Java의 TreeMap, TreeSet)
  • 단점: AVL보다 탐색이 약간 느리다.

동작 원리

AVL Tree

AVL 트리는 모든 노드에서 왼쪽과 오른쪽 서브트리의 높이 차이가 최대 1이 되도록 유지한다. 이때, 모든 노드가 균형 인수(Balance Factor)를 가진다. (BF = 왼쪽 서브트리의 높이 - 오른쪽 서브트리의 높이). 따라서 BF가 -1, 0, 1이면 균형을 유지하고, BF가 -2, 2이면 불균형하여 회전이 필요하다.

삽입 동작

  1. BST 규칙으로 삽입한다. (왼쪽 < 부모 < 오른쪽)
  2. 루트까지 거슬러 올라가며 BF를 갱신한다. 즉, BF가 절대값 2가 된 지점을 찾아서 불균형 패턴에 따라 회전을 수행한다.

회전 케이스

불균형 유형 발생 상황 회전 방식
LL 왼쪽 자식의 왼쪽에 노드 삽입시 Right Rotation
RR 오른쪽 자식의 오른쪽에 노드 삽입시 Left Rotation
LR 왼쪽 자식의 오른쪽에 노드 삽입시 Left -> Right
RL 오른쪽 자식의 왼쪽에 노드 삽입시 Right -> Left

삭제 동작

삭제 시에도 동일하게 동작한다.

  1. BST 규칙으로 삭제한다.
  2. 위로 올라가며 BF를 갱신한다.
  3. 절댓값 2가 된 곳마다 회전을 수행한다.

Red-Black Tree

Red-Black Tree는 색깔 정보(Red/Black)을 이용해서 균형을 간접적으로 유지하는데, 이렇게 해서 완벽히 높이를 맞추진 않지만, 모든 리프까지의 검정 노드 수가 동일하도록 보장한다.

기본 규칙

  1. 각 노드는 Red, Black
  2. Root는 항상 Black
  3. leaf node(NIL)는 Black
  4. Red node의 자식은 모두 Black
  5. Root에서 leaf node까지 경로의 Black node 수는 동일하다.

삽입 동작

  1. BST 규칙으로 삽입한다. 이때, 새 노드는 항상 Red이다.
  2. Red-Black 규칙이 위반됐는지 확인한다. 이때, 부모가 Red이면, Red-Red 위반이 발생할 가능성이 있다.
  3. 조정한다. (Recoloring 또는 Rotation) -> 부모의 형제(삼촌)의 색깔에 따라 달라진다.

삽입 케이스

케이스 상황 해결 방법
Case 1 삽입 노드가 Root Black으로 변경한다.
Case 2 부모가 Black 문제 없음
Case 3 부모가 Red, 삼촌도 Red 부모와 삼촌을 Black으로 만들고, 조부모를 Red로 만든 후 위로 올라간다.
Case 4 부모가 Red, 삼촌이 Black 회전 + 색상 변경(Rotation + Recoloring)

Case 3

삽입: 10 → 5 → 15 → 1 → 6 → 12 → 18 → 3

  1. insert 5

    10(B)
    /
    5(R)
  2. insert 15

    10(B)
    /  \
    5(R) 15(R)

-> 부모와 삼촌 모두 Red, 조부모는 Black이므로, 부모와 삼촌을 Black, 조부모를 Red로 바꿔주는데, 조부모는 Root이므로 다시 Black으로 바꿔준다.

   10(B)
   /  \
  5(B) 15(B)
  1. insert 1

     10(B)
     /
    5(B)
    /
    1(R)
  2. insert 6

      10(B)
      /
    5(B)
    /  \
    1(R) 6(R)
  3. insert 12

      10(B)
      /   \
    5(B)  15(B)
         /
       12(R)
  4. insert 18

      10(B)
      /   \
    5(B)  15(B)
         /   \
       12(R) 18(R)
  5. insert 3 (문제 발생)
    부모와 삼촌이 모두 Red인데 조부모는 Black이고, 조부모가 Root가 아니다.

        10(B)
        /   \
      5(B)  15(B)
     / \    / \
    1(R)6(R)12(R)18(R)
     \
      3(R)

->

  1. 부모와 삼촌은 Black으로 만들고,
  2. 조부모를 Red로 만든다.
    다만, 이 케이스의 경우에는 조부모(5)가 Red로 바뀌어도, 그 부모(Root, 10) 또한 Black이라 더 이상 위로 전파되지 않고 종료된다.
        10(B)
        /   \
      5(R)  15(B)
     / \    / \
    1(B)6(B)12(R)18(R)
     \
      3(R)

따라서 부모와 삼촌이 모두 Red이면,

  1. 부모와 삼촌을 모두 Black으로 바꾸고, 조부모를 Red로 바꾸고,
  2. 이때, 조부모가 Root면 다시 Black으로 바꾸고,
  3. 조부모가 아니면 조부모가 Red로 변하면서 그 위로 재귀적으로 fix-up이 반복된다.

Case 4

삽입: 10 -> 5 -> 1
Right Rotation 수행 후 색상을 교체한다.

    10(B)
    /
  5(R)
  /
1(R)

->

   5(B)
  /  \
1(R) 10(R)

삭제 동작

  1. BST 규칙으로 삭제
  2. 삭제된 노드가 Black이면 검정 높이 불균형 발생이 가능함
  3. Double Black 처리를 위해 형제의 색깔을 보고 Recoloring, Rotation을 수행한다.

코드 구현

AVL 트리 삽입

class AVLNode:
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None
        self.height = 1


class AVLTree:
    def get_height(self, node):
        return node.height if node else 0

    def get_balance(self, node):
        return self.get_height(node.left) - self.get_height(node.right) if node else 0

    def right_rotate(self, y):
        x = y.left
        T2 = x.right
        x.right = y
        y.left = T2
        y.height = max(self.get_height(y.left), self.get_height(y.right)) + 1
        x.height = max(self.get_height(x.left), self.get_height(x.right)) + 1
        return x

    def left_rotate(self, x):
        y = x.right
        T2 = y.left
        y.left = x
        x.right = T2
        x.height = max(self.get_height(x.left), self.get_height(x.right)) + 1
        y.height = max(self.get_height(y.left), self.get_height(y.right)) + 1
        return y

    def insert(self, node, key):
        # 기본 BST 삽입
        if not node:
            return AVLNode(key)
        elif key < node.key:
            node.left = self.insert(node.left, key)
        else:
            node.right = self.insert(node.right, key)

        # 높이 갱신
        node.height = 1 + max(self.get_height(node.left), self.get_height(node.right))

        # 불균형 감지
        balance = self.get_balance(node)

        # LL
        if balance > 1 and key < node.left.key:
            return self.right_rotate(node)
        # RR
        if balance < -1 and key > node.right.key:
            return self.left_rotate(node)
        # LR
        if balance > 1 and key > node.left.key:
            node.left = self.left_rotate(node.left)
            return self.right_rotate(node)
        # RL
        if balance < -1 and key < node.right.key:
            node.right = self.right_rotate(node.right)
            return self.left_rotate(node)

        return node

    def inorder(self, node):
        return self.inorder(node.left) + [node.key] + self.inorder(node.right) if node else []

    def print_tree(self, node, level=0, prefix="Root: "):
        if node:
            print(" " * (level * 4) + prefix + f"{node.key} (h={node.height})")
            if node.left:
                self.print_tree(node.left, level + 1, "L--- ")
            if node.right:
                self.print_tree(node.right, level + 1, "R--- ")


if __name__ == "__main__":
    avl = AVLTree()
    root = None
    for key in [30, 20, 10, 25, 40]:
        root = avl.insert(root, key)

    print("\n[AVL Tree]")
    avl.print_tree(root)
    print("Inorder:", avl.inorder(root))
  • 실행 결과
    [AVL Tree]
    Root: 20 (h=3)
    L--- 10 (h=1)
    R--- 30 (h=2)
      L--- 25 (h=1)
      R--- 40 (h=1)
    Inorder: [10, 20, 25, 30, 40]

RBT 삽입

RED = True
BLACK = False

class RBNode:
    def __init__(self, key, color=RED, left=None, right=None, parent=None):
        self.key = key
        self.color = color
        self.left = left
        self.right = right
        self.parent = parent


class RBTree:
    def __init__(self):
        self.NIL = RBNode(key=None, color=BLACK)
        self.root = self.NIL

    def left_rotate(self, x):
        y = x.right
        x.right = y.left
        if y.left != self.NIL:
            y.left.parent = x
        y.parent = x.parent
        if x.parent == None:
            self.root = y
        elif x == x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y
        y.left = x
        x.parent = y

    def right_rotate(self, x):
        y = x.left
        x.left = y.right
        if y.right != self.NIL:
            y.right.parent = x
        y.parent = x.parent
        if x.parent == None:
            self.root = y
        elif x == x.parent.right:
            x.parent.right = y
        else:
            x.parent.left = y
        y.right = x
        x.parent = y

    def insert(self, key):
        node = RBNode(key)
        node.left = node.right = self.NIL
        parent = None
        current = self.root

        while current != self.NIL:
            parent = current
            if node.key < current.key:
                current = current.left
            else:
                current = current.right

        node.parent = parent
        if parent is None:
            self.root = node
        elif node.key < parent.key:
            parent.left = node
        else:
            parent.right = node

        node.color = RED
        self.insert_fix(node)

    def insert_fix(self, k):
        while k.parent and k.parent.color == RED:
            if k.parent == k.parent.parent.left:
                u = k.parent.parent.right  # uncle
                if u.color == RED:
                    # 부모, 삼촌 Red → recolor
                    k.parent.color = BLACK
                    u.color = BLACK
                    k.parent.parent.color = RED
                    k = k.parent.parent
                else:
                    if k == k.parent.right:
                        k = k.parent
                        self.left_rotate(k)
                    k.parent.color = BLACK
                    k.parent.parent.color = RED
                    self.right_rotate(k.parent.parent)
            else:
                u = k.parent.parent.left
                if u.color == RED:
                    k.parent.color = BLACK
                    u.color = BLACK
                    k.parent.parent.color = RED
                    k = k.parent.parent
                else:
                    if k == k.parent.left:
                        k = k.parent
                        self.right_rotate(k)
                    k.parent.color = BLACK
                    k.parent.parent.color = RED
                    self.left_rotate(k.parent.parent)
        self.root.color = BLACK

    def inorder(self, node):
        if node == self.NIL:
            return []
        return self.inorder(node.left) + [(node.key, "R" if node.color else "B")] + self.inorder(node.right)

    def print_tree(self, node, indent="", last=True):
        if node != self.NIL:
            print(indent, "└── " if last else "├── ", f"{node.key}({'R' if node.color else 'B'})", sep="")
            indent += "    " if last else "│   "
            self.print_tree(node.left, indent, False)
            self.print_tree(node.right, indent, True)


if __name__ == "__main__":
    rbt = RBTree()
    for key in [10, 5, 1, 15, 20]:
        rbt.insert(key)

    print("\n[Red-Black Tree]")
    rbt.print_tree(rbt.root)
    print("Inorder:", rbt.inorder(rbt.root))
  • 실행 결과
    [Red-Black Tree]
    └── 5(B)
      ├── 1(B)
      └── 15(B)
          ├── 10(R)
          └── 20(R)
    Inorder: [(1, 'B'), (5, 'B'), (10, 'R'), (15, 'B'), (20, 'R')]

AVL/RBT inserting, deleting

# ============================================================
# AVL TREE IMPLEMENTATION (Insert / Delete / Step Visualization)
# ============================================================

class AVLNode:
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None
        self.height = 1


class AVLTree:
    def get_height(self, node):
        return node.height if node else 0

    def get_balance(self, node):
        return self.get_height(node.left) - self.get_height(node.right) if node else 0

    def right_rotate(self, y):
        x = y.left
        T2 = x.right
        x.right = y
        y.left = T2
        y.height = max(self.get_height(y.left), self.get_height(y.right)) + 1
        x.height = max(self.get_height(x.left), self.get_height(x.right)) + 1
        return x

    def left_rotate(self, x):
        y = x.right
        T2 = y.left
        y.left = x
        x.right = T2
        x.height = max(self.get_height(x.left), self.get_height(x.right)) + 1
        y.height = max(self.get_height(y.left), self.get_height(y.right)) + 1
        return y

    def insert(self, node, key, step=False):
        if not node:
            return AVLNode(key)
        elif key < node.key:
            node.left = self.insert(node.left, key, step)
        else:
            node.right = self.insert(node.right, key, step)

        node.height = 1 + max(self.get_height(node.left), self.get_height(node.right))
        balance = self.get_balance(node)

        # 회전 처리
        if balance > 1 and key < node.left.key:
            node = self.right_rotate(node)
        elif balance < -1 and key > node.right.key:
            node = self.left_rotate(node)
        elif balance > 1 and key > node.left.key:
            node.left = self.left_rotate(node.left)
            node = self.right_rotate(node)
        elif balance < -1 and key < node.right.key:
            node.right = self.right_rotate(node.right)
            node = self.left_rotate(node)

        if step:
            print(f"\n[AVL Insert Step: {key}]")
            self.print_tree(node)
        return node

    def min_value_node(self, node):
        while node.left:
            node = node.left
        return node

    def delete(self, root, key, step=False):
        if not root:
            return root
        elif key < root.key:
            root.left = self.delete(root.left, key, step)
        elif key > root.key:
            root.right = self.delete(root.right, key, step)
        else:
            if not root.left:
                return root.right
            elif not root.right:
                return root.left
            temp = self.min_value_node(root.right)
            root.key = temp.key
            root.right = self.delete(root.right, temp.key, step)

        if not root:
            return root

        root.height = 1 + max(self.get_height(root.left), self.get_height(root.right))
        balance = self.get_balance(root)

        # 균형 조정
        if balance > 1 and self.get_balance(root.left) >= 0:
            root = self.right_rotate(root)
        elif balance > 1 and self.get_balance(root.left) < 0:
            root.left = self.left_rotate(root.left)
            root = self.right_rotate(root)
        elif balance < -1 and self.get_balance(root.right) <= 0:
            root = self.left_rotate(root)
        elif balance < -1 and self.get_balance(root.right) > 0:
            root.right = self.right_rotate(root.right)
            root = self.left_rotate(root)

        if step:
            print(f"\n[AVL Delete Step: {key}]")
            self.print_tree(root)
        return root

    def print_tree(self, node, level=0, prefix="Root: "):
        if node:
            print(" " * (level * 4) + prefix + f"{node.key} (h={node.height})")
            if node.left:
                self.print_tree(node.left, level + 1, "L--- ")
            if node.right:
                self.print_tree(node.right, level + 1, "R--- ")

# ============================================================
# RED-BLACK TREE IMPLEMENTATION (Insert / Delete / Step Visualization)
# ============================================================

RED = True
BLACK = False


class RBNode:
    def __init__(self, key, color=RED, left=None, right=None, parent=None):
        self.key = key
        self.color = color
        self.left = left
        self.right = right
        self.parent = parent


class RBTree:
    def __init__(self):
        self.NIL = RBNode(None, color=BLACK)
        self.root = self.NIL

    # --- Rotation ---
    def left_rotate(self, x):
        y = x.right
        x.right = y.left
        if y.left != self.NIL:
            y.left.parent = x
        y.parent = x.parent
        if not x.parent:
            self.root = y
        elif x == x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y
        y.left = x
        x.parent = y

    def right_rotate(self, x):
        y = x.left
        x.left = y.right
        if y.right != self.NIL:
            y.right.parent = x
        y.parent = x.parent
        if not x.parent:
            self.root = y
        elif x == x.parent.right:
            x.parent.right = y
        else:
            x.parent.left = y
        y.right = x
        x.parent = y

    # --- Insertion ---
    def insert(self, key, step=False):
        node = RBNode(key, color=RED, left=self.NIL, right=self.NIL)
        parent = None
        current = self.root

        while current != self.NIL:
            parent = current
            if node.key < current.key:
                current = current.left
            else:
                current = current.right

        node.parent = parent
        if not parent:
            self.root = node
        elif node.key < parent.key:
            parent.left = node
        else:
            parent.right = node

        self.insert_fix(node)

        if step:
            print(f"\n[RBT Insert Step: {key}]")
            self.print_tree(self.root)

    def insert_fix(self, k):
        while k.parent and k.parent.color == RED:
            if k.parent == k.parent.parent.left:
                u = k.parent.parent.right
                if u.color == RED:  # Case 1: 부모, 삼촌이 Red
                    k.parent.color = BLACK
                    u.color = BLACK
                    k.parent.parent.color = RED
                    k = k.parent.parent
                else:
                    if k == k.parent.right:  # Case 2: Left rotation
                        k = k.parent
                        self.left_rotate(k)
                    # Case 3: Right rotation
                    k.parent.color = BLACK
                    k.parent.parent.color = RED
                    self.right_rotate(k.parent.parent)
            else:
                u = k.parent.parent.left
                if u.color == RED:
                    k.parent.color = BLACK
                    u.color = BLACK
                    k.parent.parent.color = RED
                    k = k.parent.parent
                else:
                    if k == k.parent.left:
                        k = k.parent
                        self.right_rotate(k)
                    k.parent.color = BLACK
                    k.parent.parent.color = RED
                    self.left_rotate(k.parent.parent)
        self.root.color = BLACK

    # --- Delete ---
    def transplant(self, u, v):
        if not u.parent:
            self.root = v
        elif u == u.parent.left:
            u.parent.left = v
        else:
            u.parent.right = v
        v.parent = u.parent

    def delete(self, key, step=False):
        z = self.search(self.root, key)
        if z == self.NIL:
            return

        y = z
        y_original_color = y.color
        if z.left == self.NIL:
            x = z.right
            self.transplant(z, z.right)
        elif z.right == self.NIL:
            x = z.left
            self.transplant(z, z.left)
        else:
            y = self.minimum(z.right)
            y_original_color = y.color
            x = y.right
            if y.parent == z:
                x.parent = y
            else:
                self.transplant(y, y.right)
                y.right = z.right
                y.right.parent = y
            self.transplant(z, y)
            y.left = z.left
            y.left.parent = y
            y.color = z.color

        if y_original_color == BLACK:
            self.delete_fix(x)

        if step:
            print(f"\n[RBT Delete Step: {key}]")
            self.print_tree(self.root)

    def delete_fix(self, x):
        while x != self.root and x.color == BLACK:
            if x == x.parent.left:
                s = x.parent.right
                if s.color == RED:
                    s.color = BLACK
                    x.parent.color = RED
                    self.left_rotate(x.parent)
                    s = x.parent.right
                if s.left.color == BLACK and s.right.color == BLACK:
                    s.color = RED
                    x = x.parent
                else:
                    if s.right.color == BLACK:
                        s.left.color = BLACK
                        s.color = RED
                        self.right_rotate(s)
                        s = x.parent.right
                    s.color = x.parent.color
                    x.parent.color = BLACK
                    s.right.color = BLACK
                    self.left_rotate(x.parent)
                    x = self.root
            else:
                s = x.parent.left
                if s.color == RED:
                    s.color = BLACK
                    x.parent.color = RED
                    self.right_rotate(x.parent)
                    s = x.parent.left
                if s.left.color == BLACK and s.right.color == BLACK:
                    s.color = RED
                    x = x.parent
                else:
                    if s.left.color == BLACK:
                        s.right.color = BLACK
                        s.color = RED
                        self.left_rotate(s)
                        s = x.parent.left
                    s.color = x.parent.color
                    x.parent.color = BLACK
                    s.left.color = BLACK
                    self.right_rotate(x.parent)
                    x = self.root
        x.color = BLACK

    def minimum(self, node):
        while node.left != self.NIL:
            node = node.left
        return node

    def search(self, node, key):
        while node != self.NIL and key != node.key:
            node = node.left if key < node.key else node.right
        return node

    # --- Visualization ---
    def print_tree(self, node, indent="", last=True):
        if node != self.NIL:
            print(indent, "└── " if last else "├── ", f"{node.key}({'R' if node.color else 'B'})", sep="")
            indent += "    " if last else "│   "
            self.print_tree(node.left, indent, False)
            self.print_tree(node.right, indent, True)


# ============================================================
# DEMO
# ============================================================

if __name__ == "__main__":
    print("\n===== AVL Tree Demo =====")
    avl = AVLTree()
    root = None
    for key in [30, 20, 10, 25, 40]:
        root = avl.insert(root, key, step=True)
    root = avl.delete(root, 25, step=True)

    print("\n===== Red-Black Tree Demo =====")
    rbt = RBTree()
    for key in [10, 5, 1, 15, 20, 12, 18, 3]:
        rbt.insert(key, step=True)
    rbt.delete(15, step=True)

탐색 (AVL, RBT 모두 탐색은 BST와 동일)

BST

def search(self, key):
    node = self.root
    while node is not None:
        if key == node.key:
            return node
        elif key < node.key:
            node = node.left
        else:
            node = node.right
    return None

AVL

class AVLTree:
    def __init__(self):
        self.root = None

    class Node:
        def __init__(self, key):
            self.key = key
            self.left = None
            self.right = None
            self.height = 1

    def search(self, key):
        node = self.root
        while node:
            if key == node.key:
                return node
            elif key < node.key:
                node = node.left
            else:
                node = node.right
        return None

avl = AVLTree()
for k in [30, 20, 40, 10, 25]:
    avl.insert(k)

result = avl.search(25)
print("Found:", result.key if result else "Not found")

RBT

class RBTree:
    def __init__(self):
        self.NIL = self.Node(None, color='B')
        self.root = self.NIL

    class Node:
        def __init__(self, key, color='R', left=None, right=None, parent=None):
            self.key = key
            self.color = color
            self.left = left
            self.right = right
            self.parent = parent

    def search(self, key):
        node = self.root
        while node != self.NIL:
            if key == node.key:
                return node
            elif key < node.key:
                node = node.left
            else:
                node = node.right
        return None

rbt = RBTree()
for k in [10, 5, 1, 15, 20]:
    rbt.insert(k)

result = rbt.search(15)
print("Found:", result.key if result else "Not found")

최종 버전

"""
avl_rbt.py

Contains:
 - AVLTree: insert, delete, search, inorder, print_tree
 - RedBlackTree: insert, delete, search, inorder, print_tree

Usage: run this file directly to see demo usage at the bottom.
"""

# --------------------------
# AVL Tree
# --------------------------
class AVLNode:
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None
        self.height = 1

class AVLTree:
    def __init__(self):
        self.root = None

    # ---------- helpers ----------
    def _height(self, node):
        return node.height if node else 0

    def _update_height(self, node):
        node.height = 1 + max(self._height(node.left), self._height(node.right))

    def _balance_factor(self, node):
        return self._height(node.left) - self._height(node.right) if node else 0

    def _rotate_right(self, y):
        x = y.left
        T2 = x.right
        # rotate
        x.right = y
        y.left = T2
        # update heights
        self._update_height(y)
        self._update_height(x)
        return x

    def _rotate_left(self, x):
        y = x.right
        T2 = y.left
        # rotate
        y.left = x
        x.right = T2
        # update heights
        self._update_height(x)
        self._update_height(y)
        return y

    # ---------- insertion ----------
    def _insert(self, node, key):
        if not node:
            return AVLNode(key)
        if key < node.key:
            node.left = self._insert(node.left, key)
        else:
            node.right = self._insert(node.right, key)

        self._update_height(node)
        bf = self._balance_factor(node)

        # LL
        if bf > 1 and key < node.left.key:
            return self._rotate_right(node)
        # RR
        if bf < -1 and key > node.right.key:
            return self._rotate_left(node)
        # LR
        if bf > 1 and key > node.left.key:
            node.left = self._rotate_left(node.left)
            return self._rotate_right(node)
        # RL
        if bf < -1 and key < node.right.key:
            node.right = self._rotate_right(node.right)
            return self._rotate_left(node)

        return node

    def insert(self, key):
        """Insert key into AVL tree."""
        self.root = self._insert(self.root, key)

    # ---------- search ----------
    def search(self, key):
        node = self.root
        while node:
            if key == node.key:
                return node
            node = node.left if key < node.key else node.right
        return None

    # ---------- delete ----------
    def _min_node(self, node):
        cur = node
        while cur.left:
            cur = cur.left
        return cur

    def _delete(self, node, key):
        if not node:
            return node
        if key < node.key:
            node.left = self._delete(node.left, key)
        elif key > node.key:
            node.right = self._delete(node.right, key)
        else:
            # node to be deleted
            if not node.left:
                return node.right
            elif not node.right:
                return node.left
            else:
                temp = self._min_node(node.right)
                node.key = temp.key
                node.right = self._delete(node.right, temp.key)

        # update & rebalance
        self._update_height(node)
        bf = self._balance_factor(node)

        # LL
        if bf > 1 and self._balance_factor(node.left) >= 0:
            return self._rotate_right(node)
        # LR
        if bf > 1 and self._balance_factor(node.left) < 0:
            node.left = self._rotate_left(node.left)
            return self._rotate_right(node)
        # RR
        if bf < -1 and self._balance_factor(node.right) <= 0:
            return self._rotate_left(node)
        # RL
        if bf < -1 and self._balance_factor(node.right) > 0:
            node.right = self._rotate_right(node.right)
            return self._rotate_left(node)

        return node

    def delete(self, key):
        """Delete key from AVL tree (if present)."""
        self.root = self._delete(self.root, key)

    # ---------- traversal & pretty ----------
    def inorder(self):
        res = []
        def _in(n):
            if not n: return
            _in(n.left); res.append(n.key); _in(n.right)
        _in(self.root)
        return res

    def print_tree(self, node=None, prefix=""):
        if node is None:
            node = self.root
        def _print(n, indent="", last=True):
            if not n:
                return
            print(indent + ("└── " if last else "├── ") + f"{n.key} (h={n.height})")
            indent += "    " if last else "│   "
            if n.left or n.right:
                # print left then right; ensure consistent ordering
                if n.left:
                    _print(n.left, indent, False if n.right else True)
                else:
                    # placeholder for visual alignment
                    print(indent + ("└── " if n.right else "└── ") + "None")
                if n.right:
                    _print(n.right, indent, True)

        if self.root is None:
            print("(empty)")
        else:
            _print(node, "", True)


# --------------------------
# Red-Black Tree
# --------------------------
RED = True
BLACK = False

class RBNode:
    def __init__(self, key=None, color=BLACK, left=None, right=None, parent=None):
        self.key = key
        self.color = color
        self.left = left
        self.right = right
        self.parent = parent

class RedBlackTree:
    def __init__(self):
        # single NIL sentinel
        self.NIL = RBNode(key=None, color=BLACK)
        self.NIL.left = self.NIL.right = self.NIL.parent = self.NIL
        self.root = self.NIL

    # ---------- rotations ----------
    def _left_rotate(self, x):
        y = x.right
        x.right = y.left
        if y.left != self.NIL:
            y.left.parent = x
        y.parent = x.parent
        if x.parent == self.NIL:
            self.root = y
        elif x == x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y
        y.left = x
        x.parent = y

    def _right_rotate(self, y):
        x = y.left
        y.left = x.right
        if x.right != self.NIL:
            x.right.parent = y
        x.parent = y.parent
        if y.parent == self.NIL:
            self.root = x
        elif y == y.parent.right:
            y.parent.right = x
        else:
            y.parent.left = x
        x.right = y
        y.parent = x

    # ---------- insert ----------
    def insert(self, key):
        node = RBNode(key=key, color=RED, left=self.NIL, right=self.NIL, parent=None)
        y = self.NIL
        x = self.root
        while x != self.NIL:
            y = x
            if node.key < x.key:
                x = x.left
            else:
                x = x.right
        node.parent = y
        if y == self.NIL:
            self.root = node
        elif node.key < y.key:
            y.left = node
        else:
            y.right = node

        self._insert_fixup(node)

    def _insert_fixup(self, z):
        while z.parent.color == RED:
            if z.parent == z.parent.parent.left:
                y = z.parent.parent.right  # uncle
                if y.color == RED:
                    # case 1: recolor
                    z.parent.color = BLACK
                    y.color = BLACK
                    z.parent.parent.color = RED
                    z = z.parent.parent
                else:
                    if z == z.parent.right:
                        # case 2
                        z = z.parent
                        self._left_rotate(z)
                    # case 3
                    z.parent.color = BLACK
                    z.parent.parent.color = RED
                    self._right_rotate(z.parent.parent)
            else:
                y = z.parent.parent.left
                if y.color == RED:
                    z.parent.color = BLACK
                    y.color = BLACK
                    z.parent.parent.color = RED
                    z = z.parent.parent
                else:
                    if z == z.parent.left:
                        z = z.parent
                        self._right_rotate(z)
                    z.parent.color = BLACK
                    z.parent.parent.color = RED
                    self._left_rotate(z.parent.parent)
            if z == self.root:
                break
        self.root.color = BLACK

    # ---------- search ----------
    def search(self, key):
        node = self.root
        while node != self.NIL and node.key is not None:
            if key == node.key:
                return node
            node = node.left if key < node.key else node.right
        return None

    # ---------- delete ----------
    def _transplant(self, u, v):
        if u.parent == self.NIL:
            self.root = v
        elif u == u.parent.left:
            u.parent.left = v
        else:
            u.parent.right = v
        v.parent = u.parent

    def _minimum(self, node):
        while node.left != self.NIL:
            node = node.left
        return node

    def delete(self, key):
        z = self.root
        # find node
        while z != self.NIL and z.key != key:
            z = z.left if key < z.key else z.right
        if z == self.NIL:
            return  # not found

        y = z
        y_original_color = y.color
        if z.left == self.NIL:
            x = z.right
            self._transplant(z, z.right)
        elif z.right == self.NIL:
            x = z.left
            self._transplant(z, z.left)
        else:
            y = self._minimum(z.right)
            y_original_color = y.color
            x = y.right
            if y.parent == z:
                x.parent = y
            else:
                self._transplant(y, y.right)
                y.right = z.right
                y.right.parent = y
            self._transplant(z, y)
            y.left = z.left
            y.left.parent = y
            y.color = z.color

        if y_original_color == BLACK:
            self._delete_fixup(x)

    def _delete_fixup(self, x):
        while x != self.root and x.color == BLACK:
            if x == x.parent.left:
                s = x.parent.right
                if s.color == RED:
                    s.color = BLACK
                    x.parent.color = RED
                    self._left_rotate(x.parent)
                    s = x.parent.right
                if s.left.color == BLACK and s.right.color == BLACK:
                    s.color = RED
                    x = x.parent
                else:
                    if s.right.color == BLACK:
                        s.left.color = BLACK
                        s.color = RED
                        self._right_rotate(s)
                        s = x.parent.right
                    s.color = x.parent.color
                    x.parent.color = BLACK
                    s.right.color = BLACK
                    self._left_rotate(x.parent)
                    x = self.root
            else:
                s = x.parent.left
                if s.color == RED:
                    s.color = BLACK
                    x.parent.color = RED
                    self._right_rotate(x.parent)
                    s = x.parent.left
                if s.left.color == BLACK and s.right.color == BLACK:
                    s.color = RED
                    x = x.parent
                else:
                    if s.left.color == BLACK:
                        s.right.color = BLACK
                        s.color = RED
                        self._left_rotate(s)
                        s = x.parent.left
                    s.color = x.parent.color
                    x.parent.color = BLACK
                    s.left.color = BLACK
                    self._right_rotate(x.parent)
                    x = self.root
        x.color = BLACK

    # ---------- traversal & pretty ----------
    def inorder(self):
        res = []
        def _in(n):
            if n == self.NIL:
                return
            _in(n.left); res.append((n.key, 'R' if n.color else 'B')); _in(n.right)
        _in(self.root)
        return res

    def print_tree(self, node=None, prefix=""):
        if node is None:
            node = self.root
        def _print(n, indent="", last=True):
            if n == self.NIL:
                return
            print(indent + ("└── " if last else "├── ") + f"{n.key} ({'R' if n.color else 'B'})")
            indent += "    " if last else "│   "
            if n.left != self.NIL or n.right != self.NIL:
                if n.left != self.NIL:
                    _print(n.left, indent, False if n.right != self.NIL else True)
                else:
                    print(indent + ("└── " if n.right != self.NIL else "└── ") + "NIL")
                if n.right != self.NIL:
                    _print(n.right, indent, True)

        if self.root == self.NIL:
            print("(empty)")
        else:
            _print(node, "", True)


# --------------------------
# Demo (if run as script)
# --------------------------
if __name__ == "__main__":
    print("=== AVL Demo ===")
    avl = AVLTree()
    for k in [30, 20, 10, 25, 40]:
        avl.insert(k)
    avl.print_tree()
    print("AVL inorder:", avl.inorder())
    print("Search 25:", avl.search(25).key if avl.search(25) else None)
    avl.delete(25)
    print("After delete 25:")
    avl.print_tree()
    print("AVL inorder:", avl.inorder())

    print("\n=== Red-Black Demo ===")
    rbt = RedBlackTree()
    for k in [10, 5, 1, 15, 20, 12, 18, 3]:
        rbt.insert(k)
        print(f"\nInserted {k}:")
        rbt.print_tree()
    print("\nRBT inorder:", rbt.inorder())
    print("Search 15:", rbt.search(15).key if rbt.search(15) else None)
    print("\nDelete 15:")
    rbt.delete(15)
    rbt.print_tree()
    print("RBT inorder:", rbt.inorder())

'컴퓨터 사이언스 > 자료구조' 카테고리의 다른 글

Set  (0) 2025.10.26
HashSet  (0) 2025.10.19
이진 트리(Binary Tree) 소개  (0) 2025.10.04
AVL 트리  (0) 2023.11.07
Red-Black Tree 구현  (0) 2023.11.05