See that we can insert a new invaraiant at the end of the loop which asserts p(ans[i]+k[i]) == 1
:
if (p(1 << nbits) == 0) { return 1 << nbits; }
else {
assert(p(1<<nbits) == 1);
int ans = 0;
for (int i = nbits-1; i >= 0; i--) {
int k = 1 << i;
assert(p(ans + 2*k) == 1);
int ans2;
if (p(ans + k) == 0) {
ans2 = ans + k;
} else {
ans2 = ans;
}
assert(p(ans2+k) == 1)
ans = ans2;
}
}
- We've proven the correctness of the loop invariant at the end of the loop, given the prior loop invariant at the beginning of the loop.
- So, At the end of the
(i=0)
iteration, we have k=1
, and so p(ans+1) == 1
, which is the "rightmost index" condition. that we originally wanted.
§ Fully elaborated proof
if (p(1 << nbits) == 0) { return 1 << nbits; }
else {
assert(p(1<<nbits) == 1);
int ans = 0;
for (int i = nbits-1; i >= 0; i--) {
int k = 1 << i;
assert(p(ans + 2*k) == true);
if (p(ans + k) == 0) {
ans += ans + k;
} else {
ans = ans;
}
assert(p(ans+k) == 1)
}
}
§ Simplified implementation
If we are willing to suffer some performance impact, we can change the loop
to become significantly easier to prove:
if (p(1 << nbits) == 0) { return 1 << nbits; }
else {
assert(p(1<<nbits) == 1);
int ans = 0;
int i = nbits-1;
while(i >= 0) {
assert (p(ans+2*k) == 1);
int k = 1 << i;
if (p(ans + k) == 0) {
ans += ans + k;
} else {
i--;
}
assert(p(ans) == 0)
}
}
In this version of the loop, we only decrement i
when we are sure that p(ans+k) == 0
.
We don't need to prove that decrementing i
monotonically per loop trip maintains
the invariant; Rather, we can try "as many i
s as necessary" and then decrement i
once it turns out to not be useful.
§ Relationship to LCA / binary lifting
This is very similar to LCA, where we find the lowest node that is not an ancestor. The ancestor
of such a node must be the ancestor.
int lca(int u, int v) {
if (is_ancestor(u, v)) return u;
if (is_ancestor(v, u)) return v;
for (int i = l; i >= 0; --i) {
if (!is_ancestor(up[u][i], v))
u = up[u][i];
}
return up[u][0];
}