给你一颗 nn 个点的树 , 每个节点有两个权值 ai,bia_i,b_i

uu 跳到 vv 的代价是 au×bva_u\times b_v你需要计算每个节点跳到叶子的最小代价.

n105;105ai,bi105n\le 10^5;-10^5\le a_i, b_i\le 10^5

CF932F

Solution

dpidp_iii跳到叶子的最小代价,枚举uu的后代vv,有:

dpu=minv{a[u]×b[v]+dpv} \displaystyle dp_u = \min_v \{a[u] \times b[v] + dp_v\}

这个斜率的形式很标准,直接上李超线段树即可

并且很显然需要线段树合并

李超线段树合并就是把另外一棵线段树的每个优势线段都插到当前的线段树中即可,是O(nlog2n)O(n\log^2n)的复杂度

这道题下标可能是负数,需要注意线段树的midmid要向下取整(默认强转是向00取整)


第一次写李超线段树,按照自己想法写的

没想到一遍就过了,没有任何精度问题,并且代码也很简洁

在判断向线段树哪个儿子走的时候,直接算l,rl,r这两个端点处哪个线段更优即可

因为若两直线在某区间内有交,那么它们在两端点处的大小关系一定不同

完全避免了计算交点的精度问题

insert部分代码:

1
2
3
4
5
6
7
8
inline void insert (int &o, int l, int r, line now)
{
if (!o) o = ++node_cnt;
if (node[o].get_val (mid) > now.get_val (mid)) swap (node[o].seg, now);
if (l == r || !now.id) return ;
if (node[o].get_val (l) > now.get_val (l)) insert (lson, now);
if (node[o].get_val (r) > now.get_val (r)) insert (rson, now);
}

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back
#define DEBUG(x) cout << #x << " = " << x << endl;

using namespace std;

typedef long long LL;
typedef pair <int, int> pii;

template <typename T> inline int Chkmax (T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline int Chkmin (T &a, T b) { return a > b ? a = b, 1 : 0; }
template <typename T> inline T read ()
{
T sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

inline void proc_status ()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) << endl;
}

const int Maxn = 1e5 + 100;
const int LIM = 1e5 + 5;
const LL inf = 1e18;

int N, A[Maxn], B[Maxn];
int e, Begin[Maxn], To[Maxn << 1], Next[Maxn << 1];

inline void add_edge (int x, int y) { To[++e] = y; Next[e] = Begin[x]; Begin[x] = e; }

struct line
{
int id;
LL k, b;
inline LL get_val (LL x) { if (!id) return inf; return k * x + b; }
};

namespace SEG
{
#define mid (floor(1.0 * (l + r) / 2))
#define ls node[o].ch[0]
#define rs node[o].ch[1]
#define lson ls, l, mid
#define rson rs, mid + 1, r

struct info
{
int ch[2];
line seg;
inline LL get_val (LL x) { return seg.get_val (x); }
} node[Maxn * 60];
int node_cnt;

inline void insert (int &o, int l, int r, line now)
{
if (!o) o = ++node_cnt;
if (node[o].get_val (mid) > now.get_val (mid)) swap (node[o].seg, now);

if (l == r || !now.id) return ;

if (node[o].get_val (l) > now.get_val (l)) insert (lson, now);
if (node[o].get_val (r) > now.get_val (r)) insert (rson, now);
}

inline void merge (int x, int &o, int l, int r)
{
if (!x || !o) return void(o = x | o);
insert (o, l, r, node[x].seg);
merge (node[x].ch[0], lson);
merge (node[x].ch[1], rson);
}

inline LL query (int o, int l, int r, int x)
{
if (!o) return inf;
if (l == r) return node[o].seg.get_val (x);
if (x <= mid) return min (query (lson, x), node[o].get_val (x));
return min (query (rson, x), node[o].get_val (x));
}

#undef mid
}

int O[Maxn];
LL Dp[Maxn];

inline void dfs (int x, int f)
{
int fl = 0;
for (int i = Begin[x]; i; i = Next[i])
{
int y = To[i];
if (y == f) continue;
fl = 1;
dfs (y, x);
SEG :: merge (O[y], O[x], -LIM, LIM);
}
if (fl) Dp[x] = SEG :: query (O[x], -LIM, LIM, A[x]);
SEG :: insert (O[x], -LIM, LIM, (line){x, B[x], Dp[x]});
}

inline void Solve ()
{
dfs (1, 0);
for (int i = 1; i <= N; ++i) printf("%lld ", Dp[i]);
}

inline void Input ()
{
N = read<int>();
for (int i = 1; i <= N; ++i) A[i] = read<int>();
for (int i = 1; i <= N; ++i) B[i] = read<int>();
for (int i = 1; i < N; ++i)
{
int x = read<int>(), y = read<int>();
add_edge (x, y);
add_edge (y, x);
}
}

int main()
{

#ifndef ONLINE_JUDGE
freopen("F.in", "r", stdin);
freopen("F.out", "w", stdout);
#endif

Input ();
Solve ();

return 0;
}