给定 nn 个字符集为小写字母的字符串 sis_i

一个串 tt 是可接受的,当且仅当 tt 可以表示成 p1+p2++pnp_1 + p_2 + \cdots + p_n,其中 pip_isis_i的一个子串(可以为空),++ 表示字符串的拼接

问有多少种本质不同的字符串 tt 是可接受的

答案对 109+710^ 9 + 7取模

LOJ 6071

Solution

考虑对每个串建出SAM

一个直观的想法是把每个串的本质不同子串个数直接乘起来,但是这样肯定会算重

通过观察不难发现,算重的情况就是第ii个串的某个状态的一个后缀与第i+1i+1个串的某个前缀重叠

如何不考虑这种情况呢?

考虑在SAM上dp,从第nn个串往前倒着dp。如果当前串SAM的某个节点上存在某种后继状态,那么在后面的选择中第一个字符就不能等于这个后继状态,只有这样才能保证 dp 不会算重

利用这个性质,我们就只需要存dp[c]dp[c]表示后面的串以cc开头的方案数,直接在SAM上转移即可

具体来说,先把所有节点top排序,按逆top序的顺序进行当前SAM上的dp

f[i]f[i]表示当前SAM中状态ii的答案,枚举这个点的所有后继状态,如果这个后继状态 jj 不存在,那么f[i]+=dp[j]f[i] += dp[j] ,否则f[i]+=f[j]f[i] += f[j]。第一个串根节点存的dp值即为最后的答案

Summary

SAM上dp时,如果当前状态往后有转移边的话,就强制从当前SAM中转移;否则从后面字符串的答案转移

这样就能保证dp不算重

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back

using namespace std;

typedef long long LL;
typedef pair <int, int> pii;

template <typename T> inline int Chkmax (T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline int Chkmin (T &a, T b) { return a > b ? a = b, 1 : 0; }
template <typename T> inline T read ()
{
T sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

inline void proc_status ()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) << endl;
}

const int Maxn = 1e6 + 100, Mod = 1e9 + 7;

int N, Dp[30];

inline void Add (int &a, int b) { if ((a += b) >= Mod) a -= Mod; }

namespace SAM
{
int node_cnt, last;
struct info
{
int maxlen, fa, ch[26];
} node[Maxn << 1];

inline int new_node (int pre)
{
node[++node_cnt].maxlen = node[pre].maxlen + 1;
return node_cnt;
}

inline void extend (int c)
{
int now = new_node (last), pre = last; last = now;

for (; pre && !node[pre].ch[c]; pre = node[pre].fa) node[pre].ch[c] = now;

if (!pre) node[now].fa = 1;
else
{
int x = node[pre].ch[c];
if (node[x].maxlen == node[pre].maxlen + 1) node[now].fa = x;
else
{
int y = ++node_cnt;
node[y] = node[x], node[y].maxlen = node[pre].maxlen + 1;
node[x].fa = node[now].fa = y;
for (; node[pre].ch[c] == x; pre = node[pre].fa) node[pre].ch[c] = y;
}
}
}

inline void init ()
{
for (int i = 1; i <= node_cnt; ++i)
{
memset(node[i].ch, 0, sizeof node[i].ch);
node[i].maxlen = node[i].fa = 0;
}
node_cnt = last = 1;
}

inline void build (string S)
{
init ();
for (int i = 0; i < S.length(); ++i) extend (S[i] - 'a');
}

int f[Maxn << 1], P[Maxn << 1], deg[Maxn << 1];

inline void top_sort ()
{
static queue <int> Q;

for (int i = 1; i <= node_cnt; ++i) deg[i] = 0;

for (int i = 1; i <= node_cnt; ++i)
for (int c = 0; c < 26; ++c)
if (node[i].ch[c]) ++deg[node[i].ch[c]];

for (int i = 1; i <= node_cnt; ++i) if (!deg[i]) Q.push(i);

int cnt = 0;
while (!Q.empty())
{
int x = Q.front(); Q.pop();
P[++cnt] = x;
for (int c = 0; c < 26; ++c)
if (node[x].ch[c])
{
int y = node[x].ch[c];
--deg[y];
if (!deg[y]) Q.push(y);
}
}
}

inline void solve ()
{
top_sort();

for (int i = node_cnt; i >= 1; --i)
{
int x = P[i];
f[x] = 1;

for (int c = 0; c < 26; ++c)
if (node[x].ch[c]) Add (f[x], f[node[x].ch[c]]);
else Add (f[x], Dp[c]);
}

for (int c = 0; c < 26; ++ c) if (node[1].ch[c]) Dp[c] = f[node[1].ch[c]];
}
}

string S[Maxn];

inline void Solve ()
{
for (int i = N; i >= 1; --i)
{
SAM :: build (S[i]);
SAM :: solve ();
}

int ans = 1;
for (int i = 0; i < 26; ++i) Add (ans, Dp[i]);
printf("%d\n", ans);
}

inline void Input ()
{
N = read<int>();
for (int i = 1; i <= N; ++i) cin >> S[i];
}

int main()
{

#ifndef ONLINE_JUDGE
freopen("name.in", "r", stdin);
freopen("name.out", "w", stdout);
#endif

Input ();
Solve ();

return 0;
}