其他分享
首页 > 其他分享> > K-Set Tree (树的节点贡献+组合数+减法思维)(codeforce 795)

K-Set Tree (树的节点贡献+组合数+减法思维)(codeforce 795)

作者:互联网

F. K-Set Tree
time limit per test3 seconds
memory limit per test512 megabytes
inputstandard input
outputstandard output
You are given a tree G with n vertices and an integer k. The vertices of the tree are numbered from 1 to n.

For a vertex r and a subset S of vertices of G, such that |S|=k, we define f(r,S) as the size of the smallest rooted subtree containing all vertices in S when the tree is rooted at r. A set of vertices T is called a rooted subtree, if all the vertices in T are connected, and for each vertex in T, all its descendants belong to T.

You need to calculate the sum of f(r,S) over all possible distinct combinations of vertices r and subsets S, where |S|=k. Formally, compute the following:
∑r∈V∑S⊆V,|S|=kf(r,S),
where V is the set of vertices in G.

Output the answer modulo 109+7.

Input
The first line contains two integers n and k (3≤n≤2⋅105, 1≤k≤n).

Each of the following n−1 lines contains two integers x and y (1≤x,y≤n), denoting an edge between vertex x and y.

It is guaranteed that the given edges form a tree.

Output
Print the answer modulo 109+7.

Examples
inputCopy
3 2
1 2
1 3
outputCopy
25
inputCopy
7 2
1 2
2 3
2 4
1 5
4 6
4 7
outputCopy
849
Note
The tree in the second example is given below:


We have 21 subsets of size 2 in the given tree. Hence,
S∈{{1,2},{1,3},{1,4},{1,5},{1,6},{1,7},{2,3},{2,4},{2,5},{2,6},{2,7},{3,4},{3,5},{3,6},{3,7},{4,5},{4,6},{4,7},{5,6},{5,7},{6,7}}.
And since we have 7 vertices, 1≤r≤7. We need to find the sum of f(r,S) over all possible pairs of r and S.

Below we have listed the value of f(r,S) for some combinations of r and S.

r=1, S={3,7}. The value of f(r,S) is 5 and the corresponding subtree is {2,3,4,6,7}.
r=1, S={5,4}. The value of f(r,S) is 7 and the corresponding subtree is {1,2,3,4,5,6,7}.
r=1, S={4,6}. The value of f(r,S) is 3 and the corresponding subtree is {4,6,7}.
View problem

思路:

核心:就是 组合数减法,可以让最小公共祖先成立

#include <bits/stdc++.h>
using namespace std;
#define ri register int 
#define M  200005

template <class G> void read(G &x)
{
    x=0;int f=0;char ch=getchar();
    while(ch<'0'||ch>'9'){f|=ch=='-';ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    x=f?-x:x;
    return ;
} 
long long ans=0;
const int mod=1e9+7;
long long  n,m;
vector<int> p[M];
int vis[M];
int sz[M];
long long inv[M],arr[M];
long long qsn(long long a,int n)
{
    long long ans=1;
    while(n)
    {
        if(n&1) ans=ans*a%mod;
        n>>=1;a=a*a%mod;
    }
    return ans;
}
long long zh(long long a,long long b)
{
    if(a==b||b==0) return 1;
    if(a<b||a==0) return 0;
    return arr[a]*inv[a-b]%mod*inv[b]%mod;
}
void init(){
    
    arr[0]=1;inv[0]=1;
    for(ri i=1;i<=n;i++)
    {
        arr[i]=i*arr[i-1]%mod;
        inv[i]=qsn(arr[i],mod-2);
    }
    
}
void dfs1(int a)
{
    vis[a]=1;
    sz[a]++;
    for(ri i=0;i<p[a].size();i++)
    {
        int b=p[a][i];
        if(vis[b]) continue;
        dfs1(b);
        sz[a]+=sz[b];
    }
    return ;
}


void dfs2(int a)
{
    vis[a]=1;
    ans=(ans+zh(n,m)*n%mod)%mod; // geng
    long long tmp=0;
    for(ri i=0;i<p[a].size();i++)
    {
        int b=p[a][i];
        if(vis[b]==0)
        {
            ans=(ans-zh(sz[b],m)*n%mod+mod)%mod;
            tmp=(tmp+zh(sz[b],m))%mod;
        }
        else
        {
            ans=(ans-zh(n-sz[a],m)*n%mod+mod)%mod;
            tmp=(tmp+zh(n-sz[a],m))%mod;
        }
    }
    for(ri i=0;i<p[a].size();i++)
    {
        int b=p[a][i];
        if(vis[b]==0)
        {
            ans=(ans+zh(n-sz[b],m)*(n-sz[b])%mod*sz[b])%mod;
            ans=(ans-(tmp-zh(sz[b],m))*(n-sz[b])%mod*sz[b]%mod+mod)%mod;
        }
        else
        {
            ans=(ans+zh(sz[a],m)*(sz[a])%mod*(n-sz[a]))%mod;
            ans=(ans-(tmp-zh(n-sz[a],m))*(sz[a])%mod*(n-sz[a])%mod+mod)%mod;
        }
    }
    for(ri i=0;i<p[a].size();i++)
    {
        int b=p[a][i];
        if(vis[b]) continue;
        dfs2(b);
    }
    
    return ;
    
}
int main(){
    
    read(n);
    read(m);
    for(ri i=1;i<n;i++)
    {
        int a,b;
        read(a);read(b);
        p[a].push_back(b);
        p[b].push_back(a);
    }
    init();
    dfs1(1);
    memset(vis,0,sizeof(vis));
    dfs2(1);
    printf("%lld",ans);
    return 0;
    
    
}
View Code

 

标签:795,ch,int,tree,Tree,vertices,long,Set,size
来源: https://www.cnblogs.com/Lamboofhome/p/16343653.html