Theory IsaSAT_Restart_Heuristics_LLVM

theory IsaSAT_Restart_Heuristics_LLVM
  imports IsaSAT_Restart_Heuristics_Defs IsaSAT_Setup_LLVM
     IsaSAT_VMTF_State_LLVM IsaSAT_Rephase_State_LLVM
     IsaSAT_Arena_Sorting_LLVM
     IsaSAT_Restart_Reduce_LLVM
     IsaSAT_Inprocessing_LLVM
     IsaSAT_Proofs_LLVM
begin

hide_fact (open) Sepref_Rules.frefI

(*TODO Move*)
lemma trail_set_zeroed_until_state_alt_def:
  RETURN oo trail_set_zeroed_until_state = (λk S. do {
    let (M, S) = extract_trail_wl_heur S;
    let M = trail_set_zeroed_until k M;
    RETURN (update_trail_wl_heur M S)
  })
  unfolding trail_set_zeroed_until_state_def
  by  (auto simp: state_extractors
    intro!: ext split: isasat_int_splits)

sepref_def trail_set_zeroed_until_state
  is uncurry (RETURN oo trail_set_zeroed_until_state)
  ::  sint64_nat_assnk *a isasat_bounded_assnd a isasat_bounded_assn
  unfolding trail_set_zeroed_until_state_alt_def
  by sepref


lemma trail_zeroed_until_state_alt_def:
  RETURN o trail_zeroed_until_state = read_trail_wl_heur (RETURN  trail_zeroed_until)
  by (auto intro!: ext simp: trail_zeroed_until_state_def trail_zeroed_until_def
    read_all_st_def split: isasat_int_splits)

definition trail_zeroed_until_state_impl where
  trail_zeroed_until_state_impl = read_trail_wl_heur_code count_decided_pol_impl

sepref_register extract_trail_wl_heur count_decided_pol trail_zeroed_until_state trail_set_zeroed_until_state


definition trail_zeroed_until_state_fast_code :: twl_st_wll_trail_fast2  _ where
  trail_zeroed_until_state_fast_code = read_trail_wl_heur_code trail_zeroed_until_impl


global_interpretation trail_zeroed_until: read_trail_param_adder0 where
  f = trail_zeroed_until_impl and
  f' = RETURN o trail_zeroed_until and
  x_assn = sint64_nat_assn and
  P = (λS. True)
  rewrites read_trail_wl_heur (RETURN o trail_zeroed_until) = RETURN o trail_zeroed_until_state and
  read_trail_wl_heur_code trail_zeroed_until_impl = trail_zeroed_until_state_fast_code
  apply unfold_locales
  apply (rule trail_zeroed_until_impl.refine)
  subgoal
    by (auto simp: read_all_st_def trail_zeroed_until_state_def intro!: ext
      split: isasat_int_splits)
  subgoal
    by (auto simp: trail_zeroed_until_state_fast_code_def)
  done

lemmas [sepref_fr_rules] = trail_zeroed_until.refine[unfolded lambda_comp_true]
lemmas [unfolded inline_direct_return_node_case, llvm_code] =
  trail_zeroed_until_state_fast_code_def[unfolded read_all_st_code_def]


  (*End Move*)


sepref_def FLAG_restart_impl
  is uncurry0 (RETURN FLAG_restart)
  :: unit_assnk a word_assn
  unfolding FLAG_restart_def
  by sepref

sepref_def FLAG_no_restart_impl
  is uncurry0 (RETURN FLAG_no_restart)
  :: unit_assnk a word_assn
  unfolding FLAG_no_restart_def
  by sepref

sepref_def FLAG_GC_restart_impl
  is uncurry0 (RETURN FLAG_GC_restart)
  :: unit_assnk a word_assn
  unfolding FLAG_GC_restart_def
  by sepref

sepref_def FLAG_Reduce_restart_impl
  is uncurry0 (RETURN FLAG_Reduce_restart)
  :: unit_assnk a word_assn
  unfolding FLAG_Reduce_restart_def
  by sepref

sepref_def FLAG_Inprocess_restart_impl
  is uncurry0 (RETURN FLAG_Inprocess_restart)
  :: unit_assnk a word_assn
  unfolding FLAG_Inprocess_restart_def
  by sepref

definition end_of_restart_phase_st_impl :: twl_st_wll_trail_fast2  _ where
  end_of_restart_phase_st_impl = read_heur_wl_heur_code end_of_restart_phase_impl

global_interpretation end_of_restart_phase: read_heur_param_adder0 where
  f' = RETURN o end_of_restart_phase and
  f = end_of_restart_phase_impl and
  x_assn = word_assn and
  P = λ_. True
  rewrites read_heur_wl_heur (RETURN o end_of_restart_phase) = RETURN o end_of_restart_phase_st and
    read_heur_wl_heur_code end_of_restart_phase_impl = end_of_restart_phase_st_impl
  apply unfold_locales
  apply (rule end_of_restart_phase_impl_refine)
  subgoal by (auto simp: read_all_st_def end_of_restart_phase_st_def intro!: ext
    split: isasat_int_splits)
  subgoal by (auto simp: end_of_restart_phase_st_impl_def)
  done
 
definition end_of_rephasing_phase_st_impl :: twl_st_wll_trail_fast2  _ where
  end_of_rephasing_phase_st_impl = read_heur_wl_heur_code end_of_rephasing_phase_heur_stats_impl

global_interpretation end_of_rephasing_phase: read_heur_param_adder0 where
  f' = RETURN o end_of_rephasing_phase_heur and
  f = end_of_rephasing_phase_heur_stats_impl and
  x_assn = word_assn and
  P = λ_. True
  rewrites read_heur_wl_heur (RETURN o end_of_rephasing_phase_heur) = RETURN o end_of_rephasing_phase_st and
    read_heur_wl_heur_code end_of_rephasing_phase_heur_stats_impl = end_of_rephasing_phase_st_impl
  apply unfold_locales
  apply (rule heur_refine)
  subgoal by (auto simp: read_all_st_def end_of_rephasing_phase_st_def intro!: ext
    split: isasat_int_splits)
  subgoal by (auto simp: end_of_rephasing_phase_st_impl_def)
  done


lemmas [sepref_fr_rules] = end_of_restart_phase.refine end_of_rephasing_phase.refine
lemmas [unfolded inline_direct_return_node_case, llvm_code] =
  end_of_restart_phase_st_impl_def[unfolded read_all_st_code_def]
  end_of_rephasing_phase_st_impl_def[unfolded read_all_st_code_def]

sepref_register incr_restart_phase incr_restart_phase_end
  update_restart_phases


lemma update_restart_phases_alt_def:
  update_restart_phases = (λS. do {
     let lcount = get_global_conflict_count S;
     let (heur, S) = extract_heur_wl_heur S;
     let (vm, S) = extract_vmtf_wl_heur S;
     let vm = switch_bump_heur vm;
     heur  RETURN (incr_restart_phase heur);
     heur  RETURN (incr_restart_phase_end lcount heur);
     heur  RETURN (if current_restart_phase heur = STABLE_MODE then heuristic_reluctant_enable heur else heuristic_reluctant_disable heur);
     heur  RETURN (swap_emas heur);
     RETURN (update_heur_wl_heur heur (update_vmtf_wl_heur vm S))
  })
  by (auto simp: update_restart_phases_def state_extractors split: isasat_int_splits intro!: ext)

sepref_def update_restart_phases_impl
  is update_restart_phases
  :: isasat_bounded_assnd a isasat_bounded_assn
  unfolding update_restart_phases_alt_def
  by sepref

sepref_register upper_restart_bound_reached

sepref_def upper_restart_bound_reached_fast_impl
  is (RETURN o upper_restart_bound_reached)
  :: isasat_bounded_assnk a bool1_assn
  unfolding upper_restart_bound_reached_def PR_CONST_def
    fold_tuple_optimizations get_restart_count_st_def[symmetric]
    get_global_conflict_count_def[symmetric]
  supply [[goals_limit = 1]]
  by sepref

sepref_register max_restart_decision_lvl

sepref_def minimum_number_between_restarts_impl
  is uncurry0 (RETURN minimum_number_between_restarts)
  :: unit_assnk a word_assn
  unfolding minimum_number_between_restarts_def
  by sepref

sepref_def uint32_nat_assn_impl
  is uncurry0 (RETURN max_restart_decision_lvl)
  :: unit_assnk a uint32_nat_assn
  unfolding max_restart_decision_lvl_def
  apply (annot_unat_const TYPE(32))
  by sepref

sepref_def GC_required_heur_fast_code
  is uncurry GC_required_heur
  :: isasat_bounded_assnk *a uint64_nat_assnk a bool1_assn
  supply [[goals_limit=1]] of_nat_snat[sepref_import_param]
  unfolding GC_required_heur_def
  apply (annot_snat_const TYPE(64))
  by sepref

sepref_def GC_units_required_heur_fast_code
  is RETURN o GC_units_required
  :: isasat_bounded_assnk a bool1_assn
  supply [[goals_limit=1]] of_nat_snat[sepref_import_param]
  unfolding GC_units_required_def
  by sepref

sepref_register should_inprocess_or_unit_reduce_st

sepref_def should_inprocess_or_unit_reduce_st
  is uncurry (RETURN oo should_inprocess_or_unit_reduce_st)
  :: isasat_bounded_assnk *a bool1_assnk a bool1_assn
  unfolding should_inprocess_or_unit_reduce_st_def should_inprocess_st_def
  by sepref

sepref_register ema_get_value get_fast_ema_heur get_slow_ema_heur

sepref_def restart_required_heur_fast_code
  is uncurry3 restart_required_heur
  :: [λ(((S, _), _), _). learned_clss_count S  unat64_max]a isasat_bounded_assnk *a
     uint64_nat_assnk *a uint64_nat_assnk *a uint64_nat_assnk  word_assn
  supply [[goals_limit=1]] isasat_fast_def[simp] clss_size_allcount_alt_def[simp]
    learned_clss_count_def[simp]
  unfolding restart_required_heur_def get_slow_ema_heur_st_def[symmetric]
    get_fast_ema_heur_st_def[symmetric]
  apply (rewrite in   _ unat_const_fold[where 'a=32])
(* apply (rewrite in ‹(_ >> 32) < ⌑› annot_unat_unat_upcast[where 'l=64])*)
  apply (annot_snat_const TYPE(64))
  by sepref

(*TODO Move to trail*)
sepref_def replace_reason_in_trail_code
  is uncurry2 replace_reason_in_trail
  :: unat_lit_assnk *a (sint64_nat_assn)k *a trail_pol_fast_assnd a trail_pol_fast_assn
  supply [[goals_limit=1]]
  unfolding trail_pol_fast_assn_def replace_reason_in_trail_def trail_update_reason_at_def
  apply (annot_snat_const TYPE(64))
  apply (rewrite at list_update _ _ _ annot_index_of_atm)
  by sepref

(*END Move*)
lemma isasat_replace_annot_in_trail_alt_def:
  isasat_replace_annot_in_trail L C = (λS. do {
    let (lcount, S) = extract_lcount_wl_heur S;
    let (M, S) = extract_trail_wl_heur S;
    let lcount = clss_size_resetUS0 lcount;
    M  replace_reason_in_trail L C M;
    RETURN (update_trail_wl_heur M (update_lcount_wl_heur lcount S))
  })
  by (auto simp: isasat_replace_annot_in_trail_def state_extractors
        intro!: ext split: isasat_int_splits)
sepref_register isasat_replace_annot_in_trail
sepref_def isasat_replace_annot_in_trail_code
  is uncurry2 isasat_replace_annot_in_trail
  :: unat_lit_assnk *a (sint64_nat_assn)k *a isasat_bounded_assnd a isasat_bounded_assn
  supply [[goals_limit=1]]
  unfolding isasat_replace_annot_in_trail_alt_def
  by sepref

sepref_register remove_one_annot_true_clause_one_imp_wl_D_heur

lemma remove_one_annot_true_clause_one_imp_wl_D_heurI:
  isasat_fast b 
       learned_clss_count xb  learned_clss_count b 
        learned_clss_count xb  unat64_max
 by (auto simp: isasat_fast_def)


sepref_def remove_one_annot_true_clause_one_imp_wl_D_heur_code
  is uncurry remove_one_annot_true_clause_one_imp_wl_D_heur
  :: [λ(C, S). learned_clss_count S  unat64_max]a
       sint64_nat_assnk *a isasat_bounded_assnd  sint64_nat_assn ×a isasat_bounded_assn
  supply [[goals_limit=1]] remove_one_annot_true_clause_one_imp_wl_D_heurI[intro]
  unfolding remove_one_annot_true_clause_one_imp_wl_D_heur_def
    isasat_trail_nth_st_def[symmetric] get_the_propagation_reason_pol_st_def[symmetric]
    fold_tuple_optimizations
  apply (rewrite in _ =  snat_const_fold(1)[where 'a=64])
  apply (annot_snat_const TYPE(64))
  by sepref

sepref_register remove_one_annot_true_clause_imp_wl_D_heur

lemma remove_one_annot_true_clause_imp_wl_D_heurI:
  learned_clss_count x  unat64_max 
       remove_one_annot_true_clause_imp_wl_D_heur_inv x (a1', a2') 
       learned_clss_count a2'  unat64_max
  by (auto simp: isasat_fast_def remove_one_annot_true_clause_imp_wl_D_heur_inv_def)

sepref_def remove_one_annot_true_clause_imp_wl_D_heur_code
  is remove_one_annot_true_clause_imp_wl_D_heur
  :: [λS. length (get_clauses_wl_heur S)  snat64_max  
          learned_clss_count S  unat64_max]a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit=1]] remove_one_annot_true_clause_imp_wl_D_heurI[intro]
  unfolding remove_one_annot_true_clause_imp_wl_D_heur_def
    isasat_length_trail_st_def[symmetric] get_pos_of_level_in_trail_imp_st_def[symmetric]
  apply (annot_unat_const TYPE(32))
  by sepref


sepref_register number_clss_to_keep


lemma [sepref_fr_rules]:
  (Mreturn o id, RETURN o unat)  word64_assnk a uint64_nat_assn
proof -
  have [simp]: (λs. xa. ((xa = unat x) ∧* (xa = unat x)) s) = True
    by (intro ext)
     (auto intro!: exI[of _ unat x] simp: pure_true_conv pure_part_pure_eq pred_lift_def
      simp flip: import_param_3)
  show ?thesis
    apply sepref_to_hoare
    apply (vcg)
    apply (auto simp: unat_rel_def unat.rel_def br_def pred_lift_def ENTAILS_def pure_true_conv simp flip: import_param_3 pure_part_def)
    done
qed

sepref_def number_clss_to_keep_fast_code
  is number_clss_to_keep_impl
  :: isasat_bounded_assnk a sint64_nat_assn
  supply [[goals_limit = 1]]
  unfolding number_clss_to_keep_impl_def length_tvdom_def[symmetric] length_tvdom_aivdom_def
  apply (annot_snat_const TYPE(64))
  by sepref

lemma number_clss_to_keep_impl_number_clss_to_keep:
  (number_clss_to_keep_impl, number_clss_to_keep)  Sepref_Rules.freft Id (λ_. nat_relnres_rel)
  by (auto simp: number_clss_to_keep_impl_def number_clss_to_keep_def Let_def intro!: Sepref_Rules.frefI nres_relI)

lemma number_clss_to_keep_fast_code_refine[sepref_fr_rules]:
  (number_clss_to_keep_fast_code, number_clss_to_keep)  (isasat_bounded_assn)k a snat_assn
  using hfcomp[OF number_clss_to_keep_fast_code.refine
    number_clss_to_keep_impl_number_clss_to_keep, simplified]
  by auto

(*TODO Move to IsaSAT_Setup2*)
experiment
begin
  export_llvm restart_required_heur_fast_code access_avdom_at_fast_code
  trail_zeroed_until_state_fast_code
end

end