# Module Maps2

Require Import Axioms.
Require Import Coqlib.
Require Import Maps.

Set Implicit Arguments.

# The abstract signatures of trees

Module MakeProdTree (M1:TREE) (M2:TREE) <: TREE.
Definition elt: Type := (M1.elt * M2.elt)%type.
Definition elt_eq: forall (a b: elt), {a = b} + {a <> b}.
Proof.
decide equality.
apply M2.elt_eq.
apply M1.elt_eq.
Qed.

Definition t (A:Type) : Type := M1.t (M2.t A).

Definition eq: forall (A: Type), (forall (x y: A), {x=y} + {x<>y}) ->
forall (a b: t A), {a = b} + {a <> b} :=
fun A eq => M1.eq (M2.eq eq).

Definition empty (A: Type) : t A := M1.empty _.

Definition get (A: Type) (a:elt) (m: t A) : option A :=
let (a1,a2) := a in
match M1.get a1 m with
| None => None
| Some m => M2.get a2 m
end.

Definition set (A: Type) (a:elt) (v: A) (m:t A) : t A :=
let (a1,a2) := a in
let m1 := match M1.get a1 m with
| None => M2.set a2 v (M2.empty _)
| Some m1 => M2.set a2 v m1
end in
M1.set a1 m1 m.

Definition remove (A: Type) (a:elt) (m:t A) : t A :=
let (a1,a2) := a in
let m1 := match M1.get a1 m with
| None => M2.empty _
| Some m1 => m1
end in
M1.set a1 (M2.remove a2 m1) m.

The ``good variables'' properties for trees, expressing commutations between get, set and remove.
Lemma gempty:
forall (A: Type) (i: elt), get i (empty A) = None.
Proof.
intros A [i1 i2]; unfold get, empty.
rewrite M1.gempty; auto.
Qed.

Lemma gss:
forall (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
Proof.
intros A [i1 i2] x m; unfold get, set.
rewrite M1.gss; auto.
destruct (M1.get i1 m).
rewrite M2.gss; auto.
rewrite M2.gss; auto.
Qed.

Lemma gso:
forall (A: Type) (i j: elt) (x: A) (m: t A),
i <> j -> get i (set j x m) = get i m.
Proof.
intros A [i1 i2] [j1 j2] x m H; unfold get, set.
destruct (M1.elt_eq i1 j1); subst.
rewrite M1.gss; auto.
destruct (M1.get j1 m).
rewrite M2.gso; intuition congruence.
rewrite M2.gso; try intuition congruence.
rewrite M2.gempty; auto.
rewrite M1.gso; auto.
Qed.

Lemma gsspec:
forall (A: Type) (i j: elt) (x: A) (m: t A),
get i (set j x m) = if elt_eq i j then Some x else get i m.
Proof.
intros A i j x m.
destruct (elt_eq i j); subst.
apply gss.
apply gso; auto.
Qed.

Lemma gsident:
forall (A: Type) (i: elt) (m: t A) (v: A),
get i m = Some v -> set i v m = m.
Proof.
intros A [i1 i2] m v; unfold get, set.
case_eq (M1.get i1 m); [intros m2 H2 H| intros H2 H].
apply M1.gsident.
rewrite M2.gsident; auto.
congruence.
Qed.

Lemma grs:
forall (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
Proof.
intros A [i1 i2] m; unfold get, remove.
rewrite M1.gss.
rewrite M2.grs; auto.
Qed.

Lemma gro:
forall (A: Type) (i j: elt) (m: t A),
i <> j -> get i (remove j m) = get i m.
Proof.
intros A [i1 i2] [j1 j2] m H; unfold get, remove.
destruct (M1.elt_eq i1 j1); subst.
rewrite M1.gss.
rewrite M2.gro.
destruct (M1.get j1 m); auto.
rewrite M2.gempty; auto.
congruence.
rewrite M1.gso; auto.
Qed.

Lemma grspec:
forall (A: Type) (i j: elt) (m: t A),
get i (remove j m) = if elt_eq i j then None else get i m.
Proof.
intros A i j m.
destruct (elt_eq i j); subst.
apply grs.
apply gro; auto.
Qed.

Extensional equality between trees.
Definition beq (A: Type) (f:A -> A -> bool) (m1 m2: t A) : bool :=
M1.beq (M2.beq f) m1 m2.

Lemma beq_correct:
forall (A: Type) (P: A -> A -> Prop) (cmp: A -> A -> bool),
(forall (x y: A), cmp x y = true -> P x y) ->
forall (t1 t2: t A), beq cmp t1 t2 = true ->
forall (x: elt),
match get x t1, get x t2 with
| None, None => True
| Some y1, Some y2 => P y1 y2
| _, _ => False
end.
Proof.
intros A P cmp H t1 t2 H1 [x1 x2]; unfold get, beq in *.
case_eq (M1.get x1 t1); [intros m1 Hm1| intros Hm1].
case_eq (M1.get x1 t2); [intros m2 Hm2| intros Hm2].
case_eq (M2.get x2 m1); [intros mm1 Hmm1| intros Hmm1].
case_eq (M2.get x2 m2); [intros mm2 Hmm2| intros Hmm2].
assert (Hd: (forall x y : M2.t A, M2.beq cmp x y = true -> M2.beq cmp x y = true)) by auto.
generalize (M1.beq_correct (fun m1 m2 => M2.beq cmp m1 m2 = true) Hd H1 x1).
rewrite Hm1; rewrite Hm2.
intros He.
generalize (M2.beq_correct P H He x2).
rewrite Hmm1; rewrite Hmm2; auto.
assert (Hd: (forall x y : M2.t A, M2.beq cmp x y = true -> M2.beq cmp x y = true)) by auto.
generalize (M1.beq_correct (fun m1 m2 => M2.beq cmp m1 m2 = true) Hd H1 x1).
rewrite Hm1; rewrite Hm2.
intros He.
generalize (M2.beq_correct P H He x2).
rewrite Hmm1; rewrite Hmm2; auto.
assert (Hd: (forall x y : M2.t A, M2.beq cmp x y = true -> M2.beq cmp x y = true)) by auto.
generalize (M1.beq_correct (fun m1 m2 => M2.beq cmp m1 m2 = true) Hd H1 x1).
rewrite Hm1; rewrite Hm2.
intros He.
generalize (M2.beq_correct P H He x2).
rewrite Hmm1; auto.
assert (Hd: (forall x y : M2.t A, M2.beq cmp x y = true -> M2.beq cmp x y = true)) by auto.
generalize (M1.beq_correct (fun m1 m2 => M2.beq cmp m1 m2 = true) Hd H1 x1).
rewrite Hm1; rewrite Hm2.
intuition.
case_eq (M1.get x1 t2); [intros m2 Hm2| intros Hm2].
case_eq (M2.get x2 m2); [intros mm2 Hmm2| intros Hmm2].
assert (Hd: (forall x y : M2.t A, M2.beq cmp x y = true -> M2.beq cmp x y = true)) by auto.
generalize (M1.beq_correct (fun m1 m2 => M2.beq cmp m1 m2 = true) Hd H1 x1).
rewrite Hm1; rewrite Hm2.
auto.
auto.
auto.
Qed.

Applying a function to all data of a tree.
Definition map (A B: Type) (f: elt -> A -> B) (m:t A) : t B :=
M1.map (fun a1 => M2.map (fun a2 => f (a1,a2))) m.

Lemma gmap:
forall (A B: Type) (f: elt -> A -> B) (i: elt) (m: t A),
get i (map f m) = option_map (f i) (get i m).
Proof.
intros A B f [i1 i2] m; unfold get, map.
rewrite M1.gmap.
case_eq (M1.get i1 m); intros; simpl; auto.
rewrite M2.gmap; auto.
Qed.

Same as map, but the function does not receive the elt argument.
Definition map1 (A B: Type) (f:A -> B) (m:t A) : t B :=
M1.map1 (M2.map1 f) m.

Lemma gmap1:
forall (A B: Type) (f: A -> B) (i: elt) (m: t A),
get i (map1 f m) = option_map f (get i m).
Proof.
intros A B f [i1 i2] m; unfold get, map1.
rewrite M1.gmap1.
case_eq (M1.get i1 m); intros; simpl; auto.
rewrite M2.gmap1; auto.
Qed.

Applying a function pairwise to all data of two trees.
Definition combine (A B: Type) (f:option A -> option A -> option B) (m1 m2:t A) : t B :=
M1.combine (fun om1 om2 =>
match om1, om2 with
| None, None => None
| None, Some m2 => Some (M2.combine f (@M2.empty _) m2)
| Some m2, None => Some (M2.combine f m2 (@M2.empty _))
| Some m1, Some m2 => Some (M2.combine f m1 m2)
end) m1 m2.

Lemma gcombine:
forall (A B: Type) (f: option A -> option A -> option B),
f None None = None ->
forall (m1 m2: t A) (i: elt),
get i (combine f m1 m2) = f (get i m1) (get i m2).
Proof.
intros A B f Hf m1 m2 [i1 i2]; unfold get, combine.
rewrite M1.gcombine; auto.
case_eq (M1.get i1 m1); intros; simpl; auto.
case_eq (M1.get i1 m2); intros; simpl; auto.
rewrite M2.gcombine; auto.
rewrite M2.gcombine; auto.
rewrite M2.gempty; auto.
case_eq (M1.get i1 m2); intros; simpl; auto.
rewrite M2.gcombine; auto.
rewrite M2.gempty; auto.
Qed.

Lemma combine_commut:
forall (A B: Type) (f g: option A -> option A -> option B),
(forall (i j: option A), f i j = g j i) ->
forall (m1 m2: t A),
combine f m1 m2 = combine g m2 m1.
Proof.
intros A B f g Hf m1 m2; unfold get, combine.
apply M1.combine_commut.
clear m1 m2; intros m1 m2.
destruct m1; destruct m2; auto; apply f_equal.
apply M2.combine_commut; auto.
apply M2.combine_commut; auto.
apply M2.combine_commut; auto.
Qed.

Enumerating the bindings of a tree.
Definition elements (A: Type) (m:t A) : list (elt * A) :=
M1.fold (fun l x1 m1 => M2.fold (fun l x2 a => ((x1,x2),a)::l) m1 l) m (@nil _).

Module P1 := Tree_Properties(M1).
Module P2 := Tree_Properties(M2).

Definition xf1 {A:Type} (x1:M1.elt) l (p:M2.elt*A) := let (x2,a) := p in ((x1,x2),a)::l.

Definition xf2 {A:Type} l (p:(M1.elt* (list (M2.elt* A)))) :=
let (x1,m1) := p in
List.fold_left (xf1 x1) m1 l.

Definition xelements (A: Type) (m:t A) : list (elt * A) :=
List.fold_left xf2 (List.map (fun p => (fst p, M2.elements (snd p))) (M1.elements m)) (@nil _).

Lemma xelements_eq : forall A (m:t A), xelements m = elements m.
Proof.
intros; unfold xelements, elements.
rewrite M1.fold_spec.
assert (forall l l0,
fold_left
(fun (l : list (M1.elt * M2.elt * A)) (p : M1.elt * list (M2.elt * A)) =>
let (x1, m1) := p in
fold_left
(fun (l0 : list (M1.elt * M2.elt * A)) (p0 : M2.elt * A) =>
let (x2, a) := p0 in (x1, x2, a) :: l0) m1 l)
(List.map (fun p : M1.elt * M2.t A => (fst p, M2.elements (snd p)))
l) l0 =
fold_left
(fun (a : list (M1.elt * M2.elt * A)) (p : M1.elt * M2.t A) =>
M2.fold
(fun (l : list (M1.elt * M2.elt * A)) (x2 : M2.elt) (a0 : A) =>
(fst p, x2, a0) :: l) (snd p) a) l l0).
induction l; simpl; auto.
intros; rewrite <- IHl; clear IHl; auto.
apply f_equal3; auto.
destruct a as [m1 m2]; simpl.
rewrite M2.fold_spec.
apply f_equal3; auto.
apply extensionality; intros.
apply extensionality; intros.
destruct x0; simpl; auto.
auto.
Qed.

Lemma elements_correct:
forall (A: Type) (m: t A) (i: elt) (v: A),
get i m = Some v -> In (i, v) (elements m).
Proof.
unfold get, elements; intros A m [i1 i2] v.
case_eq (M1.get i1 m); [idtac| intros H1 H; discriminate].
apply (P1.fold_rec _
(fun m l => forall i1 i2 v m2,
M1.get i1 m = Some m2 ->
M2.get i2 m2 = Some v ->
In (i1, i2, v) l)); clear i1 i2 v.

intros m' m'' l H1 H2 i1 i2 v m2 Heq1 Heq2.
eapply H2; eauto.
rewrite H1; auto.

intros i1 i2 v m2 Heq1 Heq2.
rewrite M1.gempty in *; congruence.

assert (forall i x m2 l,
In x l ->
In x (M2.fold
(fun (l0 : list (M1.elt * M2.elt * A)) (x2 : M2.elt) (a : A) =>
(i, x2, a) :: l0) m2 l)).
intros i x m2 l Hl.
apply (P2.fold_rec
(fun l0 x2 a => (i, x2, a) :: l0)
(fun m2 l' => In x l')); auto with datatypes.

intros m' l i m2 H1 H2 H3 i1 i2 v m0 Heq1 Heq2.
rewrite M1.gsspec in Heq1; destruct M1.elt_eq; subst; eauto.
inv Heq1.
clear dependent m'.

generalize i2 v Heq2; clear i2 v Heq2.
apply (P2.fold_rec (fun l0 x2 a => (i, x2, a) :: l0)
(fun m2 l => forall i2 v,
M2.get i2 m2 = Some v ->
In (i, i2, v) l) l); eauto.
intros m2 m2' l1 H1 H3 i2 v.
rewrite <- H1; auto.
intros i2 v; rewrite M2.gempty; congruence.
intros m1 a i2 v Heq1 Heq2 Heq3 i2' v'.
rewrite M2.gsspec; destruct M2.elt_eq; intros T.
inv T; left; auto.
right; eauto.
Qed.

Lemma elements_complete:
forall (A: Type) (m: t A) (i: elt) (v: A),
In (i, v) (elements m) -> get i m = Some v.
Proof.
unfold elements; intros A m.
apply (P1.fold_rec _
(fun m l =>
(forall i v, In (i,v) l -> get i m = Some v))).
intros m'' m' a Hyp1 Hyp0; auto.
intros [i1 i2] Hv; unfold get.
rewrite <- Hyp1 in *; apply (Hyp0 (i1,i2)).
simpl; intuition.
intros m1 l i v1 H1 H2 H3.
apply (P2.fold_rec _
(fun m l =>
(forall i0 v, In (i0,v) l -> get i0 (M1.set i v1 m1) = Some v))); auto.
intros [i1 i2] v Hv.
unfold get. rewrite M1.gsspec; destruct M1.elt_eq; subst; auto.
generalize (H3 _ _ Hv); unfold get.
rewrite H1; congruence.
generalize (H3 _ _ Hv); unfold get; auto.
intros m0 a k v0 H H0 H4 [i1 i2] v; unfold get in *.
simpl; destruct 1.
inv H5.
rewrite M1.gss; auto.
rewrite M1.gsspec; destruct M1.elt_eq; subst.
generalize (H4 _ _ H5); rewrite M1.gss; auto.
generalize (H4 _ _ H5); rewrite M1.gso; auto.
Qed.

Lemma xelements_keys_norepet_aux : forall A a1 l1 acc,
list_norepet (List.map (@fst (M1.elt*M2.elt) A) acc) ->
(forall a3 x y, In (a3,x) l1 -> In (a1, a3, y) acc -> False) ->
list_norepet (List.map (@fst M2.elt A) l1) ->
list_norepet (List.map (@fst (M1.elt*M2.elt) A) (fold_left (xf1 a1) l1 acc)).
Proof.
induction l1; simpl; auto.
intros acc H1 H2 H3.
apply IHl1; clear IHl1.
destruct a as [a2 x]; simpl; constructor.
intros H4.
elim list_in_map_inv with (1:=H4); clear H4; intros b [B1 B2].
destruct b; simpl in *; inv B1; eauto.
inv H3; auto.
destruct a as [a2 x]; simpl.
intros a3 z y H5 H4; simpl in H3.
destruct H4 as [H4|H4].
inv H4.
inv H3.
elim H4; apply in_map with (1:=H5) (f:= (@fst M2.elt A)).
eauto.
inv H3; auto.
Qed.

Lemma xelements_keys_norepet:
forall (A: Type) (l: list (M1.elt* (list (M2.elt* A)))) acc,
list_norepet (List.map (@fst (M1.elt*M2.elt) A) acc) ->
list_norepet (List.map (@fst M1.elt (list (M2.elt* A))) l) ->
(forall a1 l1 a2 x, In (a1,l1) l -> In ((a1,a2),x) acc -> False) ->
(forall a1 l1, In (a1,l1) l -> list_norepet (List.map (@fst M2.elt A) l1)) ->
list_norepet (List.map (@fst (M1.elt*M2.elt) A) (fold_left xf2 l acc)).
Proof.
induction l; simpl; auto.
intros acc H1 H4 H2 H3; apply IHl; clear IHl.
destruct a as [a1 l1].
unfold xf2.
apply xelements_keys_norepet_aux; eauto.
inv H4; auto.
destruct a as [a1 l1]; simpl in *.
inv H4.
assert (forall (a2 : M1.elt) (l2 : list (M2.elt * A)) (a3 : M2.elt) (x : A),
In (a2, l2) l -> In (a2, a3, x) acc -> False) by eauto.
clear H3 H2 H1 H6.
intros a3 l3; generalize dependent acc.
induction l1; simpl; eauto.
intros.
elim IHl1 with (2:=H0) (3:=H1).
destruct a; simpl; intros.
destruct H3; eauto.
inv H3.
elim H5; apply in_map with (1:=H2) (f:=@fst M1.elt (list (M2.elt * A))).
eauto.
Qed.

Lemma map_map : forall (A B C: Type) (g:A->B) (f:B->C) l,
List.map f (List.map g l) = List.map (fun x => f (g x)) l.
Proof.
induction l; simpl; try f_equal; auto.
Qed.

Lemma elements_keys_norepet:
forall (A: Type) (m: t A),
list_norepet (List.map (@fst elt A) (elements m)).
Proof.
intros A m; rewrite <- xelements_eq; unfold xelements.
apply xelements_keys_norepet; auto.
constructor.
rewrite map_map; simpl.
replace (fun x : M1.elt * M2.t A => fst x) with (@fst M1.elt (M2.t A)).
apply M1.elements_keys_norepet.
apply extensionality; auto.
intros.
elim list_in_map_inv with (1:=H); clear H.
intros [b1 q1]; simpl; destruct 1.
inv H.
apply M2.elements_keys_norepet.
Qed.

Lemma elements_canonical_order_aux1:
forall (A B: Type) R (l1: list (M1.elt * list (M2.elt * A))) (l2: list (M1.elt * list (M2.elt * B))) acc1 acc2,
list_forall2 (fun x1 x2 => list_forall2 (fun b1 b2 => R (fst x1, fst b1, snd b1) (fst x2, fst b2, snd b2)) (snd x1) (snd x2)) l1 l2 ->
list_forall2 R acc1 acc2 ->
list_forall2 R (fold_left xf2 l1 acc1) (fold_left xf2 l2 acc2).
Proof.
induction l1; destruct l2; simpl; intros; try constructor; auto.
inv H.
inv H.
apply IHl1; auto; clear IHl1.
inv H; auto.
destruct a; destruct p; simpl.
inv H.
clear H6; simpl in *.
generalize dependent acc1; generalize acc2; clear acc2.
induction H4; simpl; auto.
intros; apply IHlist_forall2; clear IHlist_forall2.
destruct a1; destruct b1; simpl in *.
constructor; auto.
Qed.

Lemma list_forall2_map: forall (A B C D:Type) (R:C->D->Prop) f g (l1:list A) (l2:list B),
list_forall2 (fun x y => R (f x) (g y)) l1 l2 ->
list_forall2 R (List.map f l1) (List.map g l2).
Proof.
induction 1; constructor; auto.
Qed.

Lemma list_forall2_monotone : forall (A B:Type) (R1 R2:A->B->Prop) l1 l2,
(forall x y, R1 x y -> R2 x y) ->
list_forall2 R1 l1 l2 -> list_forall2 R2 l1 l2.
Proof.
induction 2; constructor; auto.
Qed.

Definition fold (A B: Type) (f: B -> elt -> A -> B) (m: t A) (b:B) : B :=
List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) b.

Lemma fold_left_extensionnal : forall (A B: Type) (f1 f2:A->B->A) l x,
(forall b a, f1 a b = f2 a b) ->
fold_left f1 l x = fold_left f2 l x.
Proof.
induction l; simpl; auto.
intros; rewrite IHl; auto.
rewrite H; auto.
Qed.

Lemma fold_left_map : forall (A B C: Type) (f:A->B->A) (g:C->B) l x,
fold_left f (List.map g l) x = fold_left (fun x y => f x (g y)) l x.
Proof.
induction l; simpl; auto.
Qed.

Lemma fold_spec:
forall (A B: Type) (f: B -> elt -> A -> B) (v: B) (m: t A),
fold f m v =
List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) v.
Proof.
unfold fold; auto.
Qed.

End MakeProdTree.

Module MakeProdMap (M1:MAP) (M2:MAP) <: MAP.

Definition elt: Type := (M1.elt * M2.elt)%type.

Lemma elt_eq: forall (a b: elt), {a = b} + {a <> b}.
Proof.
decide equality.
apply M2.elt_eq.
apply M1.elt_eq.
Qed.

Definition t (A:Type) : Type := M1.t (M2.t A).

Definition init (A: Type) (a: A) : t A := M1.init (M2.init a).

Definition get (A: Type) (x:elt) (m:t A) : A :=
let (x1,x2) := x in
M2.get x2 (M1.get x1 m).

Definition set (A: Type) (x:elt) (v:A) (m:t A) : t A :=
let (x1,x2) := x in
M1.set x1 (M2.set x2 v (M1.get x1 m)) m.

Lemma gi:
forall (A: Type) (i: elt) (x: A), get i (init x) = x.
Proof.
intros A [i1 i2] x; unfold get, init.
rewrite M1.gi.
rewrite M2.gi; auto.
Qed.

Lemma gss:
forall (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = x.
Proof.
intros A [i1 i2] x m; unfold get, set.
rewrite M1.gss; rewrite M2.gss; auto.
Qed.

Lemma gso:
forall (A: Type) (i j: elt) (x: A) (m: t A),
i <> j -> get i (set j x m) = get i m.
Proof.
intros A [i1 i2] [j1 j2] x m H; unfold get, set.
destruct (M1.elt_eq i1 j1); subst.
rewrite M1.gss; rewrite M2.gso; congruence.
rewrite M1.gso; auto.
Qed.

Lemma gsspec:
forall (A: Type) (i j: elt) (x: A) (m: t A),
get i (set j x m) = if elt_eq i j then x else get i m.
Proof.
intros A i j x m.
destruct (elt_eq i j); subst.
apply gss.
apply gso; auto.
Qed.

Lemma gsident:
forall (A: Type) (i j: elt) (m: t A), get j (set i (get i m) m) = get j m.
Proof.
intros A i j m.
destruct (elt_eq i j); subst.
rewrite gss; auto.
rewrite gso; auto.
Qed.

Definition map (A B: Type) (f: A -> B) (m: t A) : t B :=
M1.map (M2.map f) m.

Lemma gmap:
forall (A B: Type) (f: A -> B) (i: elt) (m: t A),
get i (map f m) = f(get i m).
Proof.
intros A B f [i1 i2] m; unfold get, map.
rewrite M1.gmap; rewrite M2.gmap; auto.
Qed.

End MakeProdMap.