Module UtilBase

Utility definitions.

Require Import Bool.
Require Import Arith.
Require Import List.
Require Import TheoryList.
Require Import BinPos.
Require Import BinNat.
Require Import Ndec.

Require Import Coqlib.
Require Import Maps.
Require Import RTL.

Require Import UtilTacs.

Definition debug {A: Type} (dbg: A -> unit) (a: A) : A :=
  let unused := dbg a in a.

Definition Pmem {A : Type} (n : positive) (t : PTree.t A) : bool :=
  match t ! n with
    | None => false
    | Some _ => true
  end.

Node membership to a function ('Prop' version)
Definition f_In (n : node) (f : function) := f.(fn_code) ! n <> None.

Definition f_In_dec :
  forall f (n : node), {f_In n f} + {~ f_In n f}.
Proof.
  unfold f_In.
  intros.
  case_eq ((fn_code f) ! n); intros.
  left. cong.
  right. cong.
Qed.

Definition PSet := PTree.t unit.
Definition PSet_singleton (p : positive) := PTree.set p tt (PTree.empty _).
Definition PIn (p : positive) (t : PSet) := Pmem p t = true.

Lemma Pmem_iff_PIn:
  forall p t,
    Pmem p t = true <-> PIn p t.
Proof.
  unfold PIn; intuition.
Qed.

Definition PSet_union (t1 t2 : PSet) : PSet :=
  PTree.fold (fun t0 p _ => PTree.set p tt t0) t1 t2.

Definition get_function_nodes (f : function) : list node :=
  PTree.xkeys f.(fn_code) 1%positive.

Definition ptree_cardinal {A : Type} (t : PTree.t A) : nat :=
  length (PTree.elements t).

Definition unit_eq_dec : forall (x y: unit), {x=y} + {x<>y}.
Proof.
decide equality. Qed.

Section OPTION_BEQ.
  Variable A : Type.
  Variable Aeq : forall (a1 a2 : A), {a1 = a2} + {a1 <> a2}.

  Definition option_beq (o1 o2 : option A) : bool :=
    match o1, o2 with
      | None, None => true
      | Some a1, Some a2 => match Aeq a1 a2 with | left _ => true | right _ => false end
      | _, _ => false
    end.

End OPTION_BEQ.

Definition node_option_beq := option_beq _ positive_eq_dec.

Definition option_none_beq {A : Type} (o1 o2 : option A) : bool :=
  match o1, o2 with
    | None, None => true
    | Some _, Some _ => true
    | _, _ => false
  end.

Section PAIR_BEQ.
  Variable A : Type.
  Variable Aeq : forall (a1 a2 : A), {a1 = a2} + {a1 <> a2}.
  Variable B : Type.
  Variable Beq : forall (b1 b2 : B), {b1 = b2} + {b1 <> b2}.

  Definition pair_beq (p1 p2 : A * B) : bool :=
    let (a1, b1) := p1 in
      let (a2, b2) := p2 in
        if Aeq a1 a2 then if Beq b1 b2 then true else false else false.

  Definition pair_eq_dec:
    forall (p1 p2 : A * B),
      {p1 = p2} + {p1 <> p2}.
Proof.
    intros.
    destruct p1, p2.
    destruct (Aeq a a0), (Beq b b0); subst;
    try (left; refl); right; intro; inv H; cong.
  Qed.

End PAIR_BEQ.

Definition option_N_node_beq := option_beq _ (pair_eq_dec _ N_eq_dec _ positive_eq_dec).

Definition option_N_N_node_beq := option_beq _ (pair_eq_dec _ (pair_eq_dec _ N_eq_dec _ N_eq_dec) _ positive_eq_dec).

Definition option_mem (n : positive) (ol : option (list positive)) : bool :=
  match ol with
    | None => true
    | Some l => mem positive_eq_dec n l
  end.

Fixpoint adj_pairs' {A : Type} (a : A) (l : list A) : list (A * A) :=
  match l with
    | nil => nil
    | a' :: r => (a, a') :: (adj_pairs' a' r)
  end.

Definition adj_pairs {A : Type} (l : list A) : list (A * A) :=
  match l with
    | nil => nil
    | a :: r => adj_pairs' a r
  end.
  
Definition diff (l1 l2 : list node) : list node :=
  List.filter (fun n => negb (TheoryList.mem positive_eq_dec n l2)) l1.

Definition incl_dec:
  forall {A : Type} (Aeq: forall (a1 a2 : A), {a1 = a2} + {a1 <> a2}) (l1 l2 : list A),
    {incl l1 l2} + {~ incl l1 l2}.
Proof.
  induction l1; intros; unfold incl in *; trim.
  left; trim.
  gen (IHl1 l2).
  destruct H.
  destruct (In_dec Aeq a l2).
  left; intros.
  destruct H; subst; trim.
  apply i; auto.
  right.
  intro.
  apply n. apply H. left; auto.
  right.
  intro.
  apply n.
  intros.
  apply H. right; auto.
Qed.

Definition nlt : forall (n1 n2 : N), ({n1 < n2} + {~ n1 < n2})%N.
Proof.
  intros.
  destruct n1, n2.
  right; intro; trim.
  left; lia2.
  right; lia2.
  destruct (plt p p0); trim.
Qed.

Definition nle : forall (n1 n2 : N), ({n1 <= n2} + {~ n1 <= n2})%N.
Proof.
  intros.
  case_eq (Nleb n1 n2); intros.
  apply Nleb_Nle in H.
  left; auto.
  right.
  intro.
  apply Nleb_Nle in H0.
  trim.
Qed.

Notation fid := (fun s''' => s''').
Notation flatten := (flat_map fid).
Notation count := (count_occ positive_eq_dec).