ABC147F Sum Difference


题解:

记有$N$个数,起始为$s$,公差为$d$,高桥那有$n$个数,有$k_1$份$d$,青木那有$m$个数,$k_2份d$

显然的,我们可以将高桥的值设为$(ns+k1d)$,将青木的值设为$(ms+k2d)$

则$\Delta=(n-m)s+(k1-k2)d$

又因为$n+m=N$,$k_1+k_2=\sum\limits_{i=0}^{N-1}i=N*(N-1)/2$

所以$\Delta=(2n-N)s+(2k_1-N(N-1)/2)*d$

也就是说$\Delta$只与$ns+k_1d$相关,其它的都是些常数

并且,其实$k_1\in[\sum\limits_{i=0}^{n-1}i,\sum\limits_{i=N-n}^{N-1}i]$,区间内的所有整数$k_1$都可以取到,这一点很容易证明

因此现在问题变成了给出N+1个形如$\{ ns+k_1d |k_1\in[\sum\limits_{i=0}^{n-1}i,\sum\limits_{i=N-n}^{N-1}i] \}$的等差数列,求它们并集的大小

注意到公差都是d,这说明两个不同的等差数列是否有交集只取决于他们的起点%d的值,我们记起点$n*s\equiv p (\mod d)$,则可以将p值相同的等差数列放在一起处理

此时p值相同的等差数列又好比一条条线段,让你求它们所覆盖的总点数

这就十分好处理了,先记下每条线段的左右端点,按照左右端点从左到右再排个序,维护左右端点极值的同时分类讨论,扫一遍即可


值得注意的是,因为算法中用到了模d,而d的值域是包括0的,所以还需要特判d==0:

  • s==0则直接输出1
  • s!=0则直接输出n+1

在我的代码中会有不少冗余的部分,主要是为了提供一个比较容易理解的思路

另外,代码中并没有特意算出p值,具体处理可以见注释


代码:

#include <bits/stdc++.h>
using namespace std;
template<class t> inline t read(t &x){
    x=0;char c=getchar();bool f=0;
    while(!isdigit(c)) f|=c=='-',c=getchar();
    while(isdigit(c)) x=(x<<1)+(x<<3)+(c^48),c=getchar();
    if(f) x=-x;return  x;
}
template<class t> inline void write(t x){
    if(x<0){putchar('-'),write(-x);}
    else{if(x>9)write(x/10);putchar('0'+x%10);}
}

#define int long long

const int N=2e5+5;
int ans,n,s,d;
/*该结构体记录一个精简后等差数列的信息*/
struct seg{
    int i,l,r; //i是n,l是左端点,r是右端点
    inline bool operator < (const seg &nt) const {
        int d1=i%d*s%d,d2=nt.i%d*s%d; //d1和d2是两个等差数列的p值
        if(d1^d2) return d1<d2; //将p值作为第一关键字
        if(l^nt.l) return l<nt.l; //然后将左端点从左向右排(其实好像并不用考虑相同的l和r)
        return r>nt.r;
    }
}f[N];
/*计算l+l+1+...+r-1+r*/
int calc(int l,int r){
    return (l+r)*(r-l+1)/2;
}

signed main(){
    read(n);read(s);read(d);
    if(d==0){ //特判
        if(s==0) write(1);
        else write(n+1);
        return 0;
    }
    for(int i=0;i<=n;i++) f[i]=(seg){i,i*s/d+calc(0,i-1),i*s/d+calc(n-i,n-1)}; //计算等差数列的各种信息,注意n的枚举是从0到n
    sort(f,f+1+n);
    for(int i=0,j;i<=n;i=j+1){
        j=i; //区间[i,j]内的等差数列的p值都是相等的
        while(j+1<=n&&f[j+1].i%d*s%d==f[i].i%d*s%d) j++;
        int l=f[i].l,r=f[i].r; //l和r是已访问过最右独立线段并的左右端点
        ans+=r-l+1;
        for(int k=i+1;k<=j;k++){
            if(f[k].r<=r) continue; //被当前区间所包括,可以直接跳过不处理
            if(f[k].l<=r){ //与当前区间有交集,更新答案和区间右边界
                ans-=r-l+1; //删除原贡献
                r=f[k].r;
                ans+=r-l+1; //加上新贡献
            }
            else{ //无交集,重新设置左右边界
                l=f[k].l;
                r=f[k].r;
                ans+=r-l+1; //加上新贡献
            }
        }
    }
    write(ans);
}

fighter