Skip to content

Commit d124b1f

Browse files
committed
first draft of dsu weighted
1 parent 9307635 commit d124b1f

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

library/dsu/dsu_weighted.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
struct dsu_weighted {
3+
int n;
4+
vi p;
5+
vector<ll> d;
6+
dsu_weighted(int n): n(n), p(n, -1), d(n) {}
7+
int f(int u) {
8+
if (p[u] < 0) return u;
9+
int root = f(p[u]);
10+
d[u] += d[p[u]];
11+
return p[u] = root;
12+
}
13+
int size(int u) { return -p[f(u)]; }
14+
ll diff(int u, int v) {
15+
return f(u) == f(v) ? d[v] - d[u] : 1e18;
16+
}
17+
bool join(int u, int v, ll w) {
18+
w += d[u] - d[v];
19+
u = f(u), v = f(v);
20+
if (u == v) return 0;
21+
if (p[u] > p[v]) swap(u, v), w = -w;
22+
p[u] += p[v];
23+
p[v] = u;
24+
d[v] = w;
25+
return 1;
26+
}
27+
};
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#define PROBLEM \
2+
"https://judge.yosupo.jp/problem/unionfind_with_potential"
3+
#include "../template.hpp"
4+
#include "../../../library/dsu/dsu_weighted.hpp"
5+
#include "../../../library/dsu/dsu.hpp"
6+
const int mod = 998'244'353;
7+
int main() {
8+
cin.tie(0)->sync_with_stdio(0);
9+
int n, q;
10+
cin >> n >> q;
11+
dsu_weighted dsu_w(n);
12+
DSU dsu(n);
13+
while (q--) {
14+
int type, u, v;
15+
cin >> type >> u >> v;
16+
if (type == 0) {
17+
assert(dsu.size(u) == dsu_w.size(u));
18+
assert(dsu.size(v) == dsu_w.size(v));
19+
int w;
20+
cin >> w;
21+
ll curr_diff = dsu_w.diff(u, v);
22+
if (curr_diff == 1e18) {
23+
assert(dsu_w.join(u, v, w));
24+
cout << 1 << '\n';
25+
} else
26+
cout << ((curr_diff % mod + mod) % mod == w)
27+
<< '\n';
28+
dsu.join(u, v);
29+
assert(dsu.size(u) == dsu_w.size(u));
30+
assert(dsu.size(v) == dsu_w.size(v));
31+
} else {
32+
assert(dsu.size(u) == dsu_w.size(u));
33+
assert(dsu.size(v) == dsu_w.size(v));
34+
ll curr_diff = dsu_w.diff(u, v);
35+
if (curr_diff == 1e18) cout << -1 << '\n';
36+
else cout << (curr_diff % mod + mod) % mod << '\n';
37+
}
38+
}
39+
return 0;
40+
}

0 commit comments

Comments
 (0)