1 条题解

  • 1
    @ 2025-6-12 23:17:01
    /**
     *    author: 小飞侠cy
     *    created: 2025.06.12 18:39:51
     */
    #include <bits/stdc++.h>
    using namespace std;
    #define ls u << 1
    #define rs u << 1 | 1
    #define LL long long
    //#define int long long
    #define PII pair <int, int>
    #define fi first
    #define se second
    #define pub push_back
    #define pob pop_back
    #define puf push_front
    #define pof pop_front
    #define lb lower_bound
    #define ub upper_bound
    #define i128 __int128
    #define pcnt(x) __builtin_popcount(x)
    #define mem(a,goal) memset(a, (goal), sizeof(a))
    #define rep(x,start,end) for(int x = (start) - ((start) > (end)); x != (end) - ((start) > (end)); ((start) < (end) ? x ++ : x --))
    #define aLL(x) (x).begin(), (x).end()
    #define sz(x) (int)(x).size()
    const int INF = 998244353;
    const int mod = 1e9 + 7;
    const int N = 100010;
    void solve()
    {
        int n;
        cin >> n;
        string s;
        cin >> s;
        s = " " + s;
        vector<vector<int> > g(n + 1);
        for(int i = 1; i < n; ++ i)
        {
            int a, b;
            scanf("%d%d", &a, &b);
            g[a].pub(b);
            g[b].pub(a);
        }
        vector<int> c1(n + 1);
        for(int i = 1; i <= n; ++ i)
        {
            if(s[i] == 'S')
            {
                for(int t : g[i])
                {
                    if(s[t] == 'C') c1[t] ++;
                }
            }
        }
        vector<int> c2(n + 1);
        for(int i = 1; i <= n; ++ i)
        {
            if(s[i] == 'C')
            {
                for(int t : g[i])
                {
                    if(s[t] == 'C') c2[t] += c1[i];
                }
            }
        }
        vector<int> p(n + 1);
        for(int i = 1; i <= n; ++ i)
        {
            if(s[i] == 'C')
            {
                for(int t : g[i])
                {
                    if(s[t] == 'P') p[t] += c2[i];
                }
            }
        }
        LL res = 0;
        for(int i = 1; i <= n; ++ i)
        {
            if(s[i] == 'P')
            {
                for(int t : g[i])
                {
                    if(s[t] == 'C') res += p[i] - c2[t];
                }
            }
        }
        printf("%lld\n", res);
    }
    signed main()
    {
        int t = 1;
        cin >> t;
        while (t --) solve();
        return 0;
    }
    
    • 1

    信息

    ID
    5610
    时间
    5000ms
    内存
    256MiB
    难度
    10
    标签
    (无)
    递交数
    1
    已通过
    1
    上传者