1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| #include <cstdio> #include <vector> #define int long long #define getl(p) (ls[p] ? ls[p] : ls[p] = ++ tot) #define getr(p) (rs[p] ? rs[p] : rs[p] = ++ tot)
inline int min(const int x, const int y) {return x < y ? x : y;} struct Edge { int to, nxt; } e[600005]; int ls[3000005], rs[3000005], v[3000005], root[300005], tot, tot2; int Dep[300005], head[300005], cnt[300005], ans[300005], n; struct Node { int id, k; }; std::vector<Node> ques[300005]; inline void AddEdge(const int u, const int v) { e[++ tot2].to = v, e[tot2].nxt = head[u], head[u] = tot2; }
void update(const int O, const int x, const int d, const int l, const int r) { if (l == r) {v[O] += d; return;} const int mid(l + r >> 1); if (x <= mid) update(getl(O), x, d, l, l + r >> 1); else update(getr(O), x, d, (l + r >> 1) + 1, r); v[O] = v[ls[O]] + v[rs[O]]; } int query(const int O, const int L, const int R, const int l, const int r) { if (L <= l && r <= R) return v[O]; const int mid(l + r >> 1); int ans(0); if (L <= mid && ls[O]) ans += query(ls[O], L, R, l, mid); if (mid < R && rs[O]) ans += query(rs[O], L, R, mid + 1, r); return ans; } void merge(int& x, const int y) { if (!x || !y) {x |= y; return;} v[x] += v[y]; merge(ls[x], ls[y]); merge(rs[x], rs[y]); } void dfs(const int u, const int fa) { Dep[u] = Dep[fa] + 1; for (int i(head[u]); i; i = e[i].nxt) if (e[i].to != fa) { const int v(e[i].to); dfs(v, u); cnt[u] += cnt[v] + 1; merge(root[u], root[v]); } for (Node& i : ques[u]) { ans[i.id] = query(root[u], Dep[u], Dep[u] + i.k, 1, n); ans[i.id] += min(Dep[u] - 1, i.k) * cnt[u]; } update(root[u], Dep[u], cnt[u], 1, n); } signed main() { int q; scanf("%lld%lld", &n, &q); for (int i(1); i < n; ++ i) { int u, v; scanf("%lld%lld", &u, &v); AddEdge(u, v), AddEdge(v, u); } for (int i(1); i <= n; ++ i) root[i] = i; tot = n; for (int i(1); i <= q; ++ i) { int p, k; scanf("%lld%lld", &p, &k); ques[p].push_back(Node{i, k}); } dfs(1, -1); for (int i(1); i <= q; ++ i) printf("%lld\n", ans[i]); }
|