Module Cshmdefgen

Require Import Coqlib.
Require Import Maps.
Require Import Errors.
Require Import Integers.
Require Import Globalenvs.
Require Import Memdata.
Require Import Floats.
Require Import AST.
Require Import Ctypes.
Require Import Cop.
Require Import Cminor.
Require Import Csharpminorannot.
Require Cshmstackgen.
Import Utf8.
Import Annotations.
Require Import ArithLib.
Import Util.

Section TRANSF.

  Variable live_annot: ident -> bool.

  Variable ge: genv.
  Variable STK: ident.
  Variable SIZE: ident.
  
  Fixpoint collect_annotation (e: expr): list (memory_chunk * annotation * expr) :=
    match e with
      | Econst c => nil
      | Evar v => nil
      | Eaddrof a => nil
      | Eunop op a => collect_annotation a
      | Ebinop op a1 a2 => (collect_annotation a1) ++ (collect_annotation a2)
      | Eload ((n, _) as alpha) kappa a =>
        if live_annot n
        then (kappa, alpha, a)::(collect_annotation a)
        else collect_annotation a
    end.

  Fixpoint make_bound (κ: memory_chunk) (e: expr) (base: Z) (n: nat) :=
    match n with
      | S n =>
        let rec := make_bound κ e (base + 1) n in
        if Zdivides_dec (align_chunk κ) base
        then Ebinop Oadd e (Econst (Ointconst (Int.repr base)))::rec
        else rec
      | O => nil
    end.

  Definition orify (el: list expr): expr :=
    match el with
      | nil => Econst (Ointconst (Int.zero))
      | e::el =>
        List.fold_left (Ebinop Oor) el e
    end.

  Definition access_stk (n: nat) := Eload (xH, nil) Mint32 (Ebinop Oadd (Eaddrof STK) (Ebinop Osub (Eload (xH, nil) Mint32 (Eaddrof SIZE)) (Econst (Ointconst (Int.repr (4 * Z.of_nat n)))))).
  
List+option monad. Bind.

  Fixpoint flat_map_o {X Y} (f: X -> option (list Y)) (m: list X) : option (list Y) :=
    match m with
    | nil => Some nil
    | x :: m =>
      do y <- f x;
      do tl <- flat_map_o f m;
      Some (y ++ tl)
    end.

  Definition is_singleton (alpha: list ablock): bool :=
    match alpha with
    | nil => false
    | _::nil => true
    | (ABglobal id _)::alpha => fold_left (fun acc x => match x with ABlocal _ _ _ => false | ABglobal id' _ => ((ident_eq id id') && acc) end) alpha true
    | (ABlocal d id _)::alpha => fold_left (fun acc x => match x with ABglobal _ _ => false | ABlocal d' id' _ => ((eq_nat_dec d d') && (ident_eq id id') && acc) end) alpha true
    end.

  Definition convert_annotation (κ: memory_chunk) (e: expr) (α: list ablock) : option (list expr) :=
    match α with
    | nil => None
    | _ =>
      if is_singleton α then
        flat_map_o
        (λ a,
         let '(base, bound) := match a with ABlocal _ _ r | ABglobal _ r => r end in
         do ptr <- match a with ABlocal n _ _ => Some (access_stk n) | ABglobal g _ => do _ <- Genv.find_symbol ge g; Some (Eaddrof g) end;
         Some (if Int.eq_dec base bound then (Ebinop (Ocmpu Ceq) (Ebinop Oadd ptr (Econst (Ointconst base))) e)::nil else (Ebinop Oand (Ebinop (Ocmpu Cle) (Ebinop Oadd ptr (Econst (Ointconst base))) e) (Ebinop (Ocmpu Cle) e (Ebinop Oadd ptr (Econst (Ointconst bound)))))::nil))
        α
      else
        flat_map_o
        (λ a,
         let '(base, bound) := match a with ABlocal _ _ r | ABglobal _ r => r end in
         do ptr <- match a with ABlocal n _ _ => Some (access_stk n) | ABglobal g _ => do _ <- Genv.find_symbol ge g; Some (Eaddrof g) end;
         let possible_addr := make_bound κ ptr (Int.unsigned base) (S (Z.to_nat (Int.unsigned bound - Int.unsigned base))) in
         Some (map (Ebinop (Ocmpu Ceq) e) possible_addr))
        α
    end.

  Definition inhab (opt: option AST.typ): option expr :=
    match opt with
    | None => None
    | Some AST.Tint => Some (Econst (Ointconst Int.zero))
    | Some AST.Tfloat => Some (Econst (Ofloatconst Float.zero))
    | Some AST.Tlong => Some (Econst (Olongconst Int64.zero))
    | Some AST.Tsingle => Some (Econst (Osingleconst Float32.zero))
    | Some AST.Tany32 => Some (Econst (Ointconst Int.zero))
    | Some AST.Tany64 => Some (Econst (Olongconst Int64.zero))
    end.
      
  Fixpoint make_ifs (new_temp: ident) (opt: option AST.typ) (l: list (memory_chunk * annotation * expr)) (k: stmt) :=
    match l with
      | nil => k
      | x::l =>
        match x with
          | (kappa, alpha, e) =>
            match convert_annotation kappa (Evar new_temp) (snd alpha) with
              | None => make_ifs new_temp opt l k
              | Some el => make_ifs new_temp opt l (Sseq (Sset new_temp e) (Sifthenelse (orify el) k (Sifthenelse (Ebinop Odiv (Econst (Ointconst (Int.repr (Zpos (fst alpha))))) (Econst (Ointconst Int.zero))) (Sreturn (inhab opt)) (Sreturn (inhab opt)))))
            end
        end
    end.

  Local Open Scope error_monad_scope.

  Fixpoint remove_annot_expr (e: expr): res expr :=
    match e with
      | Econst _ | Evar _ | Eaddrof _ => OK e
      | Eunop op a =>
        do a <- remove_annot_expr a;
        OK (Eunop op a)
      | Ebinop op a1 a2 =>
        do a1 <- remove_annot_expr a1;
        do a2 <- remove_annot_expr a2;
        OK (Ebinop op a1 a2)
      | Eload (n, alpha) kappa a =>
        do a <- remove_annot_expr a;
        OK (Eload (n, nil) kappa a)
    end.

  Fixpoint remove_annot_exprlist (el: list expr): res (list expr) :=
    match el with
      | nil => OK nil
      | e::el =>
        do e <- remove_annot_expr e;
        do el <- remove_annot_exprlist el;
        OK (e::el)
    end.

  Fixpoint remove_annot_stmt (s: stmt): res stmt :=
        match s with
      | Sskip => OK Sskip
      | Sset id e =>
        do e <- remove_annot_expr e;
        OK (Sset id e)
      | Sstore (n, alpha) kappa e1 e2 =>
        do e1 <- remove_annot_expr e1;
        do e2 <- remove_annot_expr e2;
        OK (Sstore (n, nil) kappa e1 e2)
      | Scall optid sig e el =>
        do e <- remove_annot_expr e;
        do el <- remove_annot_exprlist el;
        OK (Scall optid sig e el)
      | Sbuiltin optid ef el =>
        do el <- remove_annot_exprlist el;
        OK (Sbuiltin optid ef el)
      | Sseq s1 s2 =>
        do s1 <- remove_annot_stmt s1;
        do s2 <- remove_annot_stmt s2;
        OK (Sseq s1 s2)
      | Sifthenelse e s1 s2 =>
        do e <- remove_annot_expr e;
        do s1 <- remove_annot_stmt s1;
        do s2 <- remove_annot_stmt s2;
        OK (Sifthenelse e s1 s2)
      | Sloop s =>
        do s <- remove_annot_stmt s;
        OK (Sloop s)
      | Sblock s =>
        do s <- remove_annot_stmt s;
        OK (Sblock s)
      | Sexit n => OK (Sexit n)
      | Sswitch islong e ls =>
        do e <- remove_annot_expr e;
        do ls <- remove_annot_lstmt ls;
        OK (Sswitch islong e ls)
      | Sreturn None => OK (Sreturn None)
      | Sreturn (Some e) =>
        do e <- remove_annot_expr e;
        OK (Sreturn (Some e))
      | Slabel lbl s =>
        do s <- remove_annot_stmt s;
        OK (Slabel lbl s)
      | Sgoto lbl => OK (Sgoto lbl)
    end
  with remove_annot_lstmt (ls: lbl_stmt): res lbl_stmt :=
         match ls with
           | LSnil => OK LSnil
           | LScons lbl s ls =>
             do s <- remove_annot_stmt s;
             do ls <- remove_annot_lstmt ls;
             OK (LScons lbl s ls)
         end.

  Definition epilogue: stmt :=
    Sstore (xH, nil) Mint32 (Eaddrof SIZE) (Ebinop Osub (Eload (xH, nil) Mint32 (Eaddrof SIZE)) (Econst (Ointconst (Int.repr 4)))).
    
  Fixpoint transl_stmt (opt: option AST.typ) (tret new_temp: ident) (s: stmt): res stmt :=
    match s with
      | Sskip => OK Sskip
      | Sset id e =>
        let checks := collect_annotation e in
        remove_annot_stmt (make_ifs new_temp opt checks s)
      | Sstore ((n, _) as alpha) kappa e1 e2 =>
        let checks := collect_annotation e1 in
        let checks := collect_annotation e2 ++ checks in
        remove_annot_stmt (make_ifs new_temp opt (if live_annot n then (kappa, alpha, e1)::checks else checks) s)
      | Scall optid sig e el =>
        let checks := collect_annotation e in
        let checks' := fold_left (fun acc x => (collect_annotation x) ++ acc) el nil in
        remove_annot_stmt (make_ifs new_temp opt (checks ++ checks') (Scall optid sig e el))
      | Sbuiltin optid ef el =>
        let checks := fold_left (fun acc x => (collect_annotation x) ++ acc) el nil in
        remove_annot_stmt (make_ifs new_temp opt checks (Sbuiltin optid ef el))
      | Sseq s1 s2 =>
        do s1 <- transl_stmt opt tret new_temp s1;
        do s2 <- transl_stmt opt tret new_temp s2;
        remove_annot_stmt (Sseq s1 s2)
      | Sifthenelse e s1 s2 =>
        do s1 <- transl_stmt opt tret new_temp s1;
        do s2 <- transl_stmt opt tret new_temp s2;
        let checks := collect_annotation e in
        remove_annot_stmt (make_ifs new_temp opt checks (Sifthenelse e s1 s2))
      | Sloop s =>
        do s <- transl_stmt opt tret new_temp s;
        remove_annot_stmt (Sloop s)
      | Sblock s =>
        do s <- transl_stmt opt tret new_temp s;
        remove_annot_stmt (Sblock s)
      | Sexit n => remove_annot_stmt (Sexit n)
      | Sswitch islong e ls =>
        do ls <- transl_lstmt opt tret new_temp ls;
        let checks := collect_annotation e in
        remove_annot_stmt (make_ifs new_temp opt checks (Sswitch islong e ls))
      | Sreturn None => remove_annot_stmt (Sseq epilogue (Sreturn None))
      | Sreturn (Some e) =>
        let checks := collect_annotation e in
        remove_annot_stmt (make_ifs new_temp opt checks (Sseq (Sset tret e) (Sseq epilogue (Sreturn (Some (Evar tret))))))
      | Slabel lbl s =>
        do s <- transl_stmt opt tret new_temp s;
        remove_annot_stmt (Slabel lbl s)
      | Sgoto lbl => remove_annot_stmt (Sgoto lbl)
    end
  with transl_lstmt (opt: option AST.typ) (tret new_temp: ident) (ls: lbl_stmt): res lbl_stmt :=
         match ls with
           | LSnil => remove_annot_lstmt LSnil
           | LScons lbl s ls =>
             do s <- transl_stmt opt tret new_temp s;
             do ls <- transl_lstmt opt tret new_temp ls;
             remove_annot_lstmt (LScons lbl s ls)
         end.

  Definition prelude (sp: ident): stmt :=
    Sseq (Sstore (xH, nil) Mint32 (Eaddrof SIZE) (Ebinop Oadd (Eload (xH, nil) Mint32 (Eaddrof SIZE)) (Econst (Ointconst (Int.repr 4))))) (Sstore (xH, nil) Mint32 (Ebinop Oadd (Eaddrof STK) (Eload (xH, nil) Mint32 (Eaddrof SIZE))) (Eaddrof sp)).

  Definition transl_function (f: function): res function :=
    let tret := Psucc (List.fold_left (fun acc x => if plt acc x then x else acc) (f.(fn_params) ++ f.(fn_temps)) xH) in
    let new_temp := Psucc tret in
    match f.(fn_vars) with
    | sp::nil =>
        do tbody <- transl_stmt f.(fn_sig).(sig_res) tret new_temp f.(fn_body);
        OK (mkfunction f.(fn_sig) f.(fn_params) f.(fn_vars) (new_temp::tret::f.(fn_temps)) (Sseq (prelude (fst sp)) (Sseq tbody epilogue)))
    | _ => Error (msg "function does not have only one local variable")
    end.
  
  Definition transl_fundef (fd: fundef): res fundef :=
    match fd with
      | Internal f =>
        do tf <- transl_function f;
        OK (Internal tf)
      | External tf => OK fd
    end.

End TRANSF.

Definition STK_globvar :=
  mkglobvar tt ((Init_space 512)::nil) false false.

Definition SIZE_globvar :=
  mkglobvar tt ((Init_int32 (Int.repr (-4)))::nil) false false.

Local Open Scope error_monad_scope.
Arguments OK {A} _.
Definition transl_program (live_annot: ident -> bool) (p: program): res program :=
  let ge := Genv.globalenv p in
  let names := List.map fst p.(prog_defs) in
  let STK := Psucc (Psucc (List.fold_left Pmax names xH)) in
  let SIZE := Psucc STK in
  AST.transform_partial_augment_program
    (transl_fundef live_annot ge STK SIZE)
    OK
    ((STK, Gvar STK_globvar) :: (SIZE, Gvar SIZE_globvar) :: nil)
    p.(prog_main)
    p.