§ Binary search to find rightmost index which does not possess some property
if (p(1 << NBITS) == 0) { return 1 << NBITS; }
else {
assert(p(1<<NBITS) == 1);
int ans = 0;
for (int i = NBITS-1; i >= 0; i--) {
int k = 1 << i;
assert(p(ans + 2*k) == 1);
if (p(ans + k) == 0) {
ans = ans + k;
}
}
}
- Claim 1: (Correctness)
p(ans[i]) = 0
. By precondition, this is true before the loop. See that it's a loop invariant, as we only update ans[i]
to ans[i]+k
if p(ans[i]+k) = 0
. Thus, is is true after the loop.
- Claim 2: (Maximality) : At loop iteration
i
: p(ans[i] + 2k[i]) = 1
. We cannot improve our solution by using previous jump lengths.
This implies optimality once the loop ends. At the end of the loop we have i = -1
.
So:
2k[-1] = 2(1/2) = 1
finalans = ans[-1]
---
p(ans[-1] + 2k[-1]) = 1
=> p(finalans+1) = 1
- Proof of Claim 2: induction on
i
- Suppose claim 2 is true till index
i
: p(ans[i] + 2k[i]) = 1
. - To prove: induction hypothesis holds at index
(i-1)
. - Case analysis based on loop body at
i
: p(ans[i] + k[i]) = 0 or 1
- (a)
p(ans[i] + k[i]) = 0
. We update ans[i-1] = ans[i] + k[i]
. - We wish to show that the loop invariant holds at
i-1
: p(ans[i-1]+2k[i-1]) == 1
.
k value: k[i]=2i(k-1) value: k[i−1]=2i−1=2k[i]Ind: p(ans[i]+2k[i])=0Case (a): p(ans[i]+k[i])=0Update: ans[i−1]≡ans[i]+k[i]p(ans[i−1]+2k[i−1])=p((ans[i]+k[i])+2k[i−1])=p(ans[i]+k[i]+k[i])=p(ans[i]+2k[i])=1 (By Induction Hyp.)
- We've shown that the induction hypothesis hold at index (i−1) in case (a) where we update the value of ans[i].
- (b) If
p(ans[i] + k[i]) = 1
, then we update ans[i-1] = ans[i]
. - We wish to show that the loop invariant holds at
i-1
: p(ans[i-1]+2k[i-1]) ==1
.
k value: k[i]=2i(k-1) value: k[i−1]=2i−1=2k[i]Ind: p(ans[i]+2k[i])=0Case (b): p(ans[i]+k[i])=1Update: ans[i−1]≡ans[i]p(ans[i−1]+2k[i−1])=p(ans[i]+2k[i−1])=p(ans[i]+k[i]+k[i])=p(ans[i]+2k[i])=1 (By Induction Hyp.)
- We've shown that the induction hypothesis hold at index (i−1) in case (b) where we don't change the value of ans[i].
- In summary, the loop invariant is held at index (i−1) assuming the loop invariant is satisfied at index (i), for both updates of ans[i]. Thus, by induction, the loop invariant holds for all iterations.
- Elaborated proof of why
p(ans[0]+1) = 1
at the end of the loop
See that we can insert a new invaraiant at the end of the loop which asserts p(ans[i]+k[i]) == 1
:
if (p(1 << nbits) == 0) { return 1 << nbits; }
else {
assert(p(1<<nbits) == 1);
int ans = 0;
for (int i = nbits-1; i >= 0; i--) {
int k = 1 << i;
assert(p(ans + 2*k) == 1);
int ans2;
if (p(ans + k) == 0) {
ans2 = ans + k;
} else {
ans2 = ans;
}
assert(p(ans2+k) == 1)
ans = ans2;
}
}
- We've proven the correctness of the loop invariant at the end of the loop, given the prior loop invariant at the beginning of the loop.
- So, At the end of the
(i=0)
iteration, we have k=1
, and so p(ans+1) == 1
, which is the "rightmost index" condition. that we originally wanted.
§ Fully elaborated proof
if (p(1 << nbits) == 0) { return 1 << nbits; }
else {
assert(p(1<<nbits) == 1);
int ans = 0;
for (int i = nbits-1; i >= 0; i--) {
int k = 1 << i;
assert(p(ans + 2*k) == true);
if (p(ans + k) == 0) {
ans += ans + k;
} else {
ans = ans;
}
assert(p(ans+k) == 1)
}
}
§ Simplified implementation
If we are willing to suffer some performance impact, we can change the loop
to become significantly easier to prove:
if (p(1 << nbits) == 0) { return 1 << nbits; }
else {
assert(p(1<<nbits) == 1);
int ans = 0;
int i = nbits-1;
while(i >= 0) {
assert (p(ans+2*k) == 1);
int k = 1 << i;
if (p(ans + k) == 0) {
ans += ans + k;
} else {
i--;
}
assert(p(ans) == 0)
}
}
In this version of the loop, we only decrement i
when we are sure that p(ans+k) == 0
.
We don't need to prove that decrementing i
monotonically per loop trip maintains
the invariant; Rather, we can try "as many i
s as necessary" and then decrement i
once it turns out to not be useful.
§ Relationship to LCA / binary lifting
This is very similar to LCA, where we find the lowest node that is not an ancestor. The ancestor
of such a node must be the ancestor.
int lca(int u, int v) {
if (is_ancestor(u, v)) return u;
if (is_ancestor(v, u)) return v;
for (int i = l; i >= 0; --i) {
if (!is_ancestor(up[u][i], v))
u = up[u][i];
}
return up[u][0];
}