其他分享
首页 > 其他分享> > [NOIP2021] 数列

[NOIP2021] 数列

作者:互联网

洛谷题面

感觉这道题纯动态规划的边界等问题非常麻烦,所以这里采用记忆化搜索。

题目大意

给出 \(n,m,k\) 及 \(val_0\cdots val_m\),定义一个值 \(\in [0,m]\) 的序列 \(a\),其权值为 \(\prod\limits_{i=1}^{n} val_{a_i}\)

我们称 \(S\) 满足条件当且仅当 \(S=\sum\limits_{i=1}^{n} 2^{a_i}\) 的二进制表示中,\(1\) 的个数小于等于 \(k\)。此时,也称序列 \(a\) 为合法序列。

求所有合法序列 \(a\) 的权值和 \(\mod 998244353\) 的结果。

题目分析

令 \(dfs(bit,now,x,y)\) 表示:

\(S\) 从低到高二进制的 \(bit\) 位中,用了序列 \(a\) 的前 \(now\) 个数,此时 \(S\) 二进制下有 \(x\) 个 \(1\),上一位(第 \(bit+1\) 位)进位为 \(y\)。

\(mem[biw][now][x][y]\) 则储存答案。

于是,我们有:

\[mem[bit][now][x][y]=\sum\limits_{i=0}^{n-now}{mem[bit][now+i][x+(y+i)\%2][\left\lfloor\frac{y+i}{2}\right\rfloor])\times sum[bit][i]\times C_{now+i}^{i}} \]

其中 \(C_{i}^{j}\) 表示组合数,\(sum[i][j]\) 表示:

for(register int i=0;i<=m;i++)
{
	sum[i][0]=1;
		
	for(register int j=1;j<=n;j++)
	{
		sum[i][j]=sum[i][j-1]*val[i]%mod;
	}
}

可以看到,\(sum[i][j]\) 主要作用类似于前缀和,目的是简化计算。


边界部分:

当前转移到 \(dfs(bit,now,x,y)\)。

当 \(x+getcnt(y)>k\) 时,返回 \(0\)。表示不需要继续转移了。

否则返回 \(1\)。

代码

//2021/11/30

//2021/12/1

//2021/12/2

#define _CRT_SECURE_NO_WARNINGS

#include <iostream>

#include <cstdio>

#include <climits>//need "INT_MAX","INT_MIN"

#include <cstring>

#define int long long

#define enter() putchar(10)

#define debug(c,que) cerr<<#c<<" = "<<c<<que

#define cek(c) puts(c)

#define blow(arr,st,ed,w) for(register int i=(st);i<=(ed);i++)cout<<arr[i]<<w;

#define speed_up() cin.tie(0),cout.tie(0)

#define endl "\n"

#define Input_Int(n,a) for(register int i=1;i<=n;i++)scanf("%d",a+i);

#define Input_Long(n,a) for(register long long i=1;i<=n;i++)scanf("%lld",a+i);

namespace Newstd
{
	inline int read()
	{
		int x=0,k=1;
		char ch=getchar();
		while(ch<'0' || ch>'9')
		{
			if(ch=='-')
			{
				k=-1;
			}
			ch=getchar();
		}
		while(ch>='0' && ch<='9')
		{
			x=(x<<1)+(x<<3)+ch-'0';
			ch=getchar();
		}
		return x*k;
	}
	inline void write(int x)
	{
		if(x<0)
		{
			putchar('-');
			x=-x;
		}
		if(x>9)
		{
			write(x/10);
		}
		putchar(x%10+'0');
	}
}

using namespace Newstd;

using namespace std;

const int mod=998244353;

const int MA_1=105;

const int MA_2=35;

int val[MA_1];

int C[MA_1][MA_1],sum[MA_1][MA_1];

int mem[MA_1][MA_2][MA_2][MA_2]; 

int n,m,k;

inline void init()
{
	C[0][0]=1;
	
	for(register int i=1;i<=n;i++)
	{
		C[i][0]=1;
		
		for(register int j=1;j<=i;j++)
		{
			C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod; 
		}
	}
}

inline int lowbit(int x)
{
	return x&-x;
}

inline int getcnt(int x)
{
	int ans(0);
	
	while(x!=0)
	{
		x-=lowbit(x);
		
		ans++;
	}
	
	return ans;
}

//dfs(k,now,x,y)
//S从低到高二进制的 bit 位中,用了数列 a 的前 now 项,且此时 S 中共有 x 个二进制位为 1,第 now+1 位进了 y 过去 
inline int dfs(int bit,int now,int x,int y)
{
	if(now==n)
	{
		if(x+getcnt(y)>k)
		{
			return 0;
		}
		
		return 1;
	}
	
	if(bit>m)
	{
		return 0;
	}
	
	if(mem[bit][now][x][y]!=-1)
	{
		return mem[bit][now][x][y];
	}
	
	int ans(0);
	
	for(register int i=0;i<=n-now;i++)
	{
		ans=(ans+dfs(bit+1,now+i,x+(y+i)%2,(y+i)/2)*sum[bit][i]%mod*C[now+i][i]%mod)%mod; 
	}
	
	return mem[bit][now][x][y]=ans;
}

#undef int

int main(void)
{
	#define int long long
	
	memset(mem,-1,sizeof(mem));
	
	n=read(),m=read(),k=read();
	
	init();
	
	for(register int i=0;i<=m;i++)
	{
		val[i]=read();
	}
	
	for(register int i=0;i<=m;i++)
	{
		sum[i][0]=1;
		
		for(register int j=1;j<=n;j++)
		{
			sum[i][j]=sum[i][j-1]*val[i]%mod;
		}
	}
	
	printf("%lld\n",dfs(0,0,0,0));
	
	return 0;
}

标签:now,MA,数列,int,sum,mem,bit,NOIP2021
来源: https://www.cnblogs.com/Coros-Trusds/p/15632869.html