@@ -64,29 +64,31 @@ struct TrieNode {
64
64
TrieNode * next[ 2] = {};
65
65
int cnt = 0;
66
66
};
67
+ const int BIT = 15;
67
68
class Solution {
68
- void add (TrieNode * node, int n) {
69
- for (int i = 15 ; i >= 0; --i) {
69
+ void addNode (TrieNode * node, int n) {
70
+ for (int i = BIT ; i >= 0; --i) {
70
71
int b = n >> i & 1;
71
- if (node->next[ b] == NULL ) node->next[ b] = new TrieNode();
72
+ if (! node->next[ b] ) node->next[ b] = new TrieNode();
72
73
node = node->next[ b] ;
73
74
node->cnt++;
74
75
}
75
76
}
76
- int count(TrieNode * node, int i, int n, int rl, int rh, int low, int high) {
77
- if (rl >= low && rh <= high) return node->cnt;
78
- if (rh < low || rl > high) return 0;
79
- int b = n >> i & 1, r = b ^ 1, mask = ~ (1 << i);
80
- return (node->next[ 0] ? count(node->next[ 0] , i - 1, n, rl & mask | (b << i), rh & mask | (b << i), low, high) : 0)
81
- + (node->next[ 1] ? count(node->next[ 1] , i - 1, n, rl & mask | (r << i), rh & mask | (r << i), low, high) : 0);
77
+ int count(TrieNode * node, int n, int low, int high, int i = BIT, int rangeMin = 0, int rangeMax = (1 << (BIT + 1)) - 1) {
78
+ if (rangeMin >= low && rangeMax <= high) return node->cnt;
79
+ if (rangeMax < low || rangeMin > high) return 0;
80
+ int ans = 0, b = n >> i & 1, r = 1 - b, mask = 1 << i;
81
+ if (node->next[ b] ) ans += count(node->next[ b] , n, low, high, i - 1, rangeMin & ~ mask, rangeMax & ~ mask);
82
+ if (node->next[ r] ) ans += count(node->next[ r] , n, low, high, i - 1, rangeMin | mask, rangeMax | mask);
83
+ return ans;
82
84
}
83
85
public:
84
86
int countPairs(vector<int >& A, int low, int high) {
85
87
TrieNode root;
86
88
int ans = 0;
87
89
for (int n : A) {
88
- ans += count(&root, 15, n, 0, (1 << 16) - 1 , low, high);
89
- add (&root, n);
90
+ ans += count(&root, n , low, high);
91
+ addNode (&root, n);
90
92
}
91
93
return ans;
92
94
}
0 commit comments