import Mathlib

/-!
# LambdaDynamics

Formalises the basic invariants of the multiplicative renormalisation update used for
λ-weights:

  λᵢ' = (λᵢ * Sᵢ) / (∑ⱼ λⱼ * Sⱼ)

The goal is not (yet) to prove convergence—only that the update is *well-formed*:
- nonnegativity is preserved
- the weights remain normalised (sum to 1)

This is the minimal “doesn’t collapse on itself” proof layer.
-/

namespace AIAlignment

noncomputable section

open scoped BigOperators

variable {ι : Type} [Fintype ι]

/-- Denominator used in the λ-update. -/
def denom (lam S : ι → ℝ) : ℝ := ∑ j, lam j * S j

/-- Multiplicative renormalisation update for λ given nonnegative stability scores S. -/
def lambdaUpdate (lam S : ι → ℝ) : ι → ℝ :=
  fun i => (lam i * S i) / denom lam S

lemma denom_nonneg (lam S : ι → ℝ)
    (hlam : ∀ i, 0 ≤ lam i) (hS : ∀ i, 0 ≤ S i) : 0 ≤ denom lam S := by
  classical
  -- each term λ j * S j is nonnegative
  have : ∀ j, 0 ≤ lam j * S j := fun j => mul_nonneg (hlam j) (hS j)
  -- sum of nonneg terms is nonneg
  simpa [denom] using Finset.sum_nonneg (fun j _ => this j)

lemma lambdaUpdate_nonneg (lam S : ι → ℝ)
    (hlam : ∀ i, 0 ≤ lam i) (hS : ∀ i, 0 ≤ S i)
    (hden : 0 < denom lam S) : ∀ i, 0 ≤ lambdaUpdate lam S i := by
  intro i
  classical
  have hnum : 0 ≤ lam i * S i := mul_nonneg (hlam i) (hS i)
  -- nonneg / pos = nonneg
  have : 0 ≤ (lam i * S i) / denom lam S := by
    exact div_nonneg hnum (le_of_lt hden)
  simpa [lambdaUpdate] using this

lemma sum_lambdaUpdate_eq_one (lam S : ι → ℝ)
    (hden : denom lam S ≠ 0) : (∑ i, lambdaUpdate lam S i) = 1 := by
  classical
  -- pull out constant 1/denom
  have : (∑ i, lambdaUpdate lam S i) = (∑ i, lam i * S i) / denom lam S := by
    simp [lambdaUpdate, denom, Finset.sum_div]
  -- but ∑ i, lam i * S i = denom
  -- hence denom/denom = 1
  simpa [this, denom, div_self hden]

end

end AIAlignment
