# Library grille

Require Import ssreflect ssrfun ssrbool eqtype ssrnat seq finfun path.
Require Import choice fintype tuple div zmodp bigop ssralg perm fingroup.

Section Grid.

(******************************************************************************)
(*                                                                            *)
(*                           MODELISATION                                     *)
(*                                                                            *)
(******************************************************************************)

(* Une grille c'est 9 entiers                                                 *)
(*     0 1 2                                                                  *)
(*     3 4 5                                                                  *)
(*     6 7 8                                                                  *)

Definition grid := (nat * nat * nat * nat * nat * nat * nat * nat * nat)%type.

(* Les index pour les cases                                                   *)

Definition index := 'I_9.

(* Grille initiale                                                            *)

Definition grid0 : grid := (0, 0, 0, 0, 0, 0, 0, 0, 0).

(* Retourne la valeur d'une case                                              *)

Definition get (g : grid) (i : index) : nat :=
let: (a0, a1, a2, a3, a4, a5, a6, a7, a8) := g in
if i == 0 :> nat then a0 else
if i == 1 :> nat then a1 else
if i == 2 :> nat then a2 else
if i == 3 :> nat then a3 else
if i == 4 :> nat then a4 else
if i == 5 :> nat then a5 else
if i == 6 :> nat then a6 else
if i == 7 :> nat then a7 else a8.

(* Calcule la somme des éléments voisins pour un index donné                  *)

Definition sum (g : grid) (i : index) : nat :=
let: (a0, a1, a2, a3, a4, a5, a6, a7, a8) := g in
if i == 0 :> nat then a1 + a3 + a4 else
if i == 1 :> nat then a0 + a2 + a3 + a4 + a5 else
if i == 2 :> nat then a1 + a4 + a5 else
if i == 3 :> nat then a0 + a1 + a4 + a6 + a7 else
if i == 4 :> nat then a0 + a1 + a2 + a3 + a5 + a6 + a7 + a8 else
if i == 5 :> nat then a1 + a2 + a4 + a7 + a8 else
if i == 6 :> nat then a3 + a4 + a7 else
if i == 7 :> nat then a3 + a4 + a5 + a6 + a8
else a4 + a5 + a7.

(* Fonction de mise à jour générique, la fonction f donne la valeur à mettre  *)

Definition update (g : grid) f (i : index) :=
let: (a1, a2, a3, a4, a5, a6, a7, a8, a9) := g in
let b := f g i in
if i == 0 :> nat then ( b, a2, a3, a4, a5, a6, a7, a8, a9) else
if i == 1 :> nat then (a1, b, a3, a4, a5, a6, a7, a8, a9) else
if i == 2 :> nat then (a1, a2, b, a4, a5, a6, a7, a8, a9) else
if i == 3 :> nat then (a1, a2, a3, b, a5, a6, a7, a8, a9) else
if i == 4 :> nat then (a1, a2, a3, a4, b, a6, a7, a8, a9) else
if i == 5 :> nat then (a1, a2, a3, a4, a5, b, a7, a8, a9) else
if i == 6 :> nat then (a1, a2, a3, a4, a5, a6, b, a8, a9) else
if i == 7 :> nat then (a1, a2, a3, a4, a5, a6, a7, b, a9)
else (a1, a2, a3, a4, a5, a6, a7, a8, b).

(* Ajouter un 1 en position i                                                 *)

Definition o_update g i := update g (fun g i => 1) i.

(* Ajouter la somme des voisins en position i                                 *)

Definition s_update g i := update g sum i.

(* Un parcours dans la grille c'est une permutation de toutes les cases       *)

Definition walk := 'S_9.

(* Calcule la grille pour le parcours w                                       *)

Definition w_update g (w : index -> index) :=
(* on ajoute tous d'abord deux 1                                             *)
(let g0 := o_update g (w 0%:R) in
let g1 := o_update g0 (w 1%:R) in
(* et ensuite des sommes                                                     *)
let g2 := s_update g1 (w 2%:R) in
let g3 := s_update g2 (w 3%:R) in
let g4 := s_update g3 (w 4%:R) in
let g5 := s_update g4 (w 5%:R) in
let g6 := s_update g5 (w 6%:R) in
let g7 := s_update g6 (w 7%:R) in
let g8 := s_update g7 (w 8%:R) in
g8)%R.

Definition w_val w :=
let g1 := w_update grid0 w in get g1 (w 8%:R)%R.

(* Ce que l'on veut prouver \max_(w : path) (w_val w) = 53, on fait cela par  *)
(* du calcul prouvé                                                           *)

Check \max_(w : walk) (w_val w).

(******************************************************************************)
(*                                                                            *)
(*                           CALCUL                                           *)
(*                                                                            *)
(******************************************************************************)

(* Toutes les façons d'insérer un élements dans une liste                     *)

Fixpoint insertl A (i : A) (l : seq A) : seq (seq A) :=
(i :: l) ::
if l is a :: l' then [seq a :: l | l <- insertl i l']
else [::].

(* Calcul le max de f sur toutes les permutations.                            *)

Fixpoint get_max (n m : nat) (max : nat) f (l : seq 'I_m.+2) :=
if n is n1.+1 then
let ls := insertl (n1%:R)%R l in
foldr (fun l max => get_max n1 max f l) max ls
else maxn max (f l).

Definition f_of_l n (l : seq 'I_n) (i : 'I_n) := nth i l i.

(* Calcul du maximum                                                          *)

Definition all_max :=
get_max 9 0 (fun l => w_val (f_of_l l)) [::].

Lemma all_max_57 : all_max = 57.

(******************************************************************************)
(*                                                                            *)
(*                           PREUVE                                           *)
(*                                                                            *)
(******************************************************************************)

Lemma mem_insertl (A : eqType) (i : A) (l l1 : seq A) :
l1 \in insertl i l -> l1 =i i :: l.

Lemma size_insertl (A : eqType) (i : A) (l l1 : seq A) :
l1 \in insertl i l -> size l1 = (size l).+1.

(* Une liste ordonnée par rapport à une permutation                           *)

Definition ordLS (n : nat) (p : 'S_n) l :=
sorted (fun x y => p^-1%g x < p^-1%g y) l.

Lemma ordLSE n p a b (l : seq 'I_n) :
ordLS p [:: a, b & l] = (p^-1 a < p^-1 b)%g && ordLS p [:: b & l].

Lemma ordLS_in n p a b (l : seq 'I_n) :
ordLS p (a :: l) -> b \in l -> p^-1%g a < p^-1%g b.

Lemma ordLS_nth n p (l : seq 'I_n.+1) i j :
ordLS p l -> i < j < size l -> p^-1%g (nth ord0 l i) < p^-1%g (nth ord0 l j).

Lemma ordLS_nth_ge n p (l : seq 'I_n.+1) i :
ordLS p l -> i < size l -> i <= p^-1%g (nth ord0 l i).

Lemma ordLS_nth_le n p (l : seq 'I_n.+1) i :
ordLS p l -> i < size l -> p^-1%g (nth ord0 l i) <= n - (size l - i.+1).

Lemma ordLS_enum n p (l : seq 'I_n.+1) :
ordLS p l -> (size l == n.+1) = (l == [seq p i | i <- enum 'I_n.+1]).

Lemma insertl_ordLS n (p : 'S_n.+2) i (l : seq 'I_n.+2) :
i \notin l -> ordLS p l -> exists l1, (l1 \in insertl i l) && ordLS p l1.

Lemma w_val_eq (f1 f2 : index -> index) : f1 =1 f2 -> w_val f1 = w_val f2.

Lemma leq_get_max m k v f (l : seq 'I_m.+2) : v <= get_max k v f l.

(* On peut enfin prouver le résultat voulu                                    *)

Lemma result : \max_(w : walk) (w_val w) = 57.

End Grid.