1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
| #include <cstdio> #include <algorithm> #include <cmath> #include <vector> #include <cstring> #define gc (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 65536, stdin), p1 == p2) ? EOF : *p1 ++)
typedef long long ll; char buf[65536], *p1, *p2; inline int read() { char ch; int x(0); while ((ch = gc) < 48); do x = x * 10 + ch - 48; while ((ch = gc) >= 48); return x; } int a[100005], A[100005], cnt[100005][2], In[100005], Out[100005], P[100005]; int Now, S, tot, root(1); ll g[100005], Q[500005], ans; std::vector<int> sons[100005], dfn[100005]; inline bool chk(int u) {return In[u] < In[root] && Out[root] <= Out[u];} inline int getance(int u) { return P[*(std::upper_bound(dfn[u].begin(), dfn[u].end(), In[root]) - 1)]; } struct Question { int op, id, l, r; inline bool operator < (const Question a) const { int x((l - 1) / S), y((a.l - 1) / S); return x == y ? (r - 1) / S < (a.r - 1) / S : x < y; } } q[2000005];
inline void add(const int x, const int d) { ans += cnt[x][d ^ 1], ++ cnt[x][d]; } inline void del(const int x, const int d) { ans -= cnt[x][d ^ 1], -- cnt[x][d]; } void dfs(const int u, const int fa) { P[In[u] = ++ Now] = u; for (int v : sons[u]) if (v != fa) dfs(v, u), dfn[u].push_back(In[v]); Out[u] = Now; }
int main() { int n, m, l(0), r(0), qid(0); n = read(), m = read(); S = n / sqrt(m); for (int i(1); i <= n; ++ i) A[i] = a[i] = read(); std::sort(a + 1, a + n + 1); for (int i(1); i <= n; ++ i) A[i] = std::lower_bound(a + 1, a + n + 1, A[i]) - a; for (int i(1); i < n; ++ i) { int u(read()), v(read()); sons[u].push_back(v), sons[v].push_back(u); } dfs(1, -1); for (int i(1); i <= n; ++ i) ++ cnt[a[In[i]] = A[i]][0]; for (int i(1); i <= n; ++ i) g[i] = g[i - 1] + cnt[a[i]][0]; memset(cnt, 0, sizeof cnt); for (int i(1); i <= m; ++ i) { int op(read()); if (op == 1) root = read(); else { ++ qid; int u(read()), v(read()); if (!chk(u) && chk(v)) std::swap(u, v); for (int j(1); j <= 4; ++ j) q[tot + j].id = qid; if (!chk(u) && !chk(v)) { int l1(In[u]), r1(Out[u]), l2(In[v]), r2(Out[v]); if (u == root) l1 = 1, r1 = n; if (v == root) l2 = 1, r2 = n; q[++ tot].op = 1, q[tot].l = r1, q[tot].r = r2; q[++ tot].op = -1, q[tot].l = l1 - 1, q[tot].r = r2; q[++ tot].op = -1, q[tot].l = l2 - 1, q[tot].r = r1; q[++ tot].op = 1, q[tot].l = l1 - 1, q[tot].r = l2 - 1; } else if (chk(u) && !chk(v)) { int t(getance(u)); int x(In[t] - 1), y(Out[t] + 1), l(In[v]), r(Out[v]); if (v == root) l = 1, r = n; Q[qid] = g[r] - g[l - 1]; q[++ tot].op = 1, q[tot].l = x, q[tot].r = r; q[++ tot].op = -1, q[tot].l = l - 1, q[tot].r = x; q[++ tot].op = 1, q[tot].l = y - 1, q[tot].r = l - 1; q[++ tot].op = -1, q[tot].l = y - 1, q[tot].r = r; } else { int t1(getance(u)), t2(getance(v)); int x0(In[t1] - 1), x1(In[t2] - 1), y0(Out[t1] + 1), y1(Out[t2] + 1); Q[qid] = g[x0] + g[x1] - g[y1 - 1] - g[y0 - 1] + g[n]; q[++ tot].op = 1, q[tot].l = x0, q[tot].r = x1; q[++ tot].op = -1, q[tot].l = y1 - 1, q[tot].r = x0; q[++ tot].op = -1, q[tot].l = y0 - 1, q[tot].r = x1; q[++ tot].op = 1, q[tot].l = y0 - 1, q[tot].r = y1 - 1; } } } for (int i(1); i <= tot; ++ i) if (q[i].l > q[i].r) std::swap(q[i].l, q[i].r); std::sort(q + 1, q + tot + 1); for (int i(1); i <= tot; ++ i) { while (l < q[i].l) add(a[++ l], 0); while (l > q[i].l) del(a[l --], 0); while (r < q[i].r) add(a[++ r], 1); while (r > q[i].r) del(a[r --], 1); Q[q[i].id] += q[i].op * ans; } for (int i(1); i <= qid; ++ i) printf("%lld\n", Q[i]); return 0; }
|