Persistent Data Structures - Lưu Lịch Sử Thay Đổi
Tác giả: FPTOJ Team
Nội dung tham khảo từ: VNOI Wiki - Persistent Data Structures, CP-Algorithms
1. Bản chất vấn đề
Bài toán: Truy vấn tổng đoạn trên phiên bản cũ
Cho mảng \(A\) gồm \(N\) phần tử, thực hiện \(Q\) truy vấn:
Type 1: update(i, v) — Gán \(A_i = v\) .
Type 2: query(ver, l, r) — Truy vấn tổng đoạn \([l, r]\) trên phiên bản thứ ver .
Vấn đề: Nếu lưu toàn bộ mảng cho mỗi phiên bản \(\Rightarrow O(N \cdot Q)\) bộ nhớ \(\Rightarrow\) tràn bộ nhớ!
Giải pháp: Persistent Segment Tree — chỉ lưu các nút thay đổi giữa các phiên bản.
So sánh
Cấu trúc
Thời gian cập nhật
Thời gian truy vấn
Không gian
Segment Tree thường
\(O(\log N)\)
\(O(\log N)\)
\(O(N)\)
Copy toàn bộ mảng
\(O(N)\)
\(O(1)\)
\(O(N \cdot Q)\)
Persistent Segment Tree
\(O(\log N)\)
\(O(\log N)\)
\(O(N + Q \log N)\)
2. Tư duy cốt lõi
Ý tưởng: Path Copying
Khi cập nhật 1 phần tử, chỉ có \(O(\log N)\) nút trên đường đi từ gốc đến lá bị thay đổi. Thay vì sửa trực tiếp, ta tạo bản sao của các nút đó và liên kết với các nút cũ.
flowchart TD
subgraph "Phiên bản 0 (gốc)"
R0["Gốc v0: 15"] --> L0["Trái v0: 7"]
R0 --> R0R["Phải v0: 8"]
L0 --> LL0["Lá: 3"]
L0 --> LR0["Lá: 4"]
R0R --> RL0["Lá: 5"]
R0R --> RR0["Lá: 3"]
end
subgraph "Phiên bản 1 (sửa A[2] = 10)"
R1["Gốc v1: 20"] --> L1["Trái v1: 7 (dùng lại)"]
R1 --> R1R["Phải v1: 13 (mới)"]
R1R --> RL1["Lá: 10 (mới)"]
R1R --> RR1["Lá: 3 (dùng lại)"]
L1 --> LL1["Lá: 3 (dùng lại)"]
L1 --> LR1["Lá: 4 (dùng lại)"]
end
Các nút màu "dùng lại" là con trỏ trỏ đến nút cũ — không tốn thêm bộ nhớ!
Cấu trúc nút
Mỗi nút lưu:
left, right: con trỏ đến con trái, con phải
sum: tổng đoạn mà nút quản lý
Khi cập nhật, tạo nút mới cho mỗi nút trên đường đi, các nút còn lại giữ nguyên.
Trace chi tiết
Mảng ban đầu: \(A = [3, 4, 5, 3]\) (phiên bản 0)
Truy vấn: update(2, 10) (phiên bản 1) — sửa \(A_2\) từ \(5\) thành \(10\) .
Cây phiên bản 0:
Nút
Đoạn
Tổng
Con trái
Con phải
node0
\([0, 3]\)
\(15\)
node1
node2
node1
\([0, 1]\)
\(7\)
node3 (lá: 3)
node4 (lá: 4)
node2
\([2, 3]\)
\(8\)
node5 (lá: 5)
node6 (lá: 3)
Cập nhật \(A_2 = 10\) :
Bước
Nút cũ
Nút mới tạo
Giá trị mới
1
node5 (lá, \(A_2\) )
node7
\(10\)
2
node2 (quản lý \([2,3]\) )
node8
\(10 + 3 = 13\)
3
node0 (gốc)
node9
\(7 + 13 = 20\)
Các nút node1, node3, node4, node6 không thay đổi — giữ nguyên con trỏ.
Cây phiên bản 1:
Nút
Đoạn
Tổng
Con trái
Con phải
node9
\([0, 3]\)
\(20\)
node1 (cũ)
node8 (mới)
node8
\([2, 3]\)
\(13\)
node7 (mới)
node6 (cũ)
node7
\([2, 2]\)
\(10\)
—
—
3. Phân tích tính đúng đắn
Tại sao truy vấn trên phiên bản cũ vẫn đúng?
Mỗi phiên bản là 1 cây nhị phân đầy đủ. Gốc của phiên bản \(v\) là root[v]. Khi truy vấn, duyệt từ root[v] xuống lá — tất cả nút trên đường đi đều tồn tại (nút mới hoặc nút cũ được liên kết).
Tại sao không gian là \(O(N + Q \log N)\) ?
Xây cây ban đầu: \(O(N)\) nút.
Mỗi lần cập nhật: tạo \(O(\log N)\) nút mới.
\(Q\) lần cập nhật: \(O(Q \log N)\) nút.
Tổng: \(O(N + Q \log N)\) .
4. Đánh giá độ phức tạp
Thao tác
Thời gian
Không gian
Xây cây phiên bản 0
\(O(N)\)
\(O(N)\)
Cập nhật (tạo phiên bản mới)
\(O(\log N)\)
\(O(\log N)\)
Truy vấn tổng đoạn
\(O(\log N)\)
\(O(1)\)
Truy vấn phiên bản \(k\)
\(O(\log N)\)
\(O(1)\)
Code minh họa
Persistent Segment Tree — Truy vấn tổng đoạn theo phiên bản
C++ Python
#include <bits/stdc++.h>
using namespace std ;
struct Node {
int left = -1 , right = -1 ;
long long sum = 0 ;
};
vector < Node > tree ;
vector < int > roots ; // gốc của mỗi phiên bản
int build ( vector < int >& a , int lo , int hi ) {
int id = tree . size ();
tree . push_back ( Node ());
if ( lo == hi ) {
tree [ id ]. sum = a [ lo ];
return id ;
}
int mid = ( lo + hi ) / 2 ;
tree [ id ]. left = build ( a , lo , mid );
tree [ id ]. right = build ( a , mid + 1 , hi );
tree [ id ]. sum = tree [ tree [ id ]. left ]. sum + tree [ tree [ id ]. right ]. sum ;
return id ;
}
int update ( int old , int lo , int hi , int pos , long long val ) {
int id = tree . size ();
tree . push_back ( tree [ old ]); // copy nút cũ
if ( lo == hi ) {
tree [ id ]. sum = val ;
return id ;
}
int mid = ( lo + hi ) / 2 ;
if ( pos <= mid )
tree [ id ]. left = update ( tree [ old ]. left , lo , mid , pos , val );
else
tree [ id ]. right = update ( tree [ old ]. right , mid + 1 , hi , pos , val );
tree [ id ]. sum = tree [ tree [ id ]. left ]. sum + tree [ tree [ id ]. right ]. sum ;
return id ;
}
long long query ( int id , int lo , int hi , int l , int r ) {
if ( r < lo || hi < l ) return 0 ;
if ( l <= lo && hi <= r ) return tree [ id ]. sum ;
int mid = ( lo + hi ) / 2 ;
return query ( tree [ id ]. left , lo , mid , l , r ) +
query ( tree [ id ]. right , mid + 1 , hi , l , r );
}
int main () {
ios_base :: sync_with_stdio ( false );
cin . tie ( NULL );
int n , q ;
cin >> n >> q ;
vector < int > a ( n );
for ( int i = 0 ; i < n ; i ++ ) cin >> a [ i ];
roots . push_back ( build ( a , 0 , n - 1 ));
while ( q -- ) {
int type ;
cin >> type ;
if ( type == 1 ) {
int pos , val ;
cin >> pos >> val ;
pos -- ;
int newRoot = update ( roots . back (), 0 , n - 1 , pos , val );
roots . push_back ( newRoot );
} else {
int ver , l , r ;
cin >> ver >> l >> r ;
ver -- ; l -- ; r -- ;
cout << query ( roots [ ver ], 0 , n - 1 , l , r ) << " \n " ;
}
}
return 0 ;
}
import sys
input = sys . stdin . readline
sys . setrecursionlimit ( 1 << 25 )
class Node :
__slots__ = [ 'left' , 'right' , 'sum_val' ]
def __init__ ( self , left =- 1 , right =- 1 , sum_val = 0 ):
self . left = left
self . right = right
self . sum_val = sum_val
tree = []
roots = []
def build ( a , lo , hi ):
node_id = len ( tree )
tree . append ( Node ())
if lo == hi :
tree [ node_id ] . sum_val = a [ lo ]
return node_id
mid = ( lo + hi ) // 2
tree [ node_id ] . left = build ( a , lo , mid )
tree [ node_id ] . right = build ( a , mid + 1 , hi )
tree [ node_id ] . sum_val = tree [ tree [ node_id ] . left ] . sum_val + tree [ tree [ node_id ] . right ] . sum_val
return node_id
def update ( old , lo , hi , pos , val ):
node_id = len ( tree )
tree . append ( Node ( tree [ old ] . left , tree [ old ] . right , tree [ old ] . sum_val ))
if lo == hi :
tree [ node_id ] . sum_val = val
return node_id
mid = ( lo + hi ) // 2
if pos <= mid :
tree [ node_id ] . left = update ( tree [ old ] . left , lo , mid , pos , val )
else :
tree [ node_id ] . right = update ( tree [ old ] . right , mid + 1 , hi , pos , val )
tree [ node_id ] . sum_val = tree [ tree [ node_id ] . left ] . sum_val + tree [ tree [ node_id ] . right ] . sum_val
return node_id
def query ( node_id , lo , hi , l , r ):
if r < lo or hi < l :
return 0
if l <= lo and hi <= r :
return tree [ node_id ] . sum_val
mid = ( lo + hi ) // 2
return query ( tree [ node_id ] . left , lo , mid , l , r ) + query ( tree [ node_id ] . right , mid + 1 , hi , l , r )
n , q = map ( int , input () . split ())
a = list ( map ( int , input () . split ()))
roots . append ( build ( a , 0 , n - 1 ))
for _ in range ( q ):
parts = list ( map ( int , input () . split ()))
if parts [ 0 ] == 1 :
pos , val = parts [ 1 ] - 1 , parts [ 2 ]
new_root = update ( roots [ - 1 ], 0 , n - 1 , pos , val )
roots . append ( new_root )
else :
ver , l , r = parts [ 1 ] - 1 , parts [ 2 ] - 1 , parts [ 3 ] - 1
print ( query ( roots [ ver ], 0 , n - 1 , l , r ))