Files
why3/examples/string_base64_encoding.mlw
2024-10-31 15:27:09 +01:00

333 lines
13 KiB
Plaintext

(** {1 Base64 encoding}
Implementation of the Base64 encoding algorithm.
An encode function translates an arbitrary string intro a string
containing only the following 64 characters.
{h
<pre>
Index Binary Char | Index Binary Char
0 000000 'A' | 32 100000 'g'
1 000001 'B' | 33 100001 'h'
2 000010 'C' | 34 100010 'i'
3 000011 'D' | 35 100011 'j'
4 000100 'E' | 36 100100 'k'
5 000101 'F' | 37 100101 'l'
6 000110 'G' | 38 100110 'm'
7 000111 'H' | 39 100111 'n'
8 001000 'I' | 40 101000 'o'
9 001001 'J' | 41 101001 'p'
10 001010 'K' | 42 101010 'q'
11 001011 'L' | 43 101011 'r'
12 001100 'M' | 44 101100 's'
13 001101 'N' | 45 101101 't'
14 001110 'O' | 46 101110 'u'
15 001111 'P' | 47 101111 'v'
16 010000 'Q' | 48 110000 'w'
17 010001 'R' | 49 110001 'x'
18 010010 'S' | 50 110010 'y'
19 010011 'T' | 51 110011 'z'
20 010100 'U' | 52 110100 '0'
21 010101 'V' | 53 110101 '1'
22 010110 'W' | 54 110110 '2'
23 010111 'X' | 55 110111 '3'
24 011000 'Y' | 56 111000 '4'
25 011001 'Z' | 57 111001 '5'
26 011010 'a' | 58 111010 '6'
27 011011 'b' | 59 111011 '7'
28 011100 'c' | 60 111100 '8'
29 011101 'd' | 61 111101 '9'
30 011110 'e' | 62 111110 '+'
31 011111 'f' | 63 111111 '/'
</pre>}
Four characters are required to encode each sequence of 3 characters
of the input string. The length of the encoded string is always a
multiple of four (if needed the padding character '=' is used).
A decode function does the inverse operation. It takes a previously
encoded string and returns the original string.
The main property of the algorithm is that, for every string `s`,
`decode (encode s) = s`.
*)
module Base64
use mach.int.Int
use mach.int.Int63
use string.String
use string.Char
use string.OCaml
use string.StringBuffer
(** `int2b64 i` calculates the character corresponding to index `i`
(see table above) *)
function int2b64 (i: int) : char =
if 0 <= i <= 25 then chr (i + 65) else
if 26 <= i <= 51 then chr (i - 26 + 97) else
if 52 <= i <= 61 then chr (i - 52 + 48) else
if i = 62 then chr 43 else if i = 63 then chr 47
else chr 0
let int2b64 (i: int63) : char
requires { 0 <= i < 64 }
ensures { result = int2b64 i }
=
if 0 <= i <= 25 then chr (i + 65) else
if 26 <= i <= 51 then chr (i - 26 + 97) else
if 52 <= i <= 61 then chr (i - 52 + 48) else
if i = 62 then chr 43 else if i = 63 then chr 47
else absurd
(** character '=' *)
let function eq_symbol : char = chr 61
(** `valid_b64_char c` is true, if and only if, `c` is in the table
above *)
predicate valid_b64_char (c: char) =
65 <= code c <= 90 || 97 <= code c <= 122 || 48 <= code c <= 57 ||
code c = 43 || code c = 47
lemma int2b64_valid_4_char: forall i. 0 <= i < 64 -> valid_b64_char (int2b64 i)
(** inverse of function `int2b64` *)
function b642int (c: char) : int =
if 65 <= code c <= 90 then code c - 65 else (* 0 - 25 *)
if 97 <= code c <= 122 then code c - 97 + 26 else (* 26 - 51 *)
if 48 <= code c <= 57 then code c - 48 + 52 else (* 52 - 61 *)
if code c = 43 then 62 else (* 62 *)
if code c = 47 then 63 else (* 63 *)
if c = eq_symbol then 0 else (* 0 *)
64 (* 64 *)
let b642int (c: char) : int63
requires { valid_b64_char c || c = eq_symbol }
ensures { result = b642int c }
= if 65 <= code c <= 90 then code c - 65 else
if 97 <= code c <= 122 then code c - 97 + 26 else
if 48 <= code c <= 57 then code c - 48 + 52 else
if code c = 43 then 62 else
if code c = 47 then 63 else
if eq_char c eq_symbol then 0 else
absurd
lemma b642int_int2b64: forall i. 0 <= i < 64 -> b642int (int2b64 i) = i
(** `get_pad s` calculates the number of padding characters in the
string `s` (between 0 and 2) *)
function get_pad (s: string): int =
if length s >= 1 && s[length s - 1] = eq_symbol then
(if length s >= 2 && s[length s - 2] = eq_symbol then 2 else 1)
else 0
let partial get_pad (s: string): int63
ensures { result = get_pad s }
= if length s >= 1 && eq_char s[length s - 1] eq_symbol then
(if length s >= 2 && eq_char s[length s - 2] eq_symbol then 2 else 1)
else 0
(** `calc_pad s` returns the number of padding characters that will
appear in the string that results from encoding `s` *)
function calc_pad (s: string): int =
if mod (length s) 3 = 1 then 2 else
if mod (length s) 3 = 2 then 1 else 0
lemma calc_pad_mod3: forall s. mod (length s + calc_pad s) 3 = 0
let partial calc_pad (s: string): int63
ensures { result = calc_pad s }
= if length s % 3 = 1 then 2 else
if length s % 3 = 2 then 1 else 0
(** `encoding s1 s2` is true, if and only if, `s2` is a valid
encoding of `s1` *)
predicate encoding (s1 s2: string) =
length s2 = div (length s1 + calc_pad s1) 3 * 4 &&
( forall i. 0 <= i < div (length s2) 4 ->
let b1,b2,b3,b4 = s2[4*i], s2[4*i+1],s2[4*i+2], s2[4*i+3] in
s1[i*3] = chr (b642int b1 * 4 + div (b642int b2) 16) &&
(i * 3 + 1 < length s1 ->
s1[i*3+1] = chr (mod (b642int b2) 16 * 16 + div (b642int b3) 4)) &&
(i * 3 + 2 < length s1 ->
s1[i*3+2] = chr (mod (b642int b3) 4 * 64 + b642int b4))) &&
( forall i. 0 <= i < length s2 - get_pad s2 -> valid_b64_char s2[i] ) &&
( get_pad s2 = 1 -> mod (b642int s2[length s2 - 2]) 4 = 0 /\
s2[length s2 - 1] = eq_symbol ) &&
( get_pad s2 = 2 -> mod (b642int s2[length s2 - 3]) 16 = 0 /\
s2[length s2 - 2] = s2[length s2 - 1] = eq_symbol ) &&
calc_pad s1 = get_pad s2
(** `valid_b64 s` is true, if and only if, `s` is a valid Base64
string: the length of `s` is a multiple of 4, and all its characters
(with exception to the last element or the two last elements) belong
to the table above. If the last characters of the string are not
valid then either the string terminates with "=" or "==". *)
predicate valid_b64 (s: string) =
(mod (length s) 4 = 0) &&
(forall i. 0 <= i < length s - get_pad s -> valid_b64_char s[i]) &&
(get_pad s = 1 -> mod (b642int s[length s - 2]) 4 = 0 &&
s[length s - 1] = eq_symbol) &&
(get_pad s = 2 -> mod (b642int s[length s - 3]) 16 = 0 &&
s[length s - 2] = eq_symbol &&
s[length s - 1] = eq_symbol)
let lemma encoding_valid_b64 (s1 s2: string)
requires { encoding s1 s2 }
ensures { valid_b64 s2 }
=
assert { length s1 = div (length s2) 4 * 3 - get_pad s2 };
assert { forall i. 0 <= i < div (length s1) 3 ->
(valid_b64_char s2[i*4] && valid_b64_char s2[i*4+1] &&
valid_b64_char s2[i*4+2] && valid_b64_char s2[i*4+3])
};
assert { mod (length s1) 3 <> 0 ->
let last = div (length s1) 3 in
valid_b64_char s2[last*4] && valid_b64_char s2[last*4+1] &&
if last * 3 + 1 = length s1 then
s2[last*4+2] = eq_symbol && s2[last*4+3] = eq_symbol
else valid_b64_char s2[last*4+2] && s2[last*4+3] = eq_symbol
};
assert { forall i.
(0 <= i < length s2 - get_pad s2 -> valid_b64_char s2[i]) &&
(length s2 - get_pad s2 <= i < length s2 -> s2[i] = eq_symbol) }
(** the decode of a string is unique *)
let lemma decode_unique (s1 s2 s3: string)
requires { encoding s1 s2 /\ encoding s3 s2 }
ensures { s1 = s3 }
= assert { forall i. 0 <= i < div (length s1 + calc_pad s1) 3 ->
s1[i*3] = s3[i*3] &&
(i * 3 + 1 < length s1 -> s1[i*3+1] = s3[i*3+1]) &&
(i * 3 + 2 < length s1 -> s1[i*3+2] = s3[i*3+2]) };
assert { forall i. 0 <= i < length s1 -> s1[i] = s3[i] }
(** the encode of a string is unique *)
let lemma encode_unique (s1 s2 s3: string)
requires { encoding s1 s2 /\ encoding s1 s3 }
ensures { s2 = s3 }
= assert { length s2 = length s3 };
assert { forall i. 0 <= i < div (length s2) 4 ->
let a1, a2, a3 = s1[i*3], s1[i*3+1], s1[i*3+2] in
s2[i*4] = int2b64 (div (code a1) 4) &&
if i * 3 + 1 < length s1 then
( s2[i*4+1] = int2b64 (mod (code a1) 4 * 16 + div (code a2) 16) &&
( if i * 3 + 2 < length s1 then
s2[i*4+2] = int2b64 (mod (code a2) 16 * 4 + div (code a3) 64) &&
s2[i*4+3] = int2b64 (mod (code a3) 64)
else (s2[i*4+2] = int2b64 (mod (code a2) 16 * 4) && s2[i*4+3] = eq_symbol)
)
) else
( s2[i*4+1] = int2b64 (mod (code a1) 4 * 16)
&& s2[i*4+2] = eq_symbol
&& s2[i*4+3] = eq_symbol)
};
assert { forall i. 0 <= i < div (length s3) 4 ->
let a1, a2, a3 = s1[i*3], s1[i*3+1], s1[i*3+2] in
s3[i*4] = int2b64 (div (code a1) 4) &&
if i * 3 + 1 < length s1 then
( s3[i*4+1] = int2b64 (mod (code a1) 4 * 16 + div (code a2) 16) &&
( if i * 3 + 2 < length s1 then
s3[i*4+2] = int2b64 (mod (code a2) 16 * 4 + div (code a3) 64) &&
s3[i*4+3] = int2b64 (mod (code a3) 64)
else (s3[i*4+2] = int2b64 (mod (code a2) 16 * 4) && s3[i*4+3] = eq_symbol)
)
) else
( s3[i*4+1] = int2b64 (mod (code a1) 4 * 16)
&& s3[i*4+2] = eq_symbol
&& s3[i*4+3] = eq_symbol)
};
assert { forall i. 0 <= i < div (length s2) 4 ->
s2[i*4] = s3[i*4] /\ s2[i*4+1] = s3[i*4+1]
/\ s2[i*4+2] = s3[i*4+2] /\ s2[i*4+3] = s3[i*4+3]
};
assert { forall i. 0 <= i < length s2 -> s2[i] = s3[i]};
assert { eq_string s2 s3}
let partial encode (s: string) : string
ensures { encoding s result }
= let padding = calc_pad s in
let sp = concat s (make padding (chr 0)) in
let ref i = 0 in
let r = StringBuffer.create 42 in
let ghost ref b = 0 : int in
while i < length sp do
variant {length sp - i}
invariant { i = b * 3 }
invariant { length r = b * 4 }
invariant { 0 <= i <= length sp }
invariant { forall j. 0 <= j < b ->
let a1, a2, a3 = sp[j*3], sp[j*3+1], sp[j*3+2] in
r[j*4] = int2b64 (div (code a1) 4) &&
r[j*4+1] = int2b64 (mod (code a1) 4 * 16 + div (code a2) 16) &&
r[j*4+2] = int2b64 (mod (code a2) 16 * 4 + div (code a3) 64) &&
r[j*4+3] = int2b64 (mod (code a3) 64)
}
let c1,c2,c3 = sp[i], sp[i+1], sp[i+2] in
let b1 = code c1 / 4 in
let b2 = (code c1 % 4) * 16 + code c2 / 16 in
let b3 = (code c2 % 16) * 4 + code c3 / 64 in
let b4 = code c3 % 64 in
label L in
StringBuffer.add_char r (int2b64 b1);
StringBuffer.add_char r (int2b64 b2);
StringBuffer.add_char r (int2b64 b3);
StringBuffer.add_char r (int2b64 b4);
i <- i + 3;
ghost (b <- b + 1);
assert { r[length r - 4] = int2b64 (div (code c1) 4) &&
r[length r - 3] = int2b64 (mod (code c1) 4 * 16 + div (code c2) 16) &&
r[length r - 2] = int2b64 (mod (code c2) 16 * 4 + div (code c3) 64) &&
r[length r - 1] = int2b64 (mod (code c3) 64) };
assert { forall j. 0 <= j < length r - 4 -> r[j] = (r at L)[j] };
done;
assert { padding = 1 -> mod (b642int r[length r - 2]) 4 = 0 };
assert { padding = 2 -> mod (b642int r[length r - 3]) 16 = 0 };
StringBuffer.truncate r (length r - padding);
StringBuffer.add_string r (make padding eq_symbol);
StringBuffer.contents r
let partial decode (s: string) : string
requires { valid_b64 s }
ensures { encoding result s }
= let ref i = 0 in
let r = StringBuffer.create 42 in
let ghost ref b = 0:int in
while i < length s do
variant { length s - i }
invariant { 0 <= i <= length s }
invariant { i = b * 4 }
invariant { length r = b * 3 }
invariant { forall j. 0 <= j < b ->
let b1,b2,b3,b4 = s[4*j], s[4*j+1], s[4*j+2], s[4*j+3] in
r[j*3] = chr (b642int b1 * 4 + div (b642int b2) 16) &&
r[j*3+1] = chr (mod (b642int b2) 16 * 16 + div (b642int b3) 4) &&
r[j*3+2] = chr (mod (b642int b3) 4 * 64 + b642int b4)
}
label L in
let b1,b2,b3,b4 = s[i], s[i+1], s[i+2], s[i+3] in
let a1 = b642int b1 * 4 + b642int b2 / 16 in
let a2 = b642int b2 % 16 * 16 + b642int b3 / 4 in
let a3 = b642int b3 % 4 * 64 + b642int b4 in
StringBuffer.add_char r (chr a1);
StringBuffer.add_char r (chr a2);
StringBuffer.add_char r (chr a3);
i <- i + 4;
ghost (b <- b + 1);
assert { r[length r - 3] = chr (b642int b1 * 4 + div (b642int b2) 16) &&
r[length r - 2] = chr (mod (b642int b2) 16 * 16 + div (b642int b3) 4) &&
r[length r - 1] = chr (mod (b642int b3) 4 * 64 + b642int b4) };
assert { forall j. 0 <= j < length r - 3 -> r[j] = (r at L)[j]}
done;
StringBuffer.sub r 0 (length r - get_pad s)
let partial decode_encode (s: string) : unit
= let s1 = encode s in
let s2 = decode s1 in
assert { s = s2 }
end