Module Switch

Multi-way branches (``switch'' statements) and their compilation to comparison trees.

Require Import EqNat.
Require Import FMaps.
Require FMapAVL.
Require Import Coqlib.
Require Import Integers.
Require Import Ordered.

Module IntMap := FMapAVL.Make(OrderedInt).
Module IntMapF := FMapFacts.Facts(IntMap).

A multi-way branch is composed of a list of (key, action) pairs, plus a default action.

Definition table (A: Type): Type := list (int * A).

Fixpoint switch_target {A: Type} (n: int) (dfl: A) (cases: table A)
                       {struct cases} : A :=
  match cases with
  | nil => dfl
  | (key, action) :: rem =>
      if Int.eq n key then action else switch_target n dfl rem

Multi-way branches are translated to comparison trees. Each node of the tree performs either


Context {A: Type}.
Variable eqA: forall (x y: A), {x=y} + {x<>y}.

Inductive comptree: Type :=
  | CTaction: A -> comptree
  | CTifeq: int -> A -> comptree -> comptree
  | CTiflt: int -> comptree -> comptree -> comptree
  | CTjumptable: int -> int -> list A -> comptree -> comptree.

Fixpoint comptree_match (n: int) (t: comptree) {struct t}: option A :=
  match t with
  | CTaction act => Some act
  | CTifeq key act t' =>
      if Int.eq n key then Some act else comptree_match n t'
  | CTiflt key t1 t2 =>
      if Int.ltu n key then comptree_match n t1 else comptree_match n t2
  | CTjumptable ofs sz tbl t' =>
      if Int.ltu (Int.sub n ofs) sz
      then list_nth_z tbl (Int.unsigned (Int.sub n ofs))
      else comptree_match n t'

The translation from a table to a comparison tree is performed by untrusted Caml code (function compile_switch in file In Coq, we validate a posteriori the result of this function. In other terms, we now develop and prove correct Coq functions that take a table and a comparison tree, and check that their semantics are equivalent.

Fixpoint split_lt (pivot: int) (cases: table A)
                  {struct cases} : table A * table A :=
  match cases with
  | nil => (nil, nil)
  | (key, act) :: rem =>
      let (l, r) := split_lt pivot rem in
      if Int.ltu key pivot
      then ((key, act) :: l, r)
      else (l, (key, act) :: r)

Fixpoint split_eq (pivot: int) (cases: table A)
                  {struct cases} : option A * table A :=
  match cases with
  | nil => (None, nil)
  | (key, act) :: rem =>
      let (same, others) := split_eq pivot rem in
      if Int.eq key pivot
      then (Some act, others)
      else (same, (key, act) :: others)

Fixpoint split_between (ofs sz: int) (cases: table A)
                       {struct cases} : IntMap.t A * table A :=
  match cases with
  | nil => (IntMap.empty A, nil)
  | (key, act) :: rem =>
      let (inside, outside) := split_between ofs sz rem in
      if Int.ltu (Int.sub key ofs) sz
      then (IntMap.add key act inside, outside)
      else (inside, (key, act) :: outside)

Definition refine_low_bound (v lo: Z) :=
  if zeq v lo then lo + 1 else lo.

Definition refine_high_bound (v hi: Z) :=
  if zeq v hi then hi - 1 else hi.

Fixpoint validate_jumptable (cases: IntMap.t A) (default: A)
                            (tbl: list A) (n: int) {struct tbl} : bool :=
  match tbl with
  | nil => true
  | act :: rem =>
      eqA act (match IntMap.find n cases with Some a => a | None => default end)
      && validate_jumptable cases default rem (Int.add n

Fixpoint validate (default: A) (cases: table A) (t: comptree)
                  (lo hi: Z) {struct t} : bool :=
  match t with
  | CTaction act =>
      match cases with
      | nil =>
          eqA act default
      | (key1, act1) :: _ =>
          zeq (Int.unsigned key1) lo && zeq lo hi && eqA act act1
  | CTifeq pivot act t' =>
      match split_eq pivot cases with
      | (None, _) =>
      | (Some act', others) =>
          eqA act act'
          && validate default others t'
                      (refine_low_bound (Int.unsigned pivot) lo)
                      (refine_high_bound (Int.unsigned pivot) hi)
  | CTiflt pivot t1 t2 =>
      match split_lt pivot cases with
      | (lcases, rcases) =>
          validate default lcases t1 lo (Int.unsigned pivot - 1)
          && validate default rcases t2 (Int.unsigned pivot) hi
  | CTjumptable ofs sz tbl t' =>
      let tbl_len := list_length_z tbl in
      match split_between ofs sz cases with
      | (inside, outside) =>
          zle (Int.unsigned sz) tbl_len
          && zle tbl_len Int.max_signed
          && validate_jumptable inside default tbl ofs
          && validate default outside t' lo hi

Definition validate_switch (default: A) (cases: table A) (t: comptree) :=
  validate default cases t 0 Int.max_unsigned.

Correctness proof for validation.

Lemma split_eq_prop:
  forall v default n cases optact cases',
  split_eq n cases = (optact, cases') ->
  switch_target v default cases =
   (if Int.eq v n
    then match optact with Some act => act | None => default end
    else switch_target v default cases').
  induction cases; simpl; intros until cases'.
  intros. inversion H; subst. simpl.
  destruct (Int.eq v n); auto.
  destruct a as [key act].
  case_eq (split_eq n cases). intros same other SEQ.
  rewrite (IHcases _ _ SEQ).
  predSpec Int.eq Int.eq_spec key n; intro EQ; inversion EQ; simpl.
  subst n. destruct (Int.eq v key). auto. auto.
  predSpec Int.eq Int.eq_spec v key.
  subst v. predSpec Int.eq Int.eq_spec key n. congruence. auto.

Lemma split_lt_prop:
  forall v default n cases lcases rcases,
  split_lt n cases = (lcases, rcases) ->
  switch_target v default cases =
    (if Int.ltu v n
     then switch_target v default lcases
     else switch_target v default rcases).
  induction cases; intros until rcases; simpl.
  intro. inversion H; subst. simpl.
  destruct (Int.ltu v n); auto.
  destruct a as [key act].
  case_eq (split_lt n cases). intros lc rc SEQ.
  rewrite (IHcases _ _ SEQ).
  case_eq (Int.ltu key n); intros; inv H0; simpl.
  predSpec Int.eq Int.eq_spec v key.
  subst v. rewrite H. auto.
  predSpec Int.eq Int.eq_spec v key.
  subst v. rewrite H. auto.

Lemma split_between_prop:
  forall v default ofs sz cases inside outside,
  split_between ofs sz cases = (inside, outside) ->
  switch_target v default cases =
    (if Int.ltu (Int.sub v ofs) sz
     then match IntMap.find v inside with Some a => a | None => default end
     else switch_target v default outside).
  induction cases; intros until outside; simpl.
  intros. inv H. simpl. destruct (Int.ltu (Int.sub v ofs) sz); auto.
  destruct a as [key act]. case_eq (split_between ofs sz cases). intros ins outs SEQ.
  rewrite (IHcases _ _ SEQ).
  case_eq (Int.ltu (Int.sub key ofs) sz); intros; inv H0; simpl.
  rewrite IntMapF.add_o.
  predSpec Int.eq Int.eq_spec v key.
  subst v. rewrite H. rewrite dec_eq_true. auto.
  rewrite dec_eq_false; auto.
  case_eq (Int.ltu (Int.sub v ofs) sz); intros; auto.
  predSpec Int.eq Int.eq_spec v key.
  subst v. congruence.

Lemma validate_jumptable_correct_rec:
  forall cases default tbl base v,
  validate_jumptable cases default tbl base = true ->
  0 <= Int.unsigned v < list_length_z tbl ->
  list_nth_z tbl (Int.unsigned v) =
  Some(match IntMap.find (Int.add base v) cases with Some a => a | None => default end).
  induction tbl; intros until v; simpl.
  unfold list_length_z; simpl. intros. omegaContradiction.
  rewrite list_length_z_cons. intros. destruct (andb_prop _ _ H). clear H.
  exploit proj_sumbool_true; eauto. intro EQ. subst a. clear H1.
  destruct (zeq (Int.unsigned v) 0).
  unfold Int.add. rewrite e. rewrite Zplus_0_r. rewrite Int.repr_unsigned. auto.
  assert (Int.unsigned (Int.sub v = Int.unsigned v - 1).
    unfold Int.sub. change (Int.unsigned with 1.
    apply Int.unsigned_repr. split. omega.
    generalize (Int.unsigned_range_2 v). omega.
  replace (Int.add base v) with (Int.add (Int.add base (Int.sub v
  rewrite <- IHtbl. rewrite H. auto. auto. rewrite H. omega.
  rewrite Int.sub_add_opp. rewrite Int.add_permut. rewrite Int.add_assoc.
  replace (Int.add (Int.neg with
  rewrite Int.add_zero. apply Int.add_commut.
  apply Int.mkint_eq. reflexivity.

Lemma validate_jumptable_correct:
  forall cases default tbl ofs v sz,
  validate_jumptable cases default tbl ofs = true ->
  Int.ltu (Int.sub v ofs) sz = true ->
  Int.unsigned sz <= list_length_z tbl ->
  list_nth_z tbl (Int.unsigned (Int.sub v ofs)) =
  Some(match IntMap.find v cases with Some a => a | None => default end).
  exploit Int.ltu_inv; eauto. intros.
  rewrite (validate_jumptable_correct_rec cases default tbl ofs).
  rewrite Int.sub_add_opp. rewrite Int.add_permut. rewrite <- Int.sub_add_opp.
  rewrite Int.sub_idem. rewrite Int.add_zero. auto.

Lemma validate_correct_rec:
  forall default v t cases lo hi,
  validate default cases t lo hi = true ->
  lo <= Int.unsigned v <= hi ->
  comptree_match v t = Some (switch_target v default cases).
Opaque Int.sub.
  induction t; simpl; intros until hi.
  destruct cases as [ | [key1 act1] cases1]; intros.
  exploit proj_sumbool_true; eauto. intros EQ; subst a; clear H.
  destruct (andb_prop _ _ H). destruct (andb_prop _ _ H1). clear H H1.
  assert (Int.unsigned key1 = lo). eapply proj_sumbool_true; eauto.
  assert (lo = hi). eapply proj_sumbool_true; eauto.
  assert (Int.unsigned v = Int.unsigned key1). omega.
  exploit proj_sumbool_true. eexact H2. intros EQ; subst a; clear H2.
  simpl. unfold Int.eq. rewrite H5. rewrite zeq_true. auto.
  case_eq (split_eq i cases). intros optact cases' EQ.
  destruct optact as [ act | ]. 2: congruence.
  intros. destruct (andb_prop _ _ H). clear H.
  rewrite (split_eq_prop v default _ _ _ _ EQ).
  predSpec Int.eq Int.eq_spec v i.
  f_equal. eapply proj_sumbool_true; eauto.
  eapply IHt. eauto.
  assert (Int.unsigned v <> Int.unsigned i).
    rewrite <- (Int.repr_unsigned v) in H.
    rewrite <- (Int.repr_unsigned i) in H.
  unfold refine_low_bound. destruct (zeq (Int.unsigned i) lo); omega.
  unfold refine_high_bound. destruct (zeq (Int.unsigned i) hi); omega.
  case_eq (split_lt i cases). intros lcases rcases EQ V RANGE.
  destruct (andb_prop _ _ V). clear V.
  rewrite (split_lt_prop v default _ _ _ _ EQ).
  unfold Int.ltu. destruct (zlt (Int.unsigned v) (Int.unsigned i)).
  eapply IHt1. eauto. omega.
  eapply IHt2. eauto. omega.
  case_eq (split_between i i0 cases). intros ins outs EQ V RANGE.
  destruct (andb_prop _ _ V). clear V.
  destruct (andb_prop _ _ H). clear H.
  destruct (andb_prop _ _ H1). clear H1.
  rewrite (split_between_prop v _ _ _ _ _ _ EQ).
  case_eq (Int.ltu (Int.sub v i) i0); intros.
  eapply validate_jumptable_correct; eauto.
  eapply proj_sumbool_true; eauto.
  eapply IHt; eauto.

Definition table_tree_agree
    (default: A) (cases: table A) (t: comptree) : Prop :=
  forall v, comptree_match v t = Some(switch_target v default cases).

Theorem validate_switch_correct:
  forall default t cases,
  validate_switch default cases t = true ->
  table_tree_agree default cases t.
  unfold validate_switch, table_tree_agree; intros.
  eapply validate_correct_rec; eauto.
  apply Int.unsigned_range_2.