mirror of
https://github.com/AdaCore/why3.git
synced 2026-02-12 12:34:55 -08:00
291 lines
7.9 KiB
Plaintext
291 lines
7.9 KiB
Plaintext
(* Binary search
|
|
|
|
A classical example. Searches a sorted array for a given value v. *)
|
|
|
|
module BinarySearch
|
|
|
|
use int.Int
|
|
use int.ComputerDivision
|
|
use ref.Ref
|
|
use array.Array
|
|
|
|
(* the code and its specification *)
|
|
|
|
exception Not_found (* raised to signal a search failure *)
|
|
|
|
let binary_search (a: array int) (v: int) : int
|
|
requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
|
|
ensures { 0 <= result < length a /\ a[result] = v }
|
|
raises { Not_found -> forall i. 0 <= i < length a -> a[i] <> v }
|
|
=
|
|
let ref l = 0 in
|
|
let ref u = length a - 1 in
|
|
while l <= u do
|
|
invariant { 0 <= l /\ u < length a }
|
|
invariant {
|
|
forall i. 0 <= i < length a -> a[i] = v -> l <= i <= u }
|
|
variant { u - l }
|
|
let m = l + div (u - l) 2 in
|
|
assert { l <= m <= u };
|
|
if a[m] < v then
|
|
l := m + 1
|
|
else if a[m] > v then
|
|
u := m - 1
|
|
else
|
|
return m
|
|
done;
|
|
raise Not_found
|
|
|
|
end
|
|
|
|
(* A generalization: the midpoint is computed by some abstract function.
|
|
The only requirement is that it lies between l and u. *)
|
|
|
|
module BinarySearchAnyMidPoint
|
|
|
|
use int.Int
|
|
use ref.Ref
|
|
use array.Array
|
|
|
|
exception Not_found (* raised to signal a search failure *)
|
|
|
|
val midpoint (l: int) (u: int) : int
|
|
requires { l <= u } ensures { l <= result <= u }
|
|
|
|
let binary_search (a: array int) (v: int) : int
|
|
requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
|
|
ensures { 0 <= result < length a /\ a[result] = v }
|
|
raises { Not_found -> forall i. 0 <= i < length a -> a[i] <> v }
|
|
=
|
|
let ref l = 0 in
|
|
let ref u = length a - 1 in
|
|
while l <= u do
|
|
invariant { 0 <= l /\ u < length a }
|
|
invariant { forall i. 0 <= i < length a -> a[i] = v -> l <= i <= u }
|
|
variant { u - l }
|
|
let m = midpoint l u in
|
|
if a[m] < v then
|
|
l := m + 1
|
|
else if a[m] > v then
|
|
u := m - 1
|
|
else
|
|
return m
|
|
done;
|
|
raise Not_found
|
|
|
|
end
|
|
|
|
(* The following version of binary search is faster in practice, by being
|
|
friendlier with the branch predictor of most processors. It also happens
|
|
to be stable, since it always return the highest index. *)
|
|
|
|
module BinarySearchBranchless
|
|
|
|
use int.Int
|
|
use int.ComputerDivision
|
|
use ref.Ref
|
|
use array.Array
|
|
|
|
exception Not_found (* raised to signal a search failure *)
|
|
|
|
let binary_search (a: array int) (v: int) : int
|
|
requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
|
|
ensures { 0 <= result < length a /\ a[result] = v }
|
|
ensures { forall i. result < i < length a -> a[i] <> v }
|
|
raises { Not_found -> forall i. 0 <= i < length a -> a[i] <> v }
|
|
=
|
|
let ref l = 0 in
|
|
let ref s = length a in
|
|
if s = 0 then raise Not_found;
|
|
while s > 1 do
|
|
invariant { 0 <= l /\ l + s <= length a /\ s >= 1 }
|
|
invariant {
|
|
forall i. 0 <= i < length a -> a[i] = v -> a[l] <= v /\ i < l + s }
|
|
variant { s }
|
|
let h = div s 2 in
|
|
let m = l + h in
|
|
l := if a[m] > v then l else m;
|
|
s := s - h;
|
|
done;
|
|
if a[l] = v then l
|
|
else raise Not_found
|
|
|
|
end
|
|
|
|
(* binary search using 32-bit integers *)
|
|
|
|
module BinarySearchInt32
|
|
|
|
use int.Int
|
|
use mach.int.Int32
|
|
use ref.Ref
|
|
use mach.array.Array32
|
|
|
|
exception Not_found (* raised to signal a search failure *)
|
|
|
|
let binary_search (a: array int32) (v: int32) : int32
|
|
requires { forall i1 i2. 0 <= i1 <= i2 < a.length -> a[i1] <= a[i2] }
|
|
ensures { 0 <= result < a.length /\ a[result] = v }
|
|
raises { Not_found -> forall i. 0 <= i < a.length -> a[i] <> v }
|
|
=
|
|
let ref l = 0 in
|
|
let ref u = length a - 1 in
|
|
while l <= u do
|
|
invariant { 0 <= l /\ u < a.length }
|
|
invariant { forall i. 0 <= i < a.length -> a[i] = v -> l <= i <= u }
|
|
variant { u - l }
|
|
let m = l + (u - l) / 2 in
|
|
assert { l <= m <= u };
|
|
if a[m] < v then
|
|
l := m + 1
|
|
else if a[m] > v then
|
|
u := m - 1
|
|
else
|
|
return m
|
|
done;
|
|
raise Not_found
|
|
|
|
end
|
|
|
|
(* A particular case with Boolean values (0 or 1) and a sentinel 1 at the end.
|
|
We look for the first position containing a 1. *)
|
|
|
|
module BinarySearchBoolean
|
|
|
|
use int.Int
|
|
use int.ComputerDivision
|
|
use ref.Ref
|
|
use array.Array
|
|
|
|
let binary_search (a: array int) : int
|
|
requires { 0 < a.length }
|
|
requires { forall i j. 0 <= i <= j < a.length -> 0 <= a[i] <= a[j] <= 1 }
|
|
requires { a[a.length - 1] = 1 }
|
|
ensures { 0 <= result < a.length }
|
|
ensures { a[result] = 1 }
|
|
ensures { forall i. 0 <= i < result -> a[i] = 0 }
|
|
=
|
|
let ref lo = 0 in
|
|
let ref hi = length a - 1 in
|
|
while lo < hi do
|
|
invariant { 0 <= lo <= hi < a.length }
|
|
invariant { a[hi] = 1 }
|
|
invariant { forall i. 0 <= i < lo -> a[i] = 0 }
|
|
variant { hi - lo }
|
|
let mid = lo + div (hi - lo) 2 in
|
|
if a[mid] = 1 then
|
|
hi := mid
|
|
else
|
|
lo := mid + 1
|
|
done;
|
|
lo
|
|
|
|
end
|
|
|
|
module Complexity
|
|
|
|
use int.Int
|
|
use int.ComputerDivision
|
|
use ref.Ref
|
|
use array.Array
|
|
|
|
let rec function log2 (n: int) : int
|
|
variant { n }
|
|
= if n <= 1 then 0 else 1 + log2 (div n 2)
|
|
|
|
let rec lemma log2_monotone (x y: int)
|
|
requires { x <= y }
|
|
ensures { log2 x <= log2 y }
|
|
variant { y }
|
|
= if y > 1 then log2_monotone (div x 2) (div y 2)
|
|
|
|
let function f (n: int) : int
|
|
= if n = 0 then 0 else 1 + log2 n
|
|
|
|
lemma upper_bound:
|
|
forall n. n >= 2 -> f n <= 2 * log2 n
|
|
|
|
val ref time: int
|
|
|
|
let binary_search (a: array int) (v: int) : int
|
|
requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
|
|
requires { time = 0 }
|
|
ensures { 0 <= result < length a && a[result] = v
|
|
|| result = -1 && forall i. 0 <= i < length a -> a[i] <> v }
|
|
ensures { time - old time <= f (length a) }
|
|
=
|
|
let ref lo = 0 in
|
|
let ref hi = length a in
|
|
while lo < hi do
|
|
invariant { 0 <= lo <= hi <= length a }
|
|
invariant { forall i. 0 <= i < lo || hi <= i < length a -> a[i] <> v }
|
|
invariant { (time - old time) + f (hi - lo) <= f (length a) }
|
|
variant { hi - lo }
|
|
let mid = lo + div (hi - lo) 2 in
|
|
if a[mid] < v then
|
|
lo <- mid + 1
|
|
else if a[mid] > v then
|
|
hi <- mid
|
|
else
|
|
return mid;
|
|
time <- time + 1
|
|
done;
|
|
-1
|
|
|
|
end
|
|
|
|
(* Search in a two-dimensional grid where all rows and columns are
|
|
sorted. Here is an example:
|
|
|
|
j
|
|
+---+---+---+---+---+-->
|
|
| 1 | 3 | 4 | 4 | 7*|
|
|
+---+---+---+---+---+
|
|
| 1 | 4 | 6 | 6 | 8*|
|
|
+---+---+---+---+---+
|
|
| 2 | 5 | 7 | 9*|11*|
|
|
+---+---+---+---+---+
|
|
| 4 | 8 | 8*|12*|13 |
|
|
+---+---+---+---+---+
|
|
| *
|
|
i v
|
|
|
|
Algorithm: start from the upper right corner and then move left
|
|
(resp. down) when the value is greater (resp. smaller) than the
|
|
value we search for.
|
|
|
|
In the example above, the stars depict the elements that are
|
|
examined when we search for the value 10. Here we end up outside of
|
|
the grid and thus the search is unsuccessful.
|
|
*)
|
|
|
|
module TwoDimensional
|
|
|
|
use int.Int
|
|
use matrix.Matrix
|
|
|
|
let search (m: matrix int) (v: int) : bool
|
|
requires { [@expl: rows are sorted] forall i. 0 <= i < rows m ->
|
|
forall j1 j2. 0 <= j1 <= j2 < columns m ->
|
|
get m i j1 <= get m i j2 }
|
|
requires { [@expl: columns are sorted] forall j. 0 <= j < columns m ->
|
|
forall i1 i2. 0 <= i1 <= i2 < rows m ->
|
|
get m i1 j <= get m i2 j }
|
|
ensures { result <->
|
|
exists i j. 0 <= i < rows m && 0 <= j < columns m && get m i j = v }
|
|
= let ref i = 0 in
|
|
let ref j = columns m - 1 in
|
|
while i < rows m && 0 <= j do
|
|
invariant { 0 <= i <= rows m }
|
|
invariant { -1 <= j < columns m }
|
|
invariant { forall i' j'. 0 <= i' < rows m -> 0 <= j' < columns m ->
|
|
i' < i || j' > j -> get m i' j' <> v }
|
|
variant { rows m - i + j }
|
|
let x = get m i j in
|
|
if x = v then return true;
|
|
if x < v then i <- i + 1 else j <- j - 1
|
|
done;
|
|
return false
|
|
|
|
end
|