Source code for socca.models.bridge.base

"""Bridge/filament model for intracluster structures.

This module provides models for describing elongated emission structures
such as intracluster bridges and filaments connecting galaxy clusters.
"""

from functools import partial
import inspect
import types

import warnings

import jax
import jax.numpy as jp
import numpyro.distributions

import numpy as np

from ..base import Component
from ..misc import Background, Point
from ..radial import Beta, PolyExpoRefact, Power, TopHat
from .. import config


class Bridge(Component):
    """
    Base class for bridge/filament emission models.

    The Bridge model describes elongated emission structures by combining
    a radial profile (perpendicular to the bridge axis) with a parallel
    profile (along the bridge axis). This factorized approach allows for
    flexible modeling of intracluster bridges and large-scale filaments.

    Parameters
    ----------
    radial : Profile, optional
        Profile describing the surface brightness distribution perpendicular
        to the bridge major axis. Default is Beta().
    parallel : Profile, optional
        Profile describing the surface brightness distribution along the
        bridge major axis. Default is TopHat().
    **kwargs : dict
        Additional keyword arguments including:

        xc : float, optional
            Right ascension of the bridge centroid in degrees.
        yc : float, optional
            Declination of the bridge centroid in degrees.
        rs : float, optional
            Scale radius for both radial and parallel profiles in degrees.
        Is : float, optional
            Scale intensity of the bridge (same units as image).
        theta : float, optional
            Position angle of the bridge major axis in radians
            (measured east from north).
        e : float, optional
            Axis ratio controlling the relative extent of radial vs
            parallel profiles. Default is 0.5.

    Attributes
    ----------
    radial : Profile
        The radial profile component.
    parallel : Profile
        The parallel profile component.

    Notes
    -----
    The Bridge class is an abstract base that defines the interface and
    parameter handling. Use SimpleBridge or MesaBridge for concrete
    implementations with specific profile combination rules.

    The scale radius `rs` is shared between the parallel profile (where
    it is used directly) and the radial profile (where it is scaled by
    the factor `1 - e` to control the axis ratio).

    See Also
    --------
    SimpleBridge : Multiplicative combination of radial and parallel profiles.
    MesaBridge : Harmonic mean combination for mesa-like profiles.
    """

    def __init__(self, radial=Beta(), parallel=TopHat(), **kwargs):
        super().__init__(**kwargs)

        self.xc = kwargs.get("xc", config.Bridge.xc)
        self.yc = kwargs.get("yc", config.Bridge.yc)
        self.rs = kwargs.get("rs", config.Bridge.rs)
        self.Is = kwargs.get("Is", config.Bridge.Is)
        self.theta = kwargs.get("theta", config.Bridge.theta)
        self.e = kwargs.get("e", config.Bridge.e)

        self.units.update(
            dict(
                xc="deg",
                yc="deg",
                rs="deg",
                Is="image",
                theta="rad",
                e="",
            )
        )
        self.description.update(
            dict(
                xc="Right ascension of the bridge centroid",
                yc="Declination of the bridge centroid",
                rs="Scale radius of the bridge profiles",
                Is="Scale intensity of the bridge profiles",
                theta="Position angle of the bridge major axis",
                e="Projected ellipticity of the bridge profiles",
            )
        )

        self.radial = radial
        self.parallel = parallel

        for attr in [self.radial, self.parallel]:
            if isinstance(attr, PolyExpoRefact):
                raise TypeError(
                    "PolyExpoRefact is not supported for Bridge. "
                    "Use PolyExponential or another profile type instead."
                )
            if isinstance(attr, (Background, Point)):
                raise TypeError(
                    f"{type(attr).__name__} is not supported for Bridge."
                )

        for key in ["xc", "yc", "theta", "e", "cbox"]:
            bkey = f"_{key}" if f"_{key}" in self.radial.__dict__ else key
            if bkey in self.radial.__dict__:
                delattr(self.radial, bkey)
            self.radial.units.pop(key, None)
            bkey = f"_{key}" if f"_{key}" in self.parallel.__dict__ else key
            if bkey in self.parallel.__dict__:
                delattr(self.parallel, bkey)
            self.parallel.units.pop(key, None)

        for key in ["_scale_amp", "_scale_radius"]:
            rname = getattr(self.radial, key)
            pname = getattr(self.parallel, key)
            delattr(self.radial, rname)
            delattr(self.parallel, pname)
            self.radial.units.pop(rname, None)
            self.parallel.units.pop(pname, None)

        if self.radial.id != self.id:
            type(self).idcls -= 1
            idmin = np.minimum(
                int(self.radial.id.replace("comp_", "")),
                int(self.id.replace("comp_", "")),
            )
            self.id = f"comp_{idmin:02d}"
            self.radial.id = self.id

        for attr in [self.parallel, self.radial]:
            attr.addparameter(attr._scale_amp, 1.00)

        self.parallel.addparameter(
            self.parallel._scale_radius,
            eval(f"lambda {self.id}_rs: {self.id}_rs"),
        )

        self.radial.addparameter(
            self.radial._scale_radius,
            eval(
                f"lambda {self.id}_rs, {self.id}_e: "
                f"{self.id}_rs * (1.00 - {self.id}_e)"
            ),
        )

        self._rkw = [
            f"r_{key}"
            for key in list(
                inspect.signature(self.radial.profile).parameters.keys()
            )
            if key not in ["r", "z"]
        ]
        self._zkw = [
            f"z_{key}"
            for key in list(
                inspect.signature(self.parallel.profile).parameters.keys()
            )
            if key not in ["r", "z"]
        ]

        for key in self.radial.hyper:
            self.hyper.append(f"radial.{key}")

        for key in self.parallel.hyper:
            self.hyper.append(f"parallel.{key}")

        self.units.update(
            {
                f"radial.{key}": self.radial.units[key]
                for key in self.radial.units.keys()
                if key
                not in [self.radial._scale_radius, self.radial._scale_amp]
            }
        )
        self.units.update(
            {
                f"parallel.{key}": self.parallel.units[key]
                for key in self.parallel.units.keys()
                if key
                not in [self.parallel._scale_radius, self.parallel._scale_amp]
            }
        )

        self.description.update(
            {
                f"radial.{key}": self.radial.description[key]
                for key in self.radial.description.keys()
                if key
                not in [self.radial._scale_radius, self.radial._scale_amp]
            }
        )
        self.description.update(
            {
                f"parallel.{key}": self.parallel.description[key]
                for key in self.parallel.description.keys()
                if key
                not in [self.parallel._scale_radius, self.parallel._scale_amp]
            }
        )

        self.profile = None

    def getmap(self, img, convolve=False):
        """
        Generate a two-dimensional model image of the bridge.

        Parameters
        ----------
        img : Image
            Image object defining the coordinate grid and optional PSF.
        convolve : bool, optional
            If True, convolve the model with the image PSF. Default is False.

        Returns
        -------
        ndarray
            Two-dimensional model image evaluated on the image grid.

        Raises
        ------
        ValueError
            If any model parameter is set to None or contains a prior
            distribution instead of a fixed value.

        Notes
        -----
        The model is computed by evaluating the combined radial and parallel
        profiles on a rotated coordinate grid centered on the bridge position.
        """
        kwarg = {}
        for key in list(inspect.signature(self.profile).parameters.keys()):
            if key not in ["r", "z"]:
                if key.startswith("r_"):
                    val = getattr(self.radial, key.replace("r_", ""))
                elif key.startswith("z_"):
                    val = getattr(self.parallel, key.replace("z_", ""))
                else:
                    continue

                if callable(val):
                    sig = inspect.signature(val)
                    params = list(sig.parameters.keys())
                    if params:
                        args = [
                            getattr(self, p.replace(f"{self.id}_", ""))
                            for p in params
                        ]
                        val = val(*args)

                kwarg[key] = val

        kwarg["xc"] = self.xc
        kwarg["yc"] = self.yc
        kwarg["Is"] = self.Is
        kwarg["theta"] = self.theta

        for key in kwarg.keys():
            if isinstance(kwarg[key], numpyro.distributions.Distribution):
                raise ValueError(
                    "Priors must be fixed values, not distributions."
                )
            if kwarg[key] is None:
                raise ValueError(
                    f"keyword {key} is set to None. "
                    f"Please provide a valid value."
                )

        mgrid = self._evaluate(img, **kwarg)

        if convolve:
            if img.psf is None:
                warnings.warn(
                    "No PSF defined, so no convolution will be performed."
                )
            else:
                mgrid = img.convolve(mgrid)
        return mgrid

    @staticmethod
    @partial(jax.jit, static_argnames=["grid"])
    def getgrid(grid, xc, yc, theta):
        """
        Compute rotated coordinate grids for bridge evaluation.

        Parameters
        ----------
        grid : WCSgrid
            Coordinate grid from the Image object.
        xc : float
            Right ascension of bridge centroid in degrees.
        yc : float
            Declination of bridge centroid in degrees.
        theta : float
            Position angle of bridge major axis in radians.

        Returns
        -------
        rgrid : ndarray
            Coordinate grid along the radial (perpendicular) direction.
        zgrid : ndarray
            Coordinate grid along the parallel (major axis) direction.

        Notes
        -----
        The transformation accounts for spherical coordinate projection
        using the cosine of the declination at the bridge centroid.
        """
        sint = jp.sin(theta)
        cost = jp.cos(theta)

        zgrid = (
            -(grid.x - xc) * jp.cos(jp.deg2rad(yc)) * sint
            - (grid.y - yc) * cost
        )
        rgrid = (grid.x - xc) * jp.cos(jp.deg2rad(yc)) * cost - (
            grid.y - yc
        ) * sint

        return rgrid, zgrid

    def _build_kwargs(self, pars, comp_prefix):
        """
        Build keyword arguments for _evaluate from the full parameters dict.

        Parameters
        ----------
        pars : dict
            Full parameters dictionary with prefixed keys.
        comp_prefix : str
            Component prefix (e.g., 'comp_00').

        Returns
        -------
        dict
            Keyword arguments for _evaluate including geometric and profile
            parameters with r_ and z_ prefixes.
        """
        kwarg = {}
        for key in self._rkw:
            attr_name = key.replace("r_", "")
            pars_key = f"{comp_prefix}_radial.{attr_name}"
            if pars_key in pars:
                kwarg[key] = pars[pars_key]
            else:
                val = getattr(self.radial, attr_name)
                if callable(val):
                    sig = inspect.signature(val)
                    params = list(sig.parameters.keys())
                    if params:
                        args = [pars[p] for p in params]
                        val = val(*args)
                kwarg[key] = val

        for key in self._zkw:
            attr_name = key.replace("z_", "")
            pars_key = f"{comp_prefix}_parallel.{attr_name}"
            if pars_key in pars:
                kwarg[key] = pars[pars_key]
            else:
                val = getattr(self.parallel, attr_name)
                if callable(val):
                    sig = inspect.signature(val)
                    params = list(sig.parameters.keys())
                    if params:
                        args = [pars[p] for p in params]
                        val = val(*args)
                kwarg[key] = val

        kwarg["xc"] = pars[f"{comp_prefix}_xc"]
        kwarg["yc"] = pars[f"{comp_prefix}_yc"]
        kwarg["theta"] = pars[f"{comp_prefix}_theta"]
        kwarg["Is"] = pars[f"{comp_prefix}_Is"]

        return kwarg

    def _evaluate(self, img, **kwarg):
        """
        Evaluate bridge model on the given grid with explicit parameters.

        This internal method computes the bridge surface brightness using
        the provided geometric and profile parameters. It is used by both
        getmap() and Model.getmodel() to avoid code duplication.

        Parameters
        ----------
        img : Image
            Image object containing grid and WCS information.
        **kwarg : dict
            All parameters including geometric (xc, yc, theta) and
            profile-specific parameters with r_ and z_ prefixes.

        Returns
        -------
        ndarray
            2D array of surface brightness values, averaged over subpixels.
        """
        xc = kwarg.pop("xc")
        yc = kwarg.pop("yc")
        theta = kwarg.pop("theta")

        rgrid, zgrid = self.getgrid(img.grid, xc, yc, theta)
        mgrid = self.profile(rgrid, zgrid, **kwarg)
        return jp.mean(mgrid, axis=0)

    def parameters(self):
        """
        Print a summary of all model parameters.

        Displays model parameters organized by category (base, radial,
        parallel) along with their units, current values, and descriptions.
        Hyperparameters are shown in a separate section.
        """
        keyout = [key for key in self.units.keys() if key not in self.hyper]

        if len(keyout) > 0:
            maxlen = np.max(
                np.array(
                    [
                        len(f"{key} [{self.units[key]}]")
                        for key in keyout + self.hyper
                    ]
                )
            )

            print("\nModel parameters")
            print("=" * 16)
            for key in keyout:
                keylen = maxlen - len(f" [{self.units[key]}]")
                if key.startswith("radial."):
                    kvalue = getattr(self.radial, key.replace("radial.", ""))
                elif key.startswith("parallel."):
                    kvalue = getattr(
                        self.parallel, key.replace("parallel.", "")
                    )
                else:
                    kvalue = getattr(self, key)

                if kvalue is None:
                    kvalue = None
                elif isinstance(kvalue, numpyro.distributions.Distribution):
                    kvalue = f"Distribution: {kvalue.__class__.__name__}"
                elif isinstance(
                    kvalue, (types.LambdaType, types.FunctionType)
                ):
                    kvalue = "Tied parameter"
                else:
                    kvalue = f"{kvalue:.4E}"

                print(
                    f"{key:<{keylen}} [{self.units[key]}] : "
                    + f"{kvalue}".ljust(10)
                    + f" | {self.description[key]}"
                )

            if len(self.hyper) > 0:
                print("\nHyperparameters")
                print("=" * 15)
                for key in self.hyper:
                    keylen = maxlen - len(f" [{self.units[key]}]")
                    if key.startswith("radial."):
                        kvalue = getattr(
                            self.radial, key.replace("radial.", "")
                        )
                    elif key.startswith("parallel."):
                        kvalue = getattr(
                            self.parallel, key.replace("parallel.", "")
                        )
                    else:
                        kvalue = getattr(self, key)
                    kvalue = None if kvalue is None else f"{kvalue:.4E}"
                    print(
                        f"{key:<{keylen}} [{self.units[key]}] : "
                        + f"{kvalue}".ljust(10)
                        + f" | {self.description[key]}"
                    )
        else:
            print("No parameters defined.")

    def parlist(self):
        """
        Return a list of all parameter names.

        Returns
        -------
        list of str
            Parameter names including base, radial, and parallel parameters.
        """
        return list(self.units.keys())


[docs] class SimpleBridge(Bridge): """ Simple bridge model with multiplicative profile combination. The SimpleBridge combines radial and parallel profiles multiplicatively, producing a surface brightness distribution of the form: I(r, z) = Is * f_radial(r) * f_parallel(z) This creates structures where the emission is the product of the two independent profile functions. Parameters ---------- radial : Profile, optional Profile describing emission perpendicular to the bridge axis. Default is Beta(). parallel : Profile, optional Profile describing emission along the bridge axis. Default is TopHat(), producing a uniform distribution along the bridge length. **kwargs : dict Additional keyword arguments passed to Bridge base class. See Also -------- Bridge : Base class with parameter descriptions. MesaBridge : Alternative with mesa-like profile combination. Examples -------- >>> from socca.models import SimpleBridge >>> bridge = SimpleBridge() >>> bridge.parameters() Model parameters ================ xc [deg] : None | Right ascension of bridge centroid yc [deg] : None | Declination of bridge centroid rs [deg] : None | Scale radius of the bridge profiles Is [image] : None | Scale intensity of the bridge theta [rad] : 0.0000E+00 | Position angle of bridge major axis e [] : 5.0000E-01 | Projected ellipticity of the bridge profiles radial.alpha [] : 2.0000E+00 | Radial exponent radial.beta [] : 5.5000E-01 | Slope parameter """ def __init__(self, radial=Beta(), parallel=TopHat(), **kwargs): super().__init__(radial=radial, parallel=parallel, **kwargs) _profile = [ "rfoo.profile(r,{0})".format(",".join(self._rkw)), "zfoo.profile(z,{0})".format(",".join(self._zkw)), ] _profile = "*".join(_profile) _profile = "lambda rfoo,zfoo,r,z,Is,{0},{1}: Is*{2}".format( ",".join(self._rkw), ",".join(self._zkw), _profile ) self.profile = jax.jit( partial(eval(_profile), self.radial, self.parallel) ) self._initialized = True
[docs] class MesaBridge(Bridge): """ Mesa bridge model with harmonic mean profile combination. The MesaBridge combines radial and parallel profiles using a harmonic mean, producing a mesa-like (flat-topped) surface brightness distribution: I(r, z) = Is / (1/f_radial(r) + 1/f_parallel(z)) This creates smooth transitions between the flat central region and the declining edges, resembling a mesa or table-top shape. Parameters ---------- radial : Profile, optional Profile for perpendicular direction. Default is Beta with alpha=8.0 and beta=1.0 for steep edges. parallel : Profile, optional Profile for parallel direction. Default is Power with alpha=8.0 for steep drop-off. **kwargs : dict Additional keyword arguments passed to Bridge base class. Notes ----- The default parameters are chosen to produce a characteristic mesa-like shape with steep edges, suitable for modeling intracluster bridges with relatively uniform central emission. See Also -------- Bridge : Base class with parameter descriptions. SimpleBridge : Alternative with multiplicative profile combination. Examples -------- >>> from socca.models import MesaBridge >>> bridge = MesaBridge() >>> bridge.parameters() References ---------- Hincks, A. D., et al., MRANS, 510, 3335 (2022) https://scixplorer.org/abs/2022MNRAS.510.3335H/abstract Model parameters ================ xc [deg] : None | Right ascension of bridge centroid yc [deg] : None | Declination of bridge centroid rs [deg] : None | Scale radius of the bridge profiles Is [image] : None | Scale intensity of the bridge theta [rad] : 0.0000E+00 | Position angle of bridge major axis e [] : 5.0000E-01 | Axis ratio of the bridge profiles radial.alpha [] : 8.0000E+00 | Radial exponent radial.beta [] : 1.0000E+00 | Slope parameter parallel.alpha [] : 8.0000E+00 | Power law slope """ def __init__(self, radial=None, parallel=None, **kwargs): if radial is None: radial = Beta( alpha=config.MesaBridge.r_alpha, beta=config.MesaBridge.r_beta, ) if parallel is None: parallel = Power( alpha=config.MesaBridge.z_alpha, ) super().__init__(radial=radial, parallel=parallel, **kwargs) _profile = [ "1.00/rfoo.profile(r,{0})".format(",".join(self._rkw)), "1.00/zfoo.profile(z,{0})".format(",".join(self._zkw)), ] _profile = "+".join(_profile) _profile = "lambda rfoo,zfoo,r,z,{0},{1}: 1.00/({2})".format( ",".join(self._rkw), ",".join(self._zkw), _profile ) self.profile = jax.jit( partial(eval(_profile), self.radial, self.parallel) ) self._initialized = True