题目翻译

给定 $n$ 个点的 BST(二叉搜索树),目前点权为 $-1$ 的节点可以染色 $[1,C]$,问可能的 BST 个数。答案对 $998,244,353$ 取模。

多测,$1\leq t\leq 10^5,1\leq \sum n\leq 5\times 10^5,1\leq C\leq 10^9$。

题目思路

BST 的一个性质是中序遍历点权单调不降。

中序遍历的顺序之后,我们只需要对于相邻两对不为 $-1$ 的数(我们定为 $L$ 和 $R$)直接的 $-1$ 填充即可。填充只能填充 $[L,R]$ 的数。

那么这个问题转化为对于长度为 $k$ 的序列,求填充 $[L,R]$ 且单调不降的方案数。观察到我们只需要知道可以选的个数,不需要知道具体选什么,所以答案只与 $len=R-L+1$ 有关。也就是用 $len$ 个数填充长度为 $k$ 的序列且单调不降的方案数。

这是经典组合数题,答案是 $\binom{len+(k-1)}{k}$。我的考虑方式和别的题解不太一样,你考虑对于只能单调上升的答案是 $\binom{len}{k}$,那么你加入 $k-1$ 个决策,表示 $\forall i\in[2,k]$,第 $i$ 个位置和第 $i-1$ 个位置填一样颜色。

实现需要支持多次查询 $\binom{n}{i}$,但是 $\sum i\leq n$ 所以直接暴力做就行。

完整代码

modint 部分已经省略。

#include <bits/stdc++.h>
using namespace std;
const int p = 998244353;
using Z = mod_int<p>;
#define getchar() p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++
char buf[1000000], *p1 = buf, *p2 = buf;
template <typename T>
void read(T &x)
{
    x = 0;
    int f = 1;
    char c = getchar();
    for (; c < '0' || c > '9'; c = getchar())
        if (c == '-')
            f = -f;
    for (; c >= '0' && c <= '9'; c = getchar())
        x = x * 10 + c - '0';
    x *= f;
}
template <typename T, typename... Args>
void read(T &x, Args &...y)
{
    read(x);
    read(y...);
}
typedef long long ll;
int n;
ll c;
struct node
{
    int l, r;
    ll val;
} a[500020];
vector<int> v;
void dfs(int u)
{
    if (~a[u].l)
        dfs(a[u].l);
    v.push_back(u);
    if (~a[u].r)
        dfs(a[u].r);
}
Z C(int n, int i) // n 选 i 方案数
{
    Z ret = 1;
    for (int j = 1; j <= i; j++)
        ret *= n - j + 1;
    for (int j = 1; j <= i; j++)
        ret /= j;
    return ret;
}
Z F(int n, int L, int R) // n 个数升序,且都在 [L,R] 的方案数
{
    int len = R - L + 1;
    // 等价于 C(n + len - 1 , n)
    return C(n + len - 1, n);
}
void solve()
{
    read(n, c);
    for (int i = 1; i <= n; i++)
        read(a[i].l, a[i].r, a[i].val);
    v.clear();
    dfs(1);
    int minus = 0, l = 1;
    Z ans = 1;
    for (int i : v)
    {
        if (~a[i].val)
        {
            ans *= F(minus, l, a[i].val);
            minus = 0;
            l = a[i].val;
        }
        else
            minus++;
    }
    ans *= F(minus, l, c);
    cout << ans << '\n';
}
int main()
{
    int t;
    read(t);
    while (t--)
        solve();
    return 0;
}
最后修改:2024 年 02 月 23 日
v我50吃疯狂星期四