#背包dp #组合数学 #二项式定理 #幂 ## 解法 首先观察到极大的 $N$ $(1\le N\le 2\times 10^5)$,这似乎让我们无从下手,但是不难发现 $K$ 极小 $(1\le K\le 10)$,这不禁让我们想到 $K$ 是一个关键的突破口。 考虑 dp。令 $dp_{i,j}$ 表示为以第 $i$ 个元素为结尾的所有区间的的和的 $k$ 次方之和,即 $ dp_{i,j}=\sum_{k=1}^i\left(\sum_{l=k}^i a_l\right)^K $ 每次加入一个元素 $a_i$ 时,假设之前的和为 $S$,我们可以利用二项式定理 $ (a+b)^k=\sum\limits_{i=0}^k\binom{k}{i}a^i\cdot b^{k-i} $ 展开 $(S+a_i)^k$,即 $ (S+a_i)^k=\sum\limits_{m=0}^k\binom{k}{m}S^m\cdot a_i^{k-m} $ 可以发现这个式子里面的 $\binom{k}{m}$ 和 $a_i^{k-m}$ 都是可以通过预处理提前计算得到,而 $S^m$ 就是 $dp_{(i-1,m)}$。通过这种方式,我们可以递推得到 dp 数组。最后 $\sum\limits_{i=1}^N dp_{i,k}$ 就是答案了。注意取模。 ## 复杂度分析 ### 时间复杂度 读入 $\mathcal O(N)$,预处理组合数 $\mathcal O(K^2)$,预处理 $a_i^k$ $\mathcal O(NK)$,dp 过程 $\mathcal O(NK^2)$,总时间复杂度 $\mathcal O(NK^2)$。本题下 $1\le N\le 2\times 10^5,1\le K\le 10$,因此运算量级大概是 $2\times 10^7$ 级别的,可以通过。如果不预处理 $a_i^k$ 的话,时间复杂度 $\mathcal O(NK^2\log_2 K)$。 ### 空间复杂度 $A$ 数组存储 $\mathcal O(N)$,组合数数组 $\mathcal O(K^2)$,$a_i^k$ 数组 $\mathcal O(NK)$,dp 数组 $\mathcal O(NK)$,总空间复杂度 $\mathcal O(NK)$。如果不处理 $a_i^k$ 数组并对 dp 数组进行滚动数组优化的话空间复杂度为 $\mathcal O(N+K^2)$。 ## 代码 采用滚动数组优化 dp。 ```cpp #include <iostream> #include <vector> using namespace std; using LL = long long; const LL kMod = 998244353; LL n, K, ans; int main() { cin.tie(0)->sync_with_stdio(0); // 读入 cin >> n >> K; vector<LL> a(n + 1); for (LL i = 1; i <= n; i++) { cin >> a[i]; } // 预处理组合数 vector<vector<LL>> C(K + 1, vector<LL>(K + 1)); for (LL k = 0; k <= K; k++) { C[k][0] = 1; for (LL m = 1; m <= k; m++) { C[k][m] = (C[k - 1][m - 1] + C[k - 1][m]) % kMod; } } // 预处理 a[i]^k vector<vector<LL>> powA(n + 1, vector<LL>(K + 1)); for (LL i = 1; i <= n; i++) { powA[i][0] = 1; LL x = a[i] % kMod; for (LL k = 1; k <= K; k++) { powA[i][k] = powA[i][k - 1] * x % kMod; } } // dp 过程及统计答案 vector<LL> dp(K + 1); for (LL i = 1; i <= n; i++) { vector<LL> f(K + 1); for (LL k = 0; k <= K; k++) { for (LL m = 0; m <= k; m++) { LL trm = C[k][m] * powA[i][k - m] % kMod, prv = dp[m]; if (m == 0) { // 注意对零次方的特殊判断 prv = (prv + 1) % kMod; } f[k] = (f[k] + trm * prv) % kMod; } } dp = f; ans = (ans + dp[K]) % kMod; } cout << ans << "\n"; return 0; } ```