import Mathlib

/-!
# PenaltyERT

Lean4 formalisation of the *core* ERT confidence function and a calibration-loss term.

We use Peter’s preferred confidence-as-a-limit function:

  T(k, x) = 1 - exp(-k*x)

with assumptions `k > 0` and `x ≥ 0`, and an expressed confidence `hatT` that is
clipped/normalised to `[0,1]`.

We then prove:
- `0 ≤ T ≤ 1` (bounded, well-defined)
- the calibration loss `(T - hatT)^2` is nonnegative
- and is bounded above by 1 when `hatT ∈ [0,1]`.

This provides a mathematically clean “boundedness / positivity” result for the ERT core.
It does *not* formalise the full motivated-reasoning term (MR); that can be added later as
an abstract nonnegative term.
-/

namespace AIAlignment

noncomputable section

open Real

/-- Confidence-as-a-limit function. -/
def T (k x : ℝ) : ℝ := 1 - Real.exp (-k * x)

/-- A simple calibration loss between target confidence `T(k,x)` and reported confidence `hatT`. -/
def Lcal (k x hatT : ℝ) : ℝ := (T k x - hatT)^2

lemma T_le_one (k x : ℝ) : T k x ≤ 1 := by
  -- 1 - exp(...) ≤ 1 since exp(...) ≥ 0
  dsimp [T]
  simpa using (sub_le_self (1 : ℝ) (Real.exp_nonneg _))

lemma T_nonneg (k x : ℝ) (hk : 0 < k) (hx : 0 ≤ x) : 0 ≤ T k x := by
  -- Need exp(-k*x) ≤ 1 when -k*x ≤ 0.
  have hkx : 0 ≤ k * x := mul_nonneg (le_of_lt hk) hx
  have hneg : -k * x ≤ 0 := by
    -- `-k*x ≤ 0` is equivalent to `0 ≤ k*x`
    simpa [neg_mul, neg_le] using hkx
  have hexp : Real.exp (-k * x) ≤ 1 := by
    -- exp t ≤ 1 ↔ t ≤ 0
    -- (lemma name in mathlib)
    simpa [Real.exp_le_one_iff] using hneg
  dsimp [T]
  exact sub_nonneg.mpr hexp

lemma T_between_zero_one (k x : ℝ) (hk : 0 < k) (hx : 0 ≤ x) : 0 ≤ T k x ∧ T k x ≤ 1 := by
  exact ⟨T_nonneg k x hk hx, T_le_one k x⟩

lemma Lcal_nonneg (k x hatT : ℝ) : 0 ≤ Lcal k x hatT := by
  dsimp [Lcal]
  nlinarith

lemma Lcal_le_one (k x hatT : ℝ)
    (hk : 0 < k) (hx : 0 ≤ x)
    (hhat0 : 0 ≤ hatT) (hhat1 : hatT ≤ 1) : Lcal k x hatT ≤ 1 := by
  have hT0 : 0 ≤ T k x := (T_between_zero_one k x hk hx).1
  have hT1 : T k x ≤ 1 := (T_between_zero_one k x hk hx).2
  -- Bound the difference between two numbers in [0,1].
  have h_upper : T k x - hatT ≤ 1 := by
    -- T - hatT ≤ 1 - 0
    have : T k x - hatT ≤ 1 - 0 := sub_le_sub hT1 hhat0
    simpa using this
  have h_lower : -1 ≤ T k x - hatT := by
    -- -(1) ≤ T - hatT, using T ≥ 0 and hatT ≤ 1.
    -- 0 - 1 ≤ T - hatT
    have : 0 - 1 ≤ T k x - hatT := sub_le_sub hT0 hhat1
    simpa using this
  -- Now square is bounded by 1.
  have : (T k x - hatT)^2 ≤ 1 := by
    nlinarith
  simpa [Lcal] using this

end

end AIAlignment
