Files
why3/examples/prover/BacktrackArray.mlw
2019-01-18 18:26:34 +01:00

437 lines
14 KiB
Plaintext

(* BactrackArray.mlw :
provide support for a backtracking table.
Idea : you have an infinite number of stack,initially empty,
and you add elements with time.
Moreover, you sometime need to undo some events.
Primitive:
1) create () : t 'a : create an initially empty table.
2) add (i:int) (b:'a) (tb:t 'a) : unit :
add an element b on a stack top i, and advance the time counter.
3) get (i:int) (tb:t 'a) : list 'a :
get the stack i at the current time.
4) stamp (tb:t 'a) : timestamp 'a : current timestamp.
5) backtrack (t:timestamp) (tb:t 'a) : unit :
If the instant t is a past instant or the current instant,
go back to the instant t. *)
module Types
use int.Int
use array.Array
use list.List
use Functions.Func
use Predicates.Pred
type timestamp 'a = {
time : int ;
size : int ;
ghost table : int -> list 'a ;
}
type t 'a = {
(* (-1) mean array resizing (doubling in size), i >= 0 mean
update at case i. *)
mutable history : list int ;
mutable current_time : int ;
mutable buffer : array (list 'a) ;
(* Invariant of stored data. *)
ghost inv : 'a -> bool ;
}
end
module Logic
use Types
use int.Int
use int.ComputerDivision
use import map.Map as M
use array.Array
use list.List
use Functions.Func
use Choice.Choice
function func_of_array (a:array 'a) (def:'a) : int -> 'a
axiom func_of_array_def : forall a:array 'a,def:'a,x:int.
func_of_array a def x = if (0 <= x < a.length)
then M.get a.elts x
else def
function current_timestamp (x:t 'a) : timestamp 'a =
{ time = x.current_time ;
size = x.buffer.length ;
table = func_of_array x.buffer Nil ; }
predicate correct_table (sz:int) (b: int -> list 'a) =
forall x:int. x >= sz \/ x < 0 -> b x = Nil
function pop (l:list 'a) : list 'a = match l with
| Nil -> default
| Cons _x q -> q
end
function unroll (tm:int) (t0:int) (h:list int) (b:int -> list 'a)
(sz:int) : timestamp 'a =
if tm = 0
then { time = t0 ; size = sz ; table = b }
else match h with
| Nil -> { time = (t0+tm) ; size = sz ; table = b ; }
| Cons x q -> if x = (-1)
then unroll (tm-1) t0 q b (div sz 2)
else unroll (tm-1) t0 q (b[x <- pop (b x)]) sz
end
predicate unroll_correct (tm:int) (h:list int) (b:int -> list 'a) (sz:int)
= match h with
| Nil -> tm = 0
| Cons x q -> if x = (-1)
then let s0 = div sz 2 in
s0 * 2 = sz /\ unroll_correct (tm-1) q b s0 /\
(forall i:int. i >= s0 \/ i < 0 -> b i = Nil)
else b x <> Nil /\ 0 <= x < sz
/\ unroll_correct (tm-1) q (b[x <- pop (b x)]) sz
end
predicate past_time (t:timestamp 'a) (tb:t 'a) =
tb.current_time >= t.time /\
unroll (tb.current_time - t.time) t.time
tb.history (func_of_array tb.buffer Nil) tb.buffer.length = t
predicate precede (tb1:t 'a) (tb2:t 'a) =
forall t:timestamp 'a. past_time t tb1 -> past_time t tb2
predicate before (t1 t2:timestamp 'a) =
t1.time <= t2.time
predicate list_forall (p:'a -> bool) (l:list 'a) = match l with
| Nil -> true
| Cons x q -> p x /\ list_forall p q
end
predicate correct (tb:t 'a) =
(forall t:timestamp 'a. past_time t tb -> t.size > 0) /\
(forall t:timestamp 'a,i:int.
past_time t tb /\ i >= 0 -> list_forall tb.inv (eval t.table i)) /\
(forall t:timestamp 'a.
past_time t tb -> correct_table t.size t.table) /\
unroll_correct tb.current_time tb.history
(func_of_array tb.buffer Nil) tb.buffer.length
(* I MUST INTRODUCE THIS PREDICATE FOR ONLY ONE REASON :
ABUSIVE RECORD DECONSTRUCTION IN ASSERTIONS. *)
predicate backtrack_to (tbold:t 'a) (tbinter:t 'a) (tbnew:t 'a) =
(forall tm:timestamp 'a.
past_time tm tbold -> past_time tm tbinter
&& time tm <= time (current_timestamp tbold)
&& time tm <= time (current_timestamp tbnew)
&& before tm (current_timestamp tbnew)
&& past_time tm tbnew) && (forall tm:timestamp 'a.
past_time tm tbold -> past_time tm tbnew) &&
precede tbold tbnew
end
module Impl
use Types
use Logic
use int.Int
use int.ComputerDivision
use import map.Map as M
use array.Array
use list.List
use Functions.Func
use Predicates.Pred
use Choice.Choice
(* extraction trick to speedup integer operations with
constants. *)
let constant mone : int = -1
let constant zero : int = 0
let constant one : int = 1
let constant two : int = 2
let create (ghost p: 'a -> bool) : t 'a
ensures { forall t:timestamp 'a.
past_time t result -> t.table = const Nil }
ensures { past_time (current_timestamp result) result }
ensures { result.inv = p }
ensures { correct result }
=
let res = {
history = Nil ;
current_time = zero ;
buffer = make one Nil ;
inv = p ;
} in
assert { extensionalEqual (func_of_array res.buffer Nil) (const Nil) } ;
res
(* Internal utility (break some of the invariants),
but useful in practice. *)
let add_event (x:int) (tb:t 'a) : unit
writes { tb.history,tb.current_time }
ensures { tb.history = Cons x (old tb).history }
ensures { tb.current_time = (old tb).current_time + 1 }
=
tb.history <- Cons x tb.history ;
tb.current_time <- tb.current_time + one
(* Internal utility. *)
let resize_for (x:int) (tb:t 'a) : unit
writes { tb }
requires { correct tb }
requires { x >= tb.buffer.length }
ensures { x < tb.buffer.length }
ensures { precede (old tb) tb }
ensures { correct tb }
ensures { (current_timestamp tb).table =
(current_timestamp (old tb)).table }
=
(* pattern : in order to do an optimization
(resize only once), introduce a ghost value
on which we 'execute' the unoptimized code and
'check' at end that it give the same result. *)
let ghost tbc = {
history = tb.history ;
current_time = tb.current_time ;
buffer = copy tb.buffer;
inv = tb.inv ;
} in
let rec aux (size:int) : int
requires { 0 < size <= x }
requires { correct tbc }
requires { tbc.history = tb.history /\
tbc.current_time = tb.current_time /\
func_of_array tbc.buffer Nil = func_of_array tb.buffer Nil }
requires { tbc.buffer.length = size }
variant { x - size }
writes { tb.history,tb.current_time,tbc.history,tbc.current_time,tbc.buffer }
ensures { correct tbc }
ensures { tbc.history = tb.history /\
tbc.current_time = tb.current_time /\
func_of_array tbc.buffer Nil = func_of_array tb.buffer Nil }
ensures { tbc.buffer.length = result }
ensures { result > x }
ensures { result >= size }
ensures { precede (old tbc) tbc }
=
label AuxInit in
assert { past_time (current_timestamp tbc) tbc } ;
add_event mone tb ;
add_event mone tbc ;
let size2 = two * size in
let ghost buf2 = make size2 Nil in
let buf1 = tbc.buffer in
blit buf1 zero buf2 zero size ;
tbc.buffer <- buf2 ;
assert { extensionalEqual (func_of_array tbc.buffer Nil)
(func_of_array (tbc at AuxInit).buffer Nil) } ;
assert { forall t:timestamp 'a.
(past_time t (tbc at AuxInit) \/ t = current_timestamp tbc) <-> past_time t tbc } ;
if size2 > x
then size2
else aux size2 in
let len = length tb.buffer in
assert { extensionalEqual (func_of_array tbc.buffer Nil)
(func_of_array tb.buffer Nil) } ;
assert { len = (current_timestamp tb).size } ;
let size = aux len in
let buf2 = make size Nil in
let buf1 = tb.buffer in
blit buf1 zero buf2 zero len ;
assert { extensionalEqual (func_of_array buf1 Nil)
(func_of_array buf2 Nil) } ;
tb.buffer <- buf2
let iadd (x:int) (b:'a) (tb:t 'a) : unit
writes { tb.history,tb.current_time,tb.buffer.elts }
requires { 0 <= x < tb.buffer.length }
requires { correct tb }
requires { inv tb b }
ensures { past_time (current_timestamp tb) tb }
ensures { correct tb }
ensures { precede (old tb) tb }
ensures { let tb0 = (current_timestamp (old tb)).table in
(current_timestamp tb).table = tb0[x <- Cons b (tb0 x)] }
=
label Init in
let buf = tb.buffer in
buf[x]<- Cons b (buf[x]) ;
add_event x tb ;
assert { let tb0 = (current_timestamp (tb at Init)).table in
extensionalEqual ((current_timestamp tb).table) (tb0[x <- Cons b (tb0 x)]) } ;
assert { let tb0 = (current_timestamp (tb at Init)).table in
let tb1 = (current_timestamp tb).table in
extensionalEqual (tb1[x <- pop (tb1 x)]) tb0 } ;
assert { past_time (current_timestamp tb) tb } ;
assert { forall t:timestamp 'a.
past_time t tb <-> past_time t (tb at Init) \/ t = current_timestamp tb } ;
assert { precede (tb at Init) tb }
let add (x:int) (b:'a) (tb:t 'a) : unit
writes { tb }
requires { correct tb }
requires { x >= 0 }
requires { inv tb b }
ensures { past_time (current_timestamp tb) tb }
ensures { correct tb }
ensures { precede (old tb) tb }
ensures { let tb0 = (current_timestamp (old tb)).table in
(current_timestamp tb).table = tb0[x <- Cons b (tb0 x)] }
=
if x >= length tb.buffer
then resize_for x tb ;
iadd x b tb
let get (tb:t 'a) (x:int) : list 'a
requires { correct tb }
requires { x >= 0 }
ensures { result = table (current_timestamp tb) x }
ensures { list_forall tb.inv result }
=
if x >= length tb.buffer
then Nil
else let res = tb.buffer[x] in
(assert { res = table (current_timestamp tb) x } ; res )
let backtrack (t:timestamp 'a) (tb:t 'a) : unit
writes { tb }
requires { past_time t tb }
requires { correct tb }
ensures { correct tb }
ensures { current_timestamp tb = t }
ensures { past_time (current_timestamp tb) tb }
ensures { forall t2:timestamp 'a. before t2 t /\ past_time t2 (old tb) ->
past_time t2 tb }
ensures { precede tb (old tb) }
=
let final_size = t.size in
let ghost tbc = {
history = tb.history ;
current_time = tb.current_time ;
buffer = copy tb.buffer ;
inv = tb.inv ;
} in
let rec unroll (delta:int) : unit
requires { correct tbc }
requires { past_time t tbc }
requires { delta >= 0 }
requires { tbc.current_time = t.time + delta }
requires { tbc.history = tb.history /\
(forall x:int. 0 <= x < final_size /\ x < tbc.buffer.length ->
func_of_array tbc.buffer Nil x = func_of_array tb.buffer Nil x) }
requires { tb.buffer.length <= final_size }
variant { delta }
writes { tb.history,tb.buffer.elts,tbc }
ensures { correct tbc }
ensures { tbc.history = tb.history /\
(forall x:int. 0 <= x < final_size ->
func_of_array tbc.buffer Nil x = func_of_array tb.buffer Nil x) }
ensures { current_timestamp tbc = t }
(* This is a trick avoiding an inductive reasoning outside. *)
ensures { tbc.buffer.length <= (old tbc).buffer.length }
ensures { precede tbc (old tbc) }
ensures { forall t2:timestamp 'a. before t2 t /\ past_time t2 (old tbc) ->
past_time t2 tbc }
=
label UInit in
if delta <> zero
then begin
match tb.history with
| Nil -> absurd
| Cons x q -> tb.history <- q ;
tbc.history <- q ;
tbc.current_time <- tbc.current_time - one ;
if x = mone
then begin
let buf1 = tbc.buffer in
let len = length buf1 in
let len2 = div len two in
(* nothing to do on non-ghost side. *)
let buf2 = make len2 Nil in
blit buf1 zero buf2 zero len2 ;
tbc.buffer <- buf2 ;
assert { let t0 = { time = tbc.current_time ;
size = len2 ;
table = func_of_array (tbc at UInit).buffer Nil
} in
let t1 = { t0 with table = func_of_array tbc.buffer Nil } in
past_time t0 (tbc at UInit) &&
extensionalEqual t0.table t1.table && t0 = t1 } ;
assert { extensionalEqual (func_of_array tbc.buffer Nil)
(func_of_array (tbc at UInit).buffer Nil) } ;
assert { precede tbc (tbc at UInit) } ;
unroll (delta - one)
end else begin
assert { 0 <= x < tbc.buffer.length } ;
if x < final_size
then begin
let h = tb.buffer[x] in
match h with
| Nil -> absurd
| Cons _ q -> tb.buffer[x]<- q ;
tbc.buffer[x]<- q ;
assert {
let tb0 = func_of_array (tbc at UInit).buffer Nil in
extensionalEqual (tb0[x <- pop (tb0 x)])
(func_of_array tbc.buffer Nil) } ;
assert { precede tbc (tbc at UInit) } ;
unroll (delta - one)
end
end
else begin
let hc = tbc.buffer[x] in
match hc with
| Nil -> absurd
| Cons _ q -> tbc.buffer[x]<- q
end ;
assert { let tb0 = func_of_array (tbc at UInit).buffer Nil in
extensionalEqual (tb0[x <- pop (tb0 x)])
(func_of_array tbc.buffer Nil) } ;
assert { precede tbc (tbc at UInit) } ;
(* nothing to do on non-ghost side. *)
unroll (delta - one)
end
end
end
end
in
assert { extensionalEqual (func_of_array tb.buffer Nil)
(func_of_array tbc.buffer Nil) } ;
(* direct resizing if necessary. *)
if final_size < length tb.buffer
then begin
let buf2 = make final_size Nil in
let buf1 = tb.buffer in
blit buf1 zero buf2 zero final_size ;
tb.buffer <- buf2
end ;
let tm0 = tb.current_time in
tb.current_time <- t.time ;
unroll (tm0 - t.time) ;
assert { extensionalEqual (func_of_array tb.buffer Nil)
(func_of_array tbc.buffer Nil) }
val ghost hack_func_of_array (a:array 'a) (def:'a) : int -> 'a
ensures { result = func_of_array a def }
let stamp (tb:t 'a) : timestamp 'a
requires { correct tb }
ensures { result = current_timestamp tb }
=
{
time = tb.current_time ;
size = length tb.buffer ;
table = hack_func_of_array tb.buffer Nil ;
}
(* look for the correct syntax :
meta "remove_program" val hack_func_of_array*)
end