Skip to content

Commit 0f27055

Browse files
committed
fix tests
1 parent 5e4e56b commit 0f27055

File tree

7 files changed

+128
-134
lines changed

7 files changed

+128
-134
lines changed

tests/library_checker_aizu_tests/edge_cd_asserts.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
2-
void edge_cd_asserts(const vector<vi>& adj, int cent,
3-
int split) {
2+
auto edge_cd_asserts = [&](int cent, int split) -> void {
43
assert(0 < split && split < sz(adj[cent]));
54
auto dfs = [&](auto&& self, int u, int p) -> int {
65
int siz = 1;
@@ -47,4 +46,4 @@ void edge_cd_asserts(const vector<vi>& adj, int cent,
4746
assert(!is_balanced(a, cnts[0] + b));
4847
assert(!is_balanced(b, cnts[0] + a));
4948
}
50-
}
49+
};

tests/library_checker_aizu_tests/handmade_tests/count_paths.test.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,23 @@ vector<vector<ll>> naive(const vector<vi>& adj) {
9191
//! @time O(n * logφ(n) * log2(n))
9292
//! @space this function allocates/returns various vectors
9393
//! which are each O(n)
94-
vector<ll> count_paths_per_length(const vector<vi>& adj) {
94+
vector<ll> count_paths_per_length(vector<vi> adj) {
9595
vector<ll> num_paths(sz(adj));
9696
if (sz(adj) >= 2) num_paths[1] = sz(adj) - 1;
97-
edge_cd(adj,
98-
[&](const vector<vi>& cd_adj, int cent, int split) {
99-
vector<vector<double>> cnt(2, vector<double>(1));
100-
auto dfs = [&](auto&& self, int u, int p, int d,
101-
int side) -> void {
102-
if (sz(cnt[side]) == d) cnt[side].push_back(0.0);
103-
cnt[side][d]++;
104-
for (int c : cd_adj[u])
105-
if (c != p) self(self, c, u, 1 + d, side);
106-
};
107-
rep(i, 0, sz(cd_adj[cent]))
108-
dfs(dfs, cd_adj[cent][i], cent, 1, i < split);
109-
vector<double> prod = conv(cnt[0], cnt[1]);
110-
rep(i, 0, sz(prod)) num_paths[i] += llround(prod[i]);
111-
});
97+
edge_cd(adj, [&](int cent, int split) {
98+
vector<vector<double>> cnt(2, vector<double>(1));
99+
auto dfs = [&](auto&& self, int u, int p, int d,
100+
int side) -> void {
101+
if (sz(cnt[side]) == d) cnt[side].push_back(0.0);
102+
cnt[side][d]++;
103+
for (int c : adj[u])
104+
if (c != p) self(self, c, u, 1 + d, side);
105+
};
106+
rep(i, 0, sz(adj[cent]))
107+
dfs(dfs, adj[cent][i], cent, 1, i < split);
108+
vector<double> prod = conv(cnt[0], cnt[1]);
109+
rep(i, 0, sz(prod)) num_paths[i] += llround(prod[i]);
110+
});
112111
return num_paths;
113112
}
114113
int main() {

tests/library_checker_aizu_tests/handmade_tests/edge_cd_small_trees.test.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
#define PROBLEM \
22
"https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A"
33
#include "../template.hpp"
4-
#include "../edge_cd_asserts.hpp"
54
#include "../../../kactl/stress-tests/utilities/genTree.h"
65
#include "../../../library/trees/edge_cd.hpp"
76
int main() {
87
{
98
vector<vector<int>> adj;
10-
edge_cd(adj,
11-
[&](const vector<vector<int>>&, int, int) -> void {
12-
assert(false);
13-
});
9+
edge_cd(adj, [&](int, int) -> void { assert(false); });
1410
}
1511
{
1612
vector<vector<int>> adj(1);
17-
edge_cd(adj,
18-
[&](const vector<vector<int>>&, int, int) -> void {
19-
assert(false);
20-
});
13+
edge_cd(adj, [&](int, int) -> void { assert(false); });
2114
}
2215
for (int n = 2; n <= 7; n++) {
2316
int num_codes = 1;
@@ -42,6 +35,7 @@ int main() {
4235
adj[u].push_back(v);
4336
adj[v].push_back(u);
4437
}
38+
#include "../edge_cd_asserts.hpp"
4539
edge_cd(adj, edge_cd_asserts);
4640
}
4741
}

tests/library_checker_aizu_tests/trees/edge_cd_contour_range_query.test.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#define PROBLEM \
22
"https://judge.yosupo.jp/problem/vertex_add_range_contour_sum_on_tree"
33
#include "../template.hpp"
4-
#include "../edge_cd_asserts.hpp"
54
#include "../../../library/data_structures_[l,r)/bit.hpp"
65
#include "../../../library/trees/edge_cd.hpp"
76
struct sum_adj {
@@ -47,25 +46,23 @@ struct contour_range_query {
4746
//! @param a a[u] = initial number for node u
4847
//! @time O(n logφ n)
4948
//! @space O(n logφ n) for `info` and `bits`
50-
contour_range_query(const vector<vi>& adj,
51-
const vector<ll>& a):
49+
contour_range_query(vector<vi> adj, const vector<ll>& a):
5250
n(sz(a)), sum_a(adj, a), info(n) {
53-
edge_cd(adj,
54-
[&](const vector<vi>& cd_adj, int cent, int split) {
55-
vector<vector<ll>> sum_num(2, vector<ll>(1));
56-
auto dfs = [&](auto&& self, int u, int p, int d,
57-
int side) -> void {
58-
info[u].push_back({int(sz(bits)), d, side});
59-
if (sz(sum_num[side]) == d)
60-
sum_num[side].push_back(0);
61-
sum_num[side][d] += a[u];
62-
for (int c : cd_adj[u])
63-
if (c != p) self(self, c, u, 1 + d, side);
64-
};
65-
rep(i, 0, sz(cd_adj[cent]))
66-
dfs(dfs, cd_adj[cent][i], cent, 1, i < split);
67-
bits.push_back({BIT(sum_num[0]), BIT(sum_num[1])});
68-
});
51+
edge_cd(adj, [&](int cent, int split) {
52+
vector<vector<ll>> sum_num(2, vector<ll>(1));
53+
auto dfs = [&](auto&& self, int u, int p, int d,
54+
int side) -> void {
55+
info[u].push_back({int(sz(bits)), d, side});
56+
if (sz(sum_num[side]) == d)
57+
sum_num[side].push_back(0);
58+
sum_num[side][d] += a[u];
59+
for (int c : adj[u])
60+
if (c != p) self(self, c, u, 1 + d, side);
61+
};
62+
rep(i, 0, sz(adj[cent]))
63+
dfs(dfs, adj[cent][i], cent, 1, i < split);
64+
bits.push_back({BIT(sum_num[0]), BIT(sum_num[1])});
65+
});
6966
}
7067
//! @param u node
7168
//! @param delta number to add to node u's number
@@ -108,8 +105,11 @@ int main() {
108105
adj[u].push_back(v);
109106
adj[v].push_back(u);
110107
}
111-
{ edge_cd(adj, edge_cd_asserts); }
112108
contour_range_query cq(adj, a);
109+
{
110+
#include "../edge_cd_asserts.hpp"
111+
edge_cd(adj, edge_cd_asserts);
112+
}
113113
while (q--) {
114114
int type;
115115
cin >> type;

tests/library_checker_aizu_tests/trees/edge_cd_contour_range_update.test.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#define PROBLEM \
22
"https://judge.yosupo.jp/problem/vertex_get_range_contour_add_on_tree"
33
#include "../template.hpp"
4-
#include "../edge_cd_asserts.hpp"
54
#include "../../../library/data_structures_[l,r)/bit_uncommon/rupq.hpp"
65
#include "../../../library/trees/edge_cd.hpp"
76
struct sum_adj {
@@ -49,24 +48,23 @@ struct contour_range_update {
4948
//! @param a a[u] = initial number for node u
5049
//! @time O(n logφ n)
5150
//! @space O(n logφ n) for `info` and `bits`
52-
contour_range_update(const vector<vi>& adj,
51+
contour_range_update(vector<vi> adj,
5352
const vector<ll>& a):
5453
n(sz(a)), a(a), sum_a(adj, vector<ll>(n)), info(n) {
55-
edge_cd(adj,
56-
[&](const vector<vi>& cd_adj, int cent, int split) {
57-
array<int, 2> mx_d = {0, 0};
58-
auto dfs = [&](auto&& self, int u, int p, int d,
59-
int side) -> void {
60-
mx_d[side] = max(mx_d[side], d);
61-
info[u].push_back({int(sz(bits)), d, side});
62-
for (int v : cd_adj[u])
63-
if (v != p) self(self, v, u, 1 + d, side);
64-
};
65-
rep(i, 0, sz(cd_adj[cent]))
66-
dfs(dfs, cd_adj[cent][i], cent, 1, i < split);
67-
bits.push_back(
68-
{bit_rupq(mx_d[0] + 1), bit_rupq(mx_d[1] + 1)});
69-
});
54+
edge_cd(adj, [&](int cent, int split) {
55+
array<int, 2> mx_d = {0, 0};
56+
auto dfs = [&](auto&& self, int u, int p, int d,
57+
int side) -> void {
58+
mx_d[side] = max(mx_d[side], d);
59+
info[u].push_back({int(sz(bits)), d, side});
60+
for (int v : adj[u])
61+
if (v != p) self(self, v, u, 1 + d, side);
62+
};
63+
rep(i, 0, sz(adj[cent]))
64+
dfs(dfs, adj[cent][i], cent, 1, i < split);
65+
bits.push_back(
66+
{bit_rupq(mx_d[0] + 1), bit_rupq(mx_d[1] + 1)});
67+
});
7068
}
7169
//! @param u,l,r,delta add delta to all nodes v such
7270
//! that l <= dist(u, v) < r
@@ -106,8 +104,11 @@ int main() {
106104
adj[u].push_back(v);
107105
adj[v].push_back(u);
108106
}
109-
{ edge_cd(adj, edge_cd_asserts); }
110107
contour_range_update cu(adj, a);
108+
{
109+
#include "../edge_cd_asserts.hpp"
110+
edge_cd(adj, edge_cd_asserts);
111+
}
111112
while (q--) {
112113
int type;
113114
cin >> type;

tests/library_checker_aizu_tests/trees/edge_cd_count_paths_per_length.test.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#define PROBLEM \
22
"https://judge.yosupo.jp/problem/frequency_table_of_tree_distance"
33
#include "../template.hpp"
4-
#include "../edge_cd_asserts.hpp"
54
#include "../../../kactl/content/numerical/FastFourierTransform.h"
65
#include "../../../library/trees/edge_cd.hpp"
76
//! @param adj unrooted, connected tree
@@ -10,24 +9,23 @@
109
//! @time O(n * logφ(n) * log2(n))
1110
//! @space this function allocates/returns various vectors
1211
//! which are each O(n)
13-
vector<ll> count_paths_per_length(const vector<vi>& adj) {
12+
vector<ll> count_paths_per_length(vector<vi> adj) {
1413
vector<ll> num_paths(sz(adj));
1514
if (sz(adj) >= 2) num_paths[1] = sz(adj) - 1;
16-
edge_cd(adj,
17-
[&](const vector<vi>& cd_adj, int cent, int split) {
18-
vector<vector<double>> cnt(2, vector<double>(1));
19-
auto dfs = [&](auto&& self, int u, int p, int d,
20-
int side) -> void {
21-
if (sz(cnt[side]) == d) cnt[side].push_back(0.0);
22-
cnt[side][d]++;
23-
for (int c : cd_adj[u])
24-
if (c != p) self(self, c, u, 1 + d, side);
25-
};
26-
rep(i, 0, sz(cd_adj[cent]))
27-
dfs(dfs, cd_adj[cent][i], cent, 1, i < split);
28-
vector<double> prod = conv(cnt[0], cnt[1]);
29-
rep(i, 0, sz(prod)) num_paths[i] += llround(prod[i]);
30-
});
15+
edge_cd(adj, [&](int cent, int split) {
16+
vector<vector<double>> cnt(2, vector<double>(1));
17+
auto dfs = [&](auto&& self, int u, int p, int d,
18+
int side) -> void {
19+
if (sz(cnt[side]) == d) cnt[side].push_back(0.0);
20+
cnt[side][d]++;
21+
for (int c : adj[u])
22+
if (c != p) self(self, c, u, 1 + d, side);
23+
};
24+
rep(i, 0, sz(adj[cent]))
25+
dfs(dfs, adj[cent][i], cent, 1, i < split);
26+
vector<double> prod = conv(cnt[0], cnt[1]);
27+
rep(i, 0, sz(prod)) num_paths[i] += llround(prod[i]);
28+
});
3129
return num_paths;
3230
}
3331
int main() {
@@ -41,8 +39,11 @@ int main() {
4139
adj[u].push_back(v);
4240
adj[v].push_back(u);
4341
}
44-
{ edge_cd(adj, edge_cd_asserts); }
4542
vector<ll> cnt_len = count_paths_per_length(adj);
43+
{
44+
#include "../edge_cd_asserts.hpp"
45+
edge_cd(adj, edge_cd_asserts);
46+
}
4647
for (int i = 1; i < n; i++) cout << cnt_len[i] << " ";
4748
cout << '\n';
4849
return 0;

tests/library_checker_aizu_tests/trees/edge_cd_reroot_dp.test.cpp

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#define PROBLEM \
22
"https://judge.yosupo.jp/problem/tree_path_composite_sum"
33
#include "../template.hpp"
4-
#include "../edge_cd_asserts.hpp"
54
#include "../../../library/trees/edge_cd.hpp"
65
const int mod = 998244353;
76
int main() {
@@ -45,57 +44,58 @@ int main() {
4544
assert(u_low ^ v_low);
4645
return u_low ? par[u].second : par[v].second;
4746
};
48-
{ edge_cd(base_adj, edge_cd_asserts); }
49-
edge_cd(adj,
50-
[&](const vector<vi>& cd_adj, int cent,
51-
int split) -> void {
52-
array<vector<array<int, 3>>, 2> all_backwards;
53-
array<int, 2> sum_forward = {0, 0};
54-
array<int, 2> cnt_nodes = {0, 0};
55-
auto dfs = [&](auto&& self, int u, int p,
56-
array<int, 2> forwards,
57-
array<int, 2> backwards,
58-
int side) -> void {
59-
all_backwards[side].push_back(
60-
{u, backwards[0], backwards[1]});
61-
sum_forward[side] =
62-
(sum_forward[side] + 1LL * forwards[0] * a[u] +
63-
forwards[1]) %
64-
mod;
65-
cnt_nodes[side]++;
66-
for (int v : cd_adj[u]) {
67-
if (v == p) continue;
68-
int e_id = edge_id(u, v);
69-
// f(x) = ax+b
70-
// g(x) = cx+d
71-
// f(g(x)) = a(cx+d)+b = acx+ad+b
72-
array<int, 2> curr_forw = {
73-
int(1LL * forwards[0] * b[e_id] % mod),
74-
int(
75-
(1LL * forwards[0] * c[e_id] + forwards[1]) %
76-
mod)};
77-
array<int, 2> curr_backw = {
78-
int(1LL * backwards[0] * b[e_id] % mod),
79-
int((1LL * backwards[1] * b[e_id] + c[e_id]) %
80-
mod)};
81-
self(self, v, u, curr_forw, curr_backw, side);
82-
}
83-
};
84-
for (int i = 0; i < sz(cd_adj[cent]); i++) {
85-
int e_id = edge_id(cent, cd_adj[cent][i]);
86-
dfs(dfs, cd_adj[cent][i], cent, {b[e_id], c[e_id]},
87-
{b[e_id], c[e_id]}, i < split);
47+
edge_cd(adj, [&](int cent, int split) -> void {
48+
array<vector<array<int, 3>>, 2> all_backwards;
49+
array<int, 2> sum_forward = {0, 0};
50+
array<int, 2> cnt_nodes = {0, 0};
51+
auto dfs = [&](auto&& self, int u, int p,
52+
array<int, 2> forwards,
53+
array<int, 2> backwards,
54+
int side) -> void {
55+
all_backwards[side].push_back(
56+
{u, backwards[0], backwards[1]});
57+
sum_forward[side] =
58+
(sum_forward[side] + 1LL * forwards[0] * a[u] +
59+
forwards[1]) %
60+
mod;
61+
cnt_nodes[side]++;
62+
for (int v : adj[u]) {
63+
if (v == p) continue;
64+
int e_id = edge_id(u, v);
65+
// f(x) = ax+b
66+
// g(x) = cx+d
67+
// f(g(x)) = a(cx+d)+b = acx+ad+b
68+
array<int, 2> curr_forw = {
69+
int(1LL * forwards[0] * b[e_id] % mod),
70+
int((1LL * forwards[0] * c[e_id] + forwards[1]) %
71+
mod)};
72+
array<int, 2> curr_backw = {
73+
int(1LL * backwards[0] * b[e_id] % mod),
74+
int((1LL * backwards[1] * b[e_id] + c[e_id]) %
75+
mod)};
76+
self(self, v, u, curr_forw, curr_backw, side);
8877
}
89-
for (int side = 0; side < 2; side++) {
90-
for (
91-
auto [u, curr_b, curr_c] : all_backwards[side]) {
92-
res[u] =
93-
(res[u] + 1LL * curr_b * sum_forward[!side] +
94-
1LL * curr_c * cnt_nodes[!side]) %
95-
mod;
96-
}
78+
};
79+
for (int i = 0; i < sz(adj[cent]); i++) {
80+
int e_id = edge_id(cent, adj[cent][i]);
81+
dfs(dfs, adj[cent][i], cent, {b[e_id], c[e_id]},
82+
{b[e_id], c[e_id]}, i < split);
83+
}
84+
for (int side = 0; side < 2; side++) {
85+
for (
86+
auto [u, curr_b, curr_c] : all_backwards[side]) {
87+
res[u] =
88+
(res[u] + 1LL * curr_b * sum_forward[!side] +
89+
1LL * curr_c * cnt_nodes[!side]) %
90+
mod;
9791
}
98-
});
92+
}
93+
});
94+
swap(base_adj, adj);
95+
{
96+
#include "../edge_cd_asserts.hpp"
97+
edge_cd(adj, edge_cd_asserts);
98+
}
9999
for (int i = 0; i < n; i++) cout << res[i] << ' ';
100100
cout << '\n';
101101
return 0;

0 commit comments

Comments
 (0)