Library TreeAl

Require Import
  Coqlib
  Maps
  DLib.

Set Implicit Arguments.

Module Type TYPE_EQ.
  Variable s: Type.
  Variable t: Type.
  Variable t_of_s : st.
  Variable s_of_t : ts.
  Hypothesis inj : x : s, s_of_t (t_of_s x) = x.
  Hypothesis surj: x : t, t_of_s (s_of_t x) = x.
  Variable eq: (x y: s), {x = y} + {x y}.
End TYPE_EQ.

Module TYPE_EQ_PROP (X:TYPE_EQ).
  Lemma injective (a b: X.s) :
    a bX.t_of_s a X.t_of_s b.
  Corollary injective' (a b: X.s) :
    X.t_of_s a = X.t_of_s ba = b.
  Lemma tinjective (a b: X.t) :
    a bX.s_of_t a X.s_of_t b.
  Lemma tinjective' (a b: X.t) :
    X.s_of_t a = X.s_of_t ba = b.
End TYPE_EQ_PROP.

Module Z_EQ_POS : TYPE_EQ
    with Definition s := Z
    with Definition t := positive.
  Definition s := Z.
  Definition t := positive.
  Definition t_of_s (z: s) : t :=
    match z with
      | Z0xH
      | Zpos pxO p
      | Zneg pxI p
    end.
  Definition s_of_t (p: t) : s :=
    match p with
      | xHZ0
      | xO xZpos x
      | xI xZneg x
    end.
  Lemma inj : x : s, s_of_t (t_of_s x) = x.
  Lemma surj: x : t, t_of_s (s_of_t x) = x.
  Definition eq: (x y: s), {x = y} + {x y} :=
    Z_eq_dec.
End Z_EQ_POS.

Lemma fold_left_map:
   {A B C} (f: AB) (g: CBC) xs a,
    fold_left g (map f xs) a = fold_left (fun a xg a (f x)) xs a.

Module BijTree (X:TYPE_EQ) (TTree: TREE with Definition elt := X.t) <: TREE.
  Module P := TYPE_EQ_PROP(X).
  Hint Resolve P.injective P.injective' P.tinjective P.tinjective'.
  Definition elt: Type := X.s.
  Definition elt_eq: (a b: elt), {a = b} + {a b} := X.eq.
  Definition t: TypeType := TTree.t.
  Definition empty: (A: Type), t A := TTree.empty.
  Definition get A i g : option A := TTree.get (X.t_of_s i) g.
  Definition set A i v g : t A := TTree.set (X.t_of_s i) v g.
  Definition remove A i g : t A := TTree.remove (X.t_of_s i) g.

  Lemma gempty: (A: Type) (i: elt), get i (empty A) = None.
  Lemma gss: (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
  Lemma gso: (A: Type) (i j: elt) (x: A) (m: t A),
    i jget i (set j x m) = get i m.
  Lemma gsspec: (A: Type) (i j: elt) (x: A) (m: t A),
    get i (set j x m) = if elt_eq i j then Some x else get i m.
  Lemma gsident: (A: Type) (i: elt) (m: t A) (v: A),
    get i m = Some vset i v m = m.
  Lemma grs: (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
  Lemma gro: (A: Type) (i j: elt) (m: t A),
    i jget i (remove j m) = get i m.
  Lemma grspec: (A: Type) (i j: elt) (m: t A),
    get i (remove j m) = if elt_eq i j then None else get i m.

  Definition beq A cmp (t1 t2: t A) : bool :=
    TTree.beq cmp t1 t2.
  Lemma beq_correct:
     (A: Type) (eqA: AAbool) (t1 t2: t A),
      beq eqA t1 t2 = true
      ( (x: elt),
         match get x t1, get x t2 with
           | None, NoneTrue
           | Some y1, Some y2eqA y1 y2 = true
           | _, _False
         end).

Applying a function to all data of a tree.
  Definition map A B (f: eltAB) (g: t A) : t B :=
    TTree.map (fun e ⇒ (f (X.s_of_t e))) g.

  Lemma gmap: (A B: Type) (f: eltAB) (i: elt) (m: t A),
    get i (map f m) = option_map (f i) (get i m).

Same as map, but the function does not receive the elt argument.
  Definition map1 A B (f: AB) g := TTree.map1 f g.
  Lemma gmap1:
     (A B: Type) (f: AB) (i: elt) (m: t A),
    get i (map1 f m) = option_map f (get i m).

Filtering data that match a given predicate. * ) Variable filter1: forall (A: Type), (A -> bool) -> t A -> t A. Hypothesis gfilter1: forall (A: Type) (pred: A -> bool) (i: elt) (m: t A), get i (filter1 pred m) = match get i m with None => None | Some x => if pred x then Some x else None end.
Applying a function pairwise to all data of two trees.
  Definition combine A B C (f:option Aoption Boption C) g1 g2 : t C :=
    TTree.combine f g1 g2.
  Lemma gcombine:
     (A B C: Type) (f: option Aoption Boption C),
    f None None = None
     (m1: t A) (m2: t B) (i: elt),
    get i (combine f m1 m2) = f (get i m1) (get i m2).
  Lemma combine_commut:
     (A B: Type) (f g: option Aoption Aoption B),
    ( (i j: option A), f i j = g j i) →
     (m1 m2: t A),
    combine f m1 m2 = combine g m2 m1.

Enumerating the bindings of a tree.
  Definition elements A g : list (elt × A) :=
    List.map (fun q(X.s_of_t (fst q), snd q))
    (TTree.elements g).
  Lemma elements_correct:
     (A: Type) (m: t A) (i: elt) (v: A),
    get i m = Some vIn (i, v) (elements m).
  Lemma elements_complete:
     (A: Type) (m: t A) (i: elt) (v: A),
    In (i, v) (elements m) → get i m = Some v.
  Lemma elements_keys_norepet:
     (A: Type) (m: t A),
    list_norepet (List.map (@fst elt A) (elements m)).

Folding a function over all bindings of a tree.
  Definition fold A B f (g: t A) b : B :=
    TTree.fold (fun x k vf x (X.s_of_t k) v) g b.
  Lemma fold_spec:
     (A B: Type) (f: BeltAB) (v: B) (m: t A),
    fold f m v =
    List.fold_left (fun a pf a (fst p) (snd p)) (elements m) v.

Folding a function over all rhs of bindings of a tree.
  Definition fold1 A B f (g: t A) b : B :=
    TTree.fold1 (fun x vf x v) g b.
  Lemma fold1_spec:
     (A B: Type) (f: BAB) (v: B) (m: t A),
    fold1 f m v =
    List.fold_left (fun a pf a (snd p)) (elements m) v.

End BijTree.

Module ZTree <: TREE with Definition elt := Z := BijTree(Z_EQ_POS)(PTree).

Module SumTree (L:TREE) (R:TREE) <: TREE.

  Definition elt := (L.elt + R.elt)%type.
  Definition elt_eq: (a b: elt), {a = b} + {a b}.
  Definition t (A: Type) : Type := (L.t A × R.t A)%type.
  Definition empty A : t A := (L.empty A, R.empty A).
  Definition get A k m : option A :=
    match k with
      | inl xL.get x (fst m)
      | inr xR.get x (snd m)
    end.
  Definition set A k v m : t A :=
    match k with
      | inl x(L.set x v (fst m), snd m)
      | inr x(fst m, R.set x v (snd m))
    end.
  Definition remove A k m : t A :=
    match k with
      | inl x(L.remove x (fst m), snd m)
      | inr x(fst m, R.remove x (snd m))
    end.

  Lemma gempty:
     (A: Type) (i: elt), get i (empty A) = None.
  Lemma gss:
     (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
  Lemma gso:
     (A: Type) (i j: elt) (x: A) (m: t A),
    i jget i (set j x m) = get i m.
  Lemma gsspec:
     (A: Type) (i j: elt) (x: A) (m: t A),
    get i (set j x m) = if elt_eq i j then Some x else get i m.
  Lemma gsident:
     (A: Type) (i: elt) (m: t A) (v: A),
    get i m = Some vset i v m = m.

  Lemma grs:
     (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
  Lemma gro:
     (A: Type) (i j: elt) (m: t A),
    i jget i (remove j m) = get i m.
  Lemma grspec:
     (A: Type) (i j: elt) (m: t A),
    get i (remove j m) = if elt_eq i j then None else get i m.

Extensional equality between trees.
  Definition beq A (cmp: AAbool) (m1 m2: t A) : bool :=
    let '(l1, r1) := m1 in
    let '(l2, r2) := m2 in
    L.beq cmp l1 l2 && R.beq cmp r1 r2.
  Lemma beq_correct:
     (A: Type) (eqA: AAbool) (t1 t2: t A),
      beq eqA t1 t2 = true
      ( (x: elt),
         match get x t1, get x t2 with
           | None, NoneTrue
           | Some y1, Some y2eqA y1 y2 = true
           | _, _False
         end).

Applying a function to all data of a tree.
  Definition map A B (f: eltAB) (m: t A) : t B :=
    let '(l, r) := m in
    (L.map (fun kf (inl _ k)) l,
     R.map (fun kf (inr _ k)) r).
  Lemma gmap:
     (A B: Type) (f: eltAB) (i: elt) (m: t A),
    get i (map f m) = option_map (f i) (get i m).

  Definition map1 A B (f: AB) (m: t A) : t B :=
    (L.map1 f (fst m), R.map1 f (snd m)).
  Lemma gmap1:
     (A B: Type) (f: AB) (i: elt) (m: t A),
    get i (map1 f m) = option_map f (get i m).

  Definition combine A B C (f: option Aoption Boption C) (m1: t A) (m2: t B) : t C :=
    let '(l1, r1) := m1 in
    let '(l2, r2) := m2 in
    (L.combine f l1 l2, R.combine f r1 r2).
  Lemma gcombine:
     (A B C: Type) (f: option Aoption Boption C),
    f None None = None
     (m1: t A) (m2: t B) (i: elt),
    get i (combine f m1 m2) = f (get i m1) (get i m2).
  Lemma combine_commut:
     (A B: Type) (f g: option Aoption Aoption B),
    ( (i j: option A), f i j = g j i) →
     (m1 m2: t A),
    combine f m1 m2 = combine g m2 m1.

Enumerating the bindings of a tree.
  Definition elements A (m: t A) : list (elt × A) :=
    (List.map (fun k(inl _ (fst k), snd k)) (L.elements (fst m)))
      ++
    (List.map (fun k(inr _ (fst k), snd k)) (R.elements (snd m)))
  .
  Lemma elements_correct:
     (A: Type) (m: t A) (i: elt) (v: A),
    get i m = Some vIn (i, v) (elements m).
  Lemma elements_complete:
     (A: Type) (m: t A) (i: elt) (v: A),
    In (i, v) (elements m) → get i m = Some v.
  Lemma elements_keys_norepet:
     (A: Type) (m: t A),
    list_norepet (List.map (@fst elt A) (elements m)).

  Definition fold A B (f: BeltAB) (m: t A) (v: B) : B :=
    (R.fold (fun b kf b (inr _ k)) (snd m)
    (L.fold (fun b kf b (inl _ k)) (fst m) v)).
  Lemma fold_spec:
     (A B: Type) (f: BeltAB) (v: B) (m: t A),
    fold f m v =
    List.fold_left (fun a pf a (fst p) (snd p)) (elements m) v.

  Definition fold1 A B (f: BAB) (m: t A) (v: B) : B :=
    (R.fold1 (fun bf b) (snd m)
    (L.fold1 (fun bf b) (fst m) v)).
  Lemma fold1_spec:
     (A B: Type) (f: BAB) (v: B) (m: t A),
    fold1 f m v =
    List.fold_left (fun a pf a (snd p)) (elements m) v.

End SumTree.