#!/usr/bin/env python3
"""
PWE Dark Halo Simulation: Stage-1 (Publication-Ready Core)
=========================================================

This script implements the canonical Stage-1 evolution used in the
Entanglement Compression Theory (ECT) halo-formation simulation.

Purpose
-------
- Evolve a 2D complex field ψ(x, y, t) under the Primordial Wave Equation (PWE)
  with a real compression functional C[ρ].
- Demonstrate the emergence of a dark-matter-like halo from a seeded
  oscillatory ring with asymmetry, phase winding, and noise.
- Save a small set of PNG frames (ρ = |ψ|^2) plus the final ψ field for
  inspection, analysis, or use as input to later stages.

Key features
------------
- 2D periodic box, N = 320, L = 2π (natural units).
- Split-step evolution: linear term handled in Fourier space, compression term
  handled in real space.
- Lossless evolution enforced by L² normalization after each step.
- Minimal dependencies: NumPy and Matplotlib only.
"""

import numpy as np
import numpy.fft as fft
import matplotlib.pyplot as plt

# ==============================
# Global constants and tags
# ==============================
TAG_STAGE1 = "ECT_Dark_Halo_Stage_1"   # Prefix for all Stage-1 output files
CMAP       = "inferno"             # Colormap for PNG frames


# ==============================
# Dummy logger (logging disabled)
# ==============================
def make_logger(path: str):
    """
    Logger stub retained for compatibility.

    In internal development, Stage-1 wrote diagnostics to a text log.
    For publication, all logging is disabled to keep the script clean
    and deterministic. The returned `log` function does nothing.
    """
    def log(msg: str):
        # Intentionally do nothing. Logging is disabled for release.
        pass

    class DummyFH:
        def close(self):
            # No file handle to close in this stub.
            pass

    return log, DummyFH()


# ==============================
# Stage-1 evolution
# ==============================
def run_stage1():
    """
    Run Stage-1 of the PWE-based halo simulation.

    This function:
    - Builds a 2D periodic grid in a square box of side L = 2π.
    - Constructs an initial condition consisting of:
        * A central Gaussian core.
        * Elliptical distortion.
        * m = 2 phase winding (topological charge).
        * Low-order angular asymmetries (m = 1 and m = 3).
        * Complex Gaussian noise.
    - Evolves ψ under the Primordial Wave Equation with a real
      compression functional C(ρ).
    - Outputs:
        * ECT_Dark_Halo_Stage_1_step.png snapshots at selected steps.
        * ECT_Dark_Halo_Stage_1_psi_final.npy containing the final ψ field.

    Returns
    -------
    psi : np.ndarray (complex128)
        Final complex field ψ(x, y, t_final) after Stage-1 evolution.
    """

    # ------------------------------
    # Logging (disabled)
    # ------------------------------
    log, fh = make_logger(f"{TAG_STAGE1}.txt")
    # log("Starting Stage-1 (guided scaffold evolution).")

    # ------------------------------
    # Grid definition
    # ------------------------------
    # N: number of grid points per dimension.
    # L: physical size of the box (2π, natural units).
    N = 320
    L = 2 * np.pi
    dx = L / N

    # Real-space coordinates on a uniform periodic grid.
    x = np.linspace(0, L, N, endpoint=False)
    y = np.linspace(0, L, N, endpoint=False)
    X, Y = np.meshgrid(x, y, indexing="ij")

    # ------------------------------
    # PWE and compression parameters
    # ------------------------------
    # alpha: kinetic coefficient (controls dispersion strength).
    # beta : compression coupling strength (weights C[ρ] term).
    # c0   : local compression weight (log-density term).
    # c2   : curvature compression weight (Laplacian of log-density).
    # eps  : regularization scale inside ln_ε to avoid singularities at nodes.
    # dt   : time step for split-step integration.
    # total_steps: number of time steps in Stage-1 evolution.
    alpha = 1.0
    beta  = 1.0
    c0    = 1.0
    c2    = 0.12
    eps   = 1e-3
    dt    = 0.045
    total_steps = 698

    # Steps at which to save PNG frames (ρ snapshots).
    # These were chosen empirically to capture key morphological transitions.
    save_steps = [
        10, 100, 200, 300, 325, 400, 450, 475, 500, 550, 600,
        660, 670, 690, 692, 693, 694, 695, 696, 697, 698
    ]

    # ------------------------------
    # Initial condition parameters
    # ------------------------------
    # RING_RADIUS : not directly used in this version. The evolution is
    #               dominated by the Gaussian core plus asymmetry and phase.
    # RING_WIDTH  : placeholder for historic ring-style ICs, kept for clarity.
    # M_WINDING   : phase winding number (topological charge).
    # ASYM_M1     : m = 1 angular asymmetry amplitude.
    # ASYM_M3     : m = 3 angular asymmetry amplitude.
    # NOISE_AMP   : amplitude of complex Gaussian noise.
    # SEED        : RNG seed for reproducibility.
    # ELLIP       : ellipticity factor, retained for readability of the IC
    #               design history.
    RING_RADIUS = 0.9
    RING_WIDTH  = 0.18
    M_WINDING   = 2
    ASYM_M1     = 0.08
    ASYM_M3     = 0.06
    NOISE_AMP   = 0.03
    SEED        = 20251111
    ELLIP       = 1.08

    # ------------------------------
    # Fourier-space grid for Laplacian
    # ------------------------------
    # Construct wave numbers kx, ky consistent with the periodic box.
    kx = fft.fftfreq(N, d=dx) * 2 * np.pi
    ky = fft.fftfreq(N, d=dx) * 2 * np.pi
    KX, KY = np.meshgrid(kx, ky, indexing="ij")
    K2 = KX**2 + KY**2

    def laplacian(f):
        """
        Spectral Laplacian on the periodic grid.

        Parameters
        ----------
        f : np.ndarray
            Real-valued scalar field on the grid.

        Returns
        -------
        np.ndarray
            Real-valued Laplacian ∇² f computed using FFTs.
        """
        return fft.ifftn(-K2 * fft.fftn(f)).real

    def ln_eps(u):
        """
        Regularized logarithm ln_ε(u) = 0.5 * ln(u² + ε²).

        This avoids singular behavior at nodes (u ≈ 0) while preserving
        the correct large-u behavior. Used on density ρ = |ψ|².
        """
        return 0.5 * np.log(u * u + eps * eps)

    # ------------------------------
    # Initial wavefunction ψ(x, y, t = 0)
    # ------------------------------
    # Center of the box.
    cx, cy = L / 2, L / 2
    dx0, dy0 = X - cx, Y - cy

    # Polar coordinates relative to center.
    theta = np.arctan2(dy0, dx0)
    r = np.hypot(dx0, dy0)

    # Core Gaussian: sets the primary amplitude profile.
    CORE_SIGMA = 0.55
    core = np.exp(-(r * r) / (2 * CORE_SIGMA * CORE_SIGMA))

    # Phase winding: e^{i m θ} generates a topological charge m.
    phase = np.exp(1j * M_WINDING * theta)

    # Low-order angular asymmetries: m = 1 and m = 3 cosine modes.
    # These break perfect symmetry and seed structure in the halo.
    asym = 1 + ASYM_M1 * np.cos(theta) + ASYM_M3 * np.cos(3 * theta)

    # Complex Gaussian noise: small amplitude fluctuations that seed
    # fine-grained structure and break residual symmetry.
    rng = np.random.default_rng(SEED)
    noise = NOISE_AMP * (
        rng.normal(size=(N, N)) + 1j * rng.normal(size=(N, N))
    )

    # Combine core, asymmetry, phase, and noise.
    psi = core * asym * phase + noise

    # Enforce L² normalization: ∑ |ψ|² = 1.
    psi /= np.linalg.norm(psi)

    # log("Stage-1: initial condition constructed and normalized.")

    # ------------------------------
    # Precompute linear evolution phase
    # ------------------------------
    # Linear part: exp(-i α k² dt). This is applied in Fourier space
    # for each time step (split-step method).
    linear_phase = np.exp(-1j * alpha * K2 * dt)

    def compute_C_and_rho(psi_arr):
        """
        Compute compression field C[ρ] and density ρ = |ψ|².

        Parameters
        ----------
        psi_arr : np.ndarray (complex)
            Current wavefunction ψ on the grid.

        Returns
        -------
        C : np.ndarray (float)
            Real-valued compression field C(x, y).
        rho : np.ndarray (float)
            Probability density ρ(x, y) = |ψ|².
        """
        # Density: ρ = |ψ|².
        rho = psi_arr.real**2 + psi_arr.imag**2

        # Regularized log density.
        le = ln_eps(rho)

        # Compression functional:
        # C = -c0 * ln_ε(ρ) + c2 ∇² ln_ε(ρ).
        C = -c0 * le + c2 * laplacian(le)
        return C, rho

    def step_pwe(psi_arr):
        """
        Advance ψ by one time step under the PWE with real compression.

        Uses a symmetric split-step:
        1) Half-step compression in real space.
        2) Full-step linear propagation in Fourier space.
        3) Half-step compression in real space.
        4) Renormalize ψ to enforce L² conservation.
        """
        # First half compression step.
        C, _ = compute_C_and_rho(psi_arr)
        psi_arr *= np.exp(-1j * beta * C * (dt / 2))

        # Linear (kinetic) step in Fourier space.
        psi_arr = fft.ifftn(fft.fftn(psi_arr) * linear_phase)

        # Second half compression step.
        C, _ = compute_C_and_rho(psi_arr)
        psi_arr *= np.exp(-1j * beta * C * (dt / 2))

        # Renormalize to maintain global L² normalization.
        psi_arr /= np.linalg.norm(psi_arr)
        return psi_arr

    def save_png(name, rho):
        """
        Save a snapshot of ρ as a PNG with fixed framing.

        Parameters
        ----------
        name : str
            Output filename.
        rho : np.ndarray
            Density field to visualize.
        """
        rmin, rmax = float(rho.min()), float(rho.max())
        img = (rho - rmin) / (rmax - rmin + 1e-12)

        fig = plt.figure(figsize=(8, 8), dpi=100)
        ax = plt.axes([0, 0, 1, 1])
        ax.set_axis_off()
        ax.imshow(img, origin="lower", cmap=CMAP)
        fig.savefig(name, bbox_inches="tight", pad_inches=0)
        plt.close(fig)

    # ------------------------------
    # Time evolution loop
    # ------------------------------
    for k in range(1, total_steps + 1):
        # Advance ψ by one PWE step.
        psi = step_pwe(psi)

        # Diagnostics logging is disabled for publication.
        # if (k % 20 == 0) or (k in save_steps):
        #     _, rho_diag = compute_C_and_rho(psi)
        #     log(f"[Stage-1 step {k}] ...")

        # Save snapshots at selected steps.
        if k in save_steps:
            _, rho_snap = compute_C_and_rho(psi)
            fname = f"{TAG_STAGE1}_{k}.png"
            save_png(fname, rho_snap)
            # log(f"Saved Stage-1 frame -> {fname}")

    # ------------------------------
    # Final field output
    # ------------------------------
    # Save final ψ for analysis or as input to higher stages.
    # Note: stored as complex64 to reduce file size while preserving
    # more than enough precision for visualization and qualitative study.
    np.save(f"{TAG_STAGE1}_psi_final.npy", psi.astype(np.complex64))
    # log("Saved Stage-1 final psi.")

    # Close dummy logger.
    fh.close()

    return psi


# ==============================
# Script entry point
# ==============================
if __name__ == "__main__":
    # Run Stage-1 and produce:
    # - Dark_Halo_Stage_1_XXX.png frames listed in `save_steps`.
    # - Dark_Halo_Stage_1_psi_final.npy with the final ψ field.
    run_stage1()
