Skip to content

Static Wavelet Tree - Truy Vấn Thứ Tự Trên Đoạn

Tác giả: FPTOJ Team
Nội dung tham khảo từ: CP-Algorithms - Wavelet Tree


1. Bản chất vấn đề

Bài toán: Truy vấn phần tử lớn thứ \(k\) trong đoạn

Cho mảng \(A\) gồm \(N\) phần tử. Thực hiện \(Q\) truy vấn:

  • \(k\)-th smallest in \([l, r]\): Tìm phần tử nhỏ thứ \(k\) trong đoạn \([l, r]\).
  • Count less than \(x\) in \([l, r]\): Đếm số phần tử nhỏ hơn \(x\) trong đoạn \([l, r]\).
Bài toán Segment Tree Wavelet Tree
\(k\)-th smallest \(O(\log^2 N)\) \(O(\log N)\)
Count less than \(x\) \(O(\log^2 N)\) \(O(\log N)\)
Không gian \(O(N)\) \(O(N \log N)\)

2. Tư duy cốt lõi

Ý tưởng: Phân hoạch theo bit

Wavelet Tree là cây nhị phân mà mỗi nút quản lý 1 khoảng giá trị \([lo, hi]\):

  • Nút lá: \(lo = hi\) (chỉ 1 giá trị).
  • Nút trong: Chia \([lo, hi]\) thành \([lo, mid]\) (con trái) và \([mid+1, hi]\) (con phải).

Mỗi nút lưu mảng B — đếm số phần tử thuộc con trái.

Cấu trúc cây

graph TD A["[1,8]: arr=[4,2,7,1,5,3,8,6]\nB=[0,1,1,1,2,2,3,3]"] --> B["[1,4]: arr=[4,2,1,3]\nB=[0,0,1,2]"] A --> C["[5,8]: arr=[7,5,8,6]\nB=[0,0,1,1]"] B --> D["[1,2]: arr=[2,1]\nB=[0,1]"] B --> E["[3,4]: arr=[4,3]\nB=[0,0]"] C --> F["[5,6]: arr=[5,6]\nB=[0,1]"] C --> G["[7,8]: arr=[7,8]\nB=[0,0]"] D --> H["[1,1]: 1"] D --> I["[2,2]: 2"] E --> J["[3,3]: 3"] E --> K["[4,4]: 4"] F --> L["[5,5]: 5"] F --> M["[6,6]: 6"] G --> N["[7,7]: 7"] G --> O["[8,8]: 8"]

Trace: Tìm phần tử nhỏ thứ 3 trong \([0, 7]\) (toàn mảng)

Mảng: \([4, 2, 7, 1, 5, 3, 8, 6]\), \(k = 3\)

Bước Nút Khoảng \(B\) Số trái So sánh Hành động
1 Gốc \([1,8]\) \([0,1,1,1,2,2,3,3]\) \(4\) \(k=3 \le 4\) Sang trái
2 Con trái \([1,4]\) \([0,0,1,2]\) \(2\) \(k=3 > 2\) Sang phải, \(k=3-2=1\)
3 Con phải \([3,4]\) \([0,0]\) \(0\) \(k=1 > 0\) Sang phải, \(k=1-0=1\)
4 \([4,4]\) Kết quả: 4

3. Phân tích tính đúng đắn

Mảng \(B\) đếm gì?

\(B[i]\) = số phần tử từ \(A[0]\) đến \(A[i]\) thuộc con trái (có giá trị \(\le mid\)).

Khi query trên đoạn \([l, r]\):

  • Số phần tử thuộc con trái trong \([l, r]\): B[r] - B[l-1]
  • Nếu \(k \le\) số trái → đệ quy sang trái.
  • Nếu \(k >\) số trái → đệ quy sang phải, \(k' = k - \text{trái}\).

4. Đánh giá độ phức tạp

Thao tác Thời gian Không gian
Xây cây \(O(N \log N)\) \(O(N \log N)\)
\(k\)-th smallest \(O(\log N)\) \(O(1)\)
Count less than \(x\) \(O(\log N)\) \(O(1)\)

Code minh họa

#include <bits/stdc++.h>
using namespace std;

struct WaveletTree {
    int lo, hi;
    vector<int> B;
    WaveletTree *left, *right;

    WaveletTree(vector<int>::iterator from, vector<int>::iterator to, int x, int y)
        : lo(x), hi(y), left(nullptr), right(nullptr) {
        if (from == to || lo == hi) return;
        int mid = (lo + hi) / 2;
        auto f = [mid](int x) { return x <= mid; };
        B.reserve(to - from + 1);
        B.push_back(0);
        for (auto it = from; it != to; it++)
            B.push_back(B.back() + f(*it));
        auto pivot = stable_partition(from, to, f);
        left = new WaveletTree(from, pivot, lo, mid);
        right = new WaveletTree(pivot, to, mid + 1, hi);
    }

    // Số phần tử <= k trong [l, r] (1-indexed)
    int countLessEq(int l, int r, int k) {
        if (l > r || k < lo) return 0;
        if (hi <= k) return r - l + 1;
        int lb = B[l - 1], rb = B[r];
        return left->countLessEq(lb + 1, rb, k) +
               right->countLessEq(l - lb, r - rb, k);
    }

    // Phần tử nhỏ thứ k trong [l, r] (1-indexed)
    int kth(int l, int r, int k) {
        if (lo == hi) return lo;
        int lb = B[l - 1], rb = B[r];
        int inLeft = rb - lb;
        if (k <= inLeft)
            return left->kth(lb + 1, rb, k);
        else
            return right->kth(l - lb, r - rb, k - inLeft);
    }

    ~WaveletTree() { delete left; delete right; }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, q;
    cin >> n >> q;
    vector<int> a(n);
    for (int i = 0; i < n; i++) cin >> a[i];

    int minVal = *min_element(a.begin(), a.end());
    int maxVal = *max_element(a.begin(), a.end());
    WaveletTree wt(a.begin(), a.end(), minVal, maxVal);

    while (q--) {
        int type;
        cin >> type;
        if (type == 1) {
            int l, r, k;
            cin >> l >> r >> k;
            cout << wt.kth(l, r, k) << "\n";
        } else {
            int l, r, x;
            cin >> l >> r >> x;
            cout << wt.countLessEq(l, r, x) << "\n";
        }
    }
    return 0;
}
# Wavelet Tree trong Python (chỉ dùng cho mục đích học tập)
# Với N lớn, nên dùng C++ vì Wavelet Tree tốn nhiều bộ nhớ trong Python

class WaveletTree:
    def __init__(self, data, lo, hi):
        self.lo = lo
        self.hi = hi
        self.B = [0]
        self.left = self.right = None
        if lo == hi or not data:
            return
        mid = (lo + hi) // 2
        self.B = [0]
        for x in data:
            self.B.append(self.B[-1] + (1 if x <= mid else 0))
        left_data = [x for x in data if x <= mid]
        right_data = [x for x in data if x > mid]
        self.left = WaveletTree(left_data, lo, mid)
        self.right = WaveletTree(right_data, mid + 1, hi)

    def kth(self, l, r, k):
        if self.lo == self.hi:
            return self.lo
        in_left = self.B[r] - self.B[l - 1]
        if k <= in_left:
            new_l = self.B[l - 1] + 1
            new_r = self.B[r]
            return self.left.kth(new_l, new_r, k)
        else:
            new_l = l - self.B[l - 1]
            new_r = r - self.B[r]
            return self.right.kth(new_l, new_r, k - in_left)

n, q = map(int, input().split())
a = list(map(int, input().split()))
lo, hi = min(a), max(a)
wt = WaveletTree(a, lo, hi)

for _ in range(q):
    parts = list(map(int, input().split()))
    if parts[0] == 1:
        l, r, k = parts[1], parts[2], parts[3]
        print(wt.kth(l, r, k))

💬 Bình luận