题目思路
本文提供一种依赖 $a_i$ 随机生成的解决方式。
首先 $\min$ 和 $\max$ 可以拆开,原式就是:
$$
(\sum_{l \in [l_1,r_1]} \sum_{r \in [l_2,r_2]} \max_{i \in [l,r]} a_i)-(\sum_{l \in [l_1,r_1]} \sum_{r \in [l_2,r_2]} \min_{i \in [l,r]} a_i)
$$
然而把 $a$ 序列取相反数之后的 $\max$ 就是原序列的 $\min$,做两次 $\max$ 就行,这个 $\min$ 也可以先忽略不计。
那么我们其实求的就是:
$$
\sum_{l \in [l_1,r_1]} \sum_{r \in [l_2,r_2]} \max_{i \in [l,r]} a_i
$$
这个第二个 $\sum$ 可以前缀和优化。
设 $f_{lx,rx,x}$ 表示:
$$
\sum_{l \in [lx,rx]} \sum_{r \in [1,x]} \max_{i \in [l,r]} a_i
$$
就是左端点在 $[lx,rx]$ 且右端点不超过 $x$ 的和。那么 $f_{l_1,r_1,r_2}-f_{l_1,r_1,l_2-1}$ 就是我们这个询问的答案。
先把询问 $(l_1,r_1,l_2,r_2)$ 拆成 $(l_1,r_1,l_2-1,-1)$ 和 $(l_1,r_1,r_2,1)$ 这 $2$ 条询问。后面的 $-1$ 和 $1$ 表示应该减还是加。
那么这个问题我们可以对 $x$ 做扫描线处理。
具体的,我们从 $1$ 到 $n$ 扫一遍,观察到每一次修改,会在上一个比自己大的位置停下来。那么我们维护这样的单调下降的单调栈,栈内储存的是单调下降的元素的下标,设 $stk_j$ 为栈内元素。那么我们每一次会对 $(stk_{j-1},stk_j]$ 的元素造成 $a_{stk_j}$ 的贡献。
对拆出来的 $2Q$ 条询问以 $(l,r,x,sign)$ 的 $x$ 为关键字排序,维护一个指针,表示目前处理到了哪一条询问。对于 $x=i$ 的询问,我们只需要对于对应的原询问的 $ans$ 加上或减去 $[l,r]$ 部分的总和。
那么我们需要一个支持『区间加法,区间求和』的数据结构。两棵 BIT 或者一棵 SGT 即可。
这个做法的复杂度是 $\mathcal O(n\log^2 n)$。
复杂度是两只 $\log$ 而不是 $\mathcal O(n^2\log n)$ 的原因,是因为这个 $a_i$ 生成是随机的。随机数据情况下,单调栈的大小期望是 $\mathcal O(\log n)$ 级别的。
特殊处理一下如果询问 $l_2=1$ 那么会有 $l_2-1=0$ 的情况,这时候相当于没有任何需要减掉的东西,把这一部分的询问丢掉。
上面讲的是 $\max$ 的实现,把 $a_i$ 取相反数再做一次一模一样的就是 $\min$。
完整代码
#include <bits/stdc++.h>
using namespace std;
#define getchar() p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++
char buf[1000000], *p1 = buf, *p2 = buf;
template <typename T>
void read(T &x)
{
x = 0;
int f = 1;
char c = getchar();
for (; c < '0' || c > '9'; c = getchar())
if (c == '-')
f = -f;
for (; c >= '0' && c <= '9'; c = getchar())
x = x * 10 + c - '0';
x *= f;
}
template <typename T, typename... Args>
void read(T &x, Args &...y)
{
read(x);
read(y...);
}
template <class T>
void write(T x)
{
static int stk[30];
if (x < 0)
putchar('-'), x = -x;
int top = 0;
do
{
stk[top++] = x % 10, x /= 10;
} while (x);
while (top)
putchar(stk[--top] + '0');
}
template <class T>
void write(T x, char lastChar) { write(x), putchar(lastChar); }
typedef long long ll;
const int p = 1000000000;
ll pw1023 = 1, pw1025 = 1;
const int n = 100000;
int a[100020];
ll ans[40020];
struct SegTree
{
struct node
{
ll sum, lzy;
} t[100020 << 2];
#define ls id << 1
#define rs id << 1 | 1
#define Llen (mid - l + 1)
#define Rlen (r - mid)
void clear()
{
for (int i = 1; i <= 100000 << 2; i++)
t[i].sum = t[i].lzy = 0;
}
void push_up(int id) { t[id].sum = t[ls].sum + t[rs].sum; }
void push_down(int id, int l, int r)
{
int mid = l + r >> 1;
t[ls].sum += t[id].lzy * Llen;
t[rs].sum += t[id].lzy * Rlen;
t[ls].lzy += t[id].lzy;
t[rs].lzy += t[id].lzy;
t[id].lzy = 0;
}
void build(int id = 1, int l = 1, int r = n)
{
if (l == r)
return t[id].sum = 0, void();
int mid = l + r >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
push_up(id);
}
void add(int ql, int qr, ll k, int id = 1, int l = 1, int r = n)
{
if (r < ql || l > qr)
return;
if (ql <= l && r <= qr)
return t[id].lzy += k, t[id].sum += k * (r - l + 1), void();
push_down(id, l, r);
int mid = l + r >> 1;
add(ql, qr, k, ls, l, mid);
add(ql, qr, k, rs, mid + 1, r);
push_up(id);
}
ll query(int ql, int qr, int id = 1, int l = 1, int r = n)
{
if (r < ql || l > qr)
return 0;
if (ql <= l && r <= qr)
return t[id].sum;
push_down(id, l, r);
ll ans = 0;
int mid = l + r >> 1;
ans += query(ql, qr, ls, l, mid);
ans += query(ql, qr, rs, mid + 1, r);
return ans;
}
} T;
int Q;
struct query
{
int l, r, x, sign, id;
} q[80020];
int stk[100020], top;
void solve()
{
top = 0;
int j = 1;
while (j <= 2 * Q && q[j].x == 0)
j++;
for (int i = 1; i <= n; i++)
{
while (top >= 1 && a[stk[top]] <= a[i])
top--;
stk[++top] = i;
for (int j = 1; j <= top; j++)
T.add(stk[j - 1] + 1, stk[j], a[stk[j]]);
while (j <= 2 * Q && q[j].x == i)
ans[q[j].id] += q[j].sign * T.query(q[j].l, q[j].r), j++;
}
}
int main()
{
for (int i = 1; i <= n; i++)
{
(pw1023 *= 1023) %= p;
(pw1025 *= 1025) %= p;
a[i] = pw1023 ^ pw1025;
// cout << a[i] << " \n"[i == n];
}
read(Q);
for (int i = 1; i <= Q; i++)
{
int l1, r1, l2, r2;
read(l1, r1, l2, r2);
q[i * 2 - 1] = {l1, r1, l2 - 1, -1, i};
q[i * 2] = {l1, r1, r2, 1, i};
}
sort(q + 1, q + 2 * Q + 1, [&](auto x, auto y)
{ return x.x < y.x; });
solve();
for (int i = 1; i <= n; i++)
a[i] = -a[i];
T.clear();
solve();
for (int i = 1; i <= Q; i++)
write(ans[i], '\n');
return 0;
}