mirror of
https://github.com/AdaCore/why3.git
synced 2026-02-12 12:34:55 -08:00
240 lines
6.6 KiB
Plaintext
240 lines
6.6 KiB
Plaintext
|
|
(* Given a nxn matrix m of nonnegative integers, we want to pick up one element
|
|
in each row and each column, so that their sum is maximal.
|
|
|
|
We generalize the problem as follows: f(i,c) is the maximum for rows >= i
|
|
and columns in set c. Thus the solution is f(0,{0,1,...,n-1}).
|
|
|
|
f is easily defined recursively, as we have
|
|
|
|
f(i,c) = max{j in c} m[i][j] + f(i+1, C\{j})
|
|
|
|
As such, it would still be a brute force approach (of complexity n!)
|
|
but we can memoize f and then the search space decreases to n*2^n.
|
|
|
|
The following code implements such a solution. Sets of integers are
|
|
provided in theory Bitset. Hash tables for memoization are provided
|
|
in module HashTable (see file hash_tables.mlw for an implementation).
|
|
Code for f is in module MaxMatrixMemo (mutually recursive functions
|
|
maximum and memo).
|
|
*)
|
|
|
|
theory Bitset
|
|
|
|
use int.Int
|
|
|
|
constant size : int (* elements belong to 0..size-1 *)
|
|
|
|
type set
|
|
|
|
(* membership
|
|
[mem i s] can be implemented as [s land (1 lsl i) <> 0] *)
|
|
val predicate mem int set
|
|
|
|
(* removal
|
|
[remove i s] can be implemented as [s - (1 lsl i)] *)
|
|
val function remove (x: int) (s: set): set
|
|
ensures { forall y: int. mem y result <-> y <> x /\ mem y s }
|
|
|
|
(* the set {0,1,...,n-1}
|
|
[below n] can be implemented as [1 lsl n - 1] *)
|
|
val function below (n: int): set
|
|
requires { 0 <= n <= size }
|
|
ensures { forall x: int. mem x result <-> 0 <= x < n }
|
|
|
|
val function cardinal set: int
|
|
|
|
axiom cardinal_empty:
|
|
forall s: set. cardinal s = 0 <-> (forall x: int. not (mem x s))
|
|
|
|
axiom cardinal_remove:
|
|
forall x: int. forall s: set.
|
|
mem x s -> cardinal s = 1 + cardinal (remove x s)
|
|
|
|
axiom cardinal_below:
|
|
forall n: int. 0 <= n <= size ->
|
|
cardinal (below n) = if n >= 0 then n else 0
|
|
|
|
end
|
|
|
|
module HashTable
|
|
|
|
use option.Option
|
|
use int.Int
|
|
use map.Map
|
|
|
|
type t 'a 'b = private { ghost mutable contents: map 'a (option 'b) }
|
|
|
|
function ([]) (h: t 'a 'b) (k: 'a) : option 'b = Map.get h.contents k
|
|
|
|
val create (n:int) : t 'a 'b
|
|
requires { 0 < n } ensures { forall k: 'a. result[k] = None }
|
|
|
|
val clear (h: t 'a 'b) : unit writes {h}
|
|
ensures { forall k: 'a. h[k] = None }
|
|
|
|
val add (h: t 'a 'b) (k: 'a) (v: 'b) : unit writes {h}
|
|
ensures { h[k] = Some v /\ forall k': 'a. k' <> k -> h[k'] = (old h)[k'] }
|
|
|
|
exception Not_found
|
|
|
|
val find (h: t 'a 'b) (k: 'a) : 'b
|
|
ensures { h[k] = Some result } raises { Not_found -> h[k] = None }
|
|
|
|
end
|
|
|
|
module Appmap
|
|
|
|
use map.Map
|
|
use map.Const
|
|
|
|
type key
|
|
|
|
type t 'a = abstract { contents: map key 'a }
|
|
|
|
val function create (x: 'a): t 'a
|
|
ensures { result.contents = const x }
|
|
|
|
val function ([]) (m: t 'a) (k: key): 'a
|
|
ensures { result = m.contents[k] }
|
|
|
|
val function ([<-]) (m: t 'a) (k: key) (v: 'a): t 'a
|
|
ensures { result.contents = m.contents[k <- v] }
|
|
|
|
end
|
|
|
|
module Sum
|
|
|
|
use int.Int
|
|
use map.Map
|
|
|
|
type container
|
|
|
|
function f container int : int
|
|
(** `f c i` is the `i`-th element in the container `c` *)
|
|
|
|
function sum container int int : int
|
|
(** `sum c i j` is the sum `\sum_{i <= k < j} f c k` *)
|
|
|
|
axiom Sum_def_empty :
|
|
forall c : container, i j : int. j <= i -> sum c i j = 0
|
|
|
|
axiom Sum_def_non_empty :
|
|
forall c: container, i j : int. i < j -> sum c i j = f c i + sum c (i+1) j
|
|
|
|
axiom Sum_right_extension:
|
|
forall c : container, i j : int.
|
|
i < j -> sum c i j = sum c i (j-1) + f c (j-1)
|
|
|
|
axiom Sum_transitivity :
|
|
forall c : container, i k j : int. i <= k <= j ->
|
|
sum c i j = sum c i k + sum c k j
|
|
|
|
axiom Sum_eq :
|
|
forall c1 c2 : container, i j : int.
|
|
(forall k : int. i <= k < j -> f c1 k = f c2 k) -> sum c1 i j = sum c2 i j
|
|
|
|
end
|
|
|
|
module MaxMatrixMemo
|
|
|
|
use int.Int
|
|
use int.MinMax
|
|
use ref.Ref
|
|
use Bitset
|
|
use map.Map
|
|
|
|
clone Appmap with type key = int, axiom .
|
|
|
|
val constant n : int
|
|
ensures { 0 <= result <= size }
|
|
|
|
val constant m : t (t int)
|
|
ensures { forall i j: int. 0 <= i < n -> 0 <= j < n -> 0 <= result[i][j] }
|
|
|
|
type mapii = Map.map int int
|
|
|
|
predicate solution (s: mapii) (i: int) =
|
|
(forall k: int. i <= k < n -> 0 <= Map.get s k < n) /\
|
|
(forall k1 k2: int. i <= k1 < k2 < n -> Map.get s k1 <> Map.get s k2)
|
|
|
|
predicate permutation (s: mapii) = solution s 0
|
|
|
|
function f (s: mapii) (i: int) : int = m[i][Map.get s i]
|
|
clone Sum with type container = mapii, function f = f, axiom .
|
|
|
|
lemma sum_ind:
|
|
forall i: int. i < n -> forall j: int.
|
|
forall s: mapii. sum (Map.set s i j) i n = m[i][j] + sum s (i+1) n
|
|
|
|
use option.Option
|
|
use HashTable as H
|
|
|
|
type key = (int, set)
|
|
type value = (int, t int)
|
|
|
|
predicate pre (k: key) =
|
|
let (i, c) = k in
|
|
0 <= i <= n /\ cardinal c = n-i /\ (forall k: int. mem k c -> 0 <= k < n)
|
|
|
|
predicate post (k: key) (v: value) =
|
|
let (i, c) = k in
|
|
let (r, sol) = v in
|
|
0 <= r /\ solution sol.contents i /\
|
|
(forall k: int. i <= k < n -> mem sol[k] c) /\
|
|
r = sum sol.contents i n /\
|
|
(forall s: mapii.
|
|
solution s i -> (forall k: int. i <= k < n -> mem (Map.get s k) c) ->
|
|
r >= sum s i n)
|
|
|
|
type table = H.t key value
|
|
|
|
val table: table
|
|
|
|
predicate inv (t: table) =
|
|
forall k: key, v: value. H.([]) t k = Some v -> post k v
|
|
|
|
let rec maximum (i:int) (c: set) : (int, t int) variant {2*n-2*i}
|
|
requires { pre (i, c) /\ inv table }
|
|
ensures { post (i,c) result /\ inv table }
|
|
= if i = n then
|
|
(0, create 0)
|
|
else begin
|
|
let r = ref (-1) in
|
|
let sol = ref (create 0) in
|
|
for j = 0 to n-1 do
|
|
invariant {
|
|
inv table /\
|
|
( (!r = -1 /\ forall k: int. 0 <= k < j -> not (mem k c))
|
|
\/
|
|
(0 <= !r /\ solution !sol.contents i /\
|
|
(forall k: int. i <= k < n -> mem !sol[k] c) /\
|
|
!r = sum !sol.contents i n /\
|
|
(forall s: mapii.
|
|
solution s i -> (forall k: int. i <= k < n -> mem (Map.get s k) c) ->
|
|
mem (Map.get s i) c -> Map.get s i < j -> !r >= sum s i n)))
|
|
}
|
|
if mem j c then
|
|
let (r', sol') = memo (i+1) (remove j c) in
|
|
let x = m[i][j] + r' in
|
|
if x > !r then begin r := x; sol := sol'[i <- j] end
|
|
done;
|
|
assert { 0 <= !r };
|
|
(!r, !sol)
|
|
end
|
|
|
|
with memo (i:int) (c: set) : (int, t int) variant {2*n-2*i+1}
|
|
requires { pre (i,c) /\ inv table }
|
|
ensures { post (i,c) result /\ inv table }
|
|
= try H.find table (i,c)
|
|
with H.Not_found -> let r = maximum i c in H.add table (i,c) r; r end
|
|
|
|
let maxmat ()
|
|
ensures { exists s: mapii. permutation s /\ result = sum s 0 n }
|
|
ensures { forall s: mapii. permutation s -> result >= sum s 0 n }
|
|
= H.clear table;
|
|
assert { inv table };
|
|
let (r, _) = maximum 0 (below n) in r
|
|
|
|
end
|