"""
Attribution result container.
Provides the AttributionResult dataclass for storing and visualising
attribution analysis results.
"""
from dataclasses import dataclass, field
from typing import Literal, Optional
import pandas as pd
_VAR_SYMBOLS = {"T2": "T2", "q2": "q2", "U10": "U10"}
_UNITS = {"T2": "degC", "q2": "g/kg", "U10": "m/s"}
_MAIN_COMPONENTS = {
"T2": ["T_ref", "flux_total", "resistance", "air_props"],
"q2": ["q_ref", "flux_total", "resistance", "air_props"],
"U10": ["forcing", "roughness", "stability"],
}
_DEFAULT_PLOT_COMPONENTS = {
"T2": ["flux_total", "resistance", "air_props"],
"q2": ["flux_total", "resistance", "air_props"],
"U10": ["forcing", "roughness", "stability"],
}
[docs]
@dataclass
class AttributionResult:
"""
Container for attribution analysis results.
Attributes
----------
variable : str
Name of attributed variable ('T2', 'q2', or 'U10')
contributions : pd.DataFrame
Timeseries of contributions from each component
summary : pd.DataFrame
Summary statistics (mean, std, min, max) for each component
metadata : dict
Additional context (period, scenarios, parameters)
"""
variable: str
contributions: pd.DataFrame
summary: pd.DataFrame
metadata: dict = field(default_factory=dict)
def __repr__(self) -> str:
"""Generate clean text representation of attribution results."""
lines = []
var_symbol = _VAR_SYMBOLS.get(self.variable, self.variable)
unit = _UNITS.get(self.variable, "")
# Header
lines.append(f"{var_symbol} Attribution Results")
lines.append("=" * 40)
# Total change
total = self.summary.loc["delta_total", "mean"]
lines.append(f"Mean delta_{var_symbol}: {total:+.3f} {unit}")
lines.append("")
# Component breakdown
lines.append("Component Breakdown:")
lines.append("-" * 40)
main_components = [
c
for c in _MAIN_COMPONENTS.get(self.variable, [])
if c in self.summary.index
]
for comp in main_components:
val = self.summary.loc[comp, "mean"]
pct = 100 * val / total if abs(total) > 1e-10 else 0
lines.append(f" {comp:15s}: {val:+.3f} {unit} ({pct:5.1f}%)")
# Flux sub-components (if hierarchical - T2/q2 only)
if self.variable != "U10":
flux_comps = [c for c in self.summary.index if c.startswith("flux_")]
flux_comps = [c for c in flux_comps if c != "flux_total"]
if flux_comps:
lines.append("")
lines.append(" Flux breakdown:")
for comp in flux_comps:
val = self.summary.loc[comp, "mean"]
pct = 100 * val / total if abs(total) > 1e-10 else 0
label = comp.replace("flux_", " d")
lines.append(f" {label:13s}: {val:+.3f} {unit} ({pct:5.1f}%)")
# Closure check
lines.append("")
avail_components = [
c for c in main_components if c in self.contributions.columns
]
sum_components = self.contributions[avail_components].sum(axis=1).mean()
residual = total - sum_components
lines.append(f"Closure residual: {residual:.2e} {unit}")
return "\n".join(lines)
def plot(
self,
kind: Literal["bar", "diurnal", "line", "heatmap"] = "bar",
ax=None,
components: Optional[list[str]] = None,
**kwargs,
):
"""
Visualise attribution results.
Parameters
----------
kind : str, optional
Plot type:
- 'bar': Stacked bar of mean contributions (default)
- 'diurnal': Ensemble diurnal cycle with IQR shading
- 'line': Time series of all contributions
- 'heatmap': Month x hour heatmap of total change
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
components : list of str, optional
Components to include. If None, uses main components.
**kwargs
Additional keyword arguments passed to plotting function.
Returns
-------
fig, ax : tuple
Figure and axes objects
"""
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots(figsize=(8, 5))
else:
fig = ax.get_figure()
# Default components depend on variable type
if components is None:
components = [
c
for c in _DEFAULT_PLOT_COMPONENTS.get(self.variable, [])
if c in self.contributions.columns
]
unit = _UNITS.get(self.variable, "")
if kind == "bar":
# Stacked bar chart of mean contributions
means = self.summary.loc[components, "mean"]
colors = plt.cm.Set2(range(len(components)))
bars = ax.bar(
range(len(components)),
means.values,
color=colors,
edgecolor="black",
linewidth=0.5,
)
ax.set_xticks(range(len(components)))
ax.set_xticklabels(
[c.replace("_", "\n") for c in components], rotation=0, fontsize=9
)
ax.axhline(y=0, color="black", linewidth=0.5)
ax.set_ylabel(f"Contribution ({unit})")
ax.set_title(f"{self.variable} Attribution")
# Add value labels
for bar, val in zip(bars, means.values):
height = bar.get_height()
ax.annotate(
f"{val:+.2f}",
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3 if height >= 0 else -10),
textcoords="offset points",
ha="center",
va="bottom" if height >= 0 else "top",
fontsize=8,
)
elif kind == "diurnal":
# Ensemble diurnal cycle with IQR shading
# Requires DatetimeIndex for hour extraction
if not isinstance(self.contributions.index, pd.DatetimeIndex):
raise ValueError(
"Diurnal plot requires a DatetimeIndex. "
"This result appears to be an aggregate comparison "
"(e.g., from diagnose_* with method='diurnal'). "
"Use kind='bar' instead for aggregate results."
)
df = self.contributions[components].copy()
df["hour"] = df.index.hour + df.index.minute / 60
grouped = df.groupby("hour")
hours = sorted(df["hour"].unique())
for i, comp in enumerate(components):
color = plt.cm.Set2(i)
median = grouped[comp].median()
q25 = grouped[comp].quantile(0.25)
q75 = grouped[comp].quantile(0.75)
ax.plot(hours, median.values, label=comp, color=color)
ax.fill_between(hours, q25.values, q75.values, alpha=0.3, color=color)
ax.axhline(y=0, color="black", linewidth=0.5, linestyle="--")
ax.set_xlabel("Hour of day")
ax.set_ylabel(f"Contribution ({unit})")
ax.set_title(f"{self.variable} Attribution - Diurnal Cycle")
ax.legend(loc="best", fontsize=8)
ax.set_xlim(0, 24)
ax.set_xticks(range(0, 25, 3))
elif kind == "line":
# Time series
for i, comp in enumerate(components):
color = plt.cm.Set2(i)
ax.plot(
self.contributions.index,
self.contributions[comp],
label=comp,
color=color,
alpha=0.7,
)
ax.axhline(y=0, color="black", linewidth=0.5, linestyle="--")
ax.set_xlabel("Time")
ax.set_ylabel(f"Contribution ({unit})")
ax.set_title(f"{self.variable} Attribution - Time Series")
ax.legend(loc="best", fontsize=8)
elif kind == "heatmap":
# Month x hour heatmap of total change
# Requires DatetimeIndex for month/hour extraction
if not isinstance(self.contributions.index, pd.DatetimeIndex):
raise ValueError(
"Heatmap plot requires a DatetimeIndex. "
"This result appears to be an aggregate comparison. "
"Use kind='bar' instead for aggregate results."
)
import matplotlib.colors as mcolors
df = self.contributions.copy()
df["month"] = df.index.month
df["hour"] = df.index.hour
pivot = df.pivot_table(
values="delta_total", index="hour", columns="month", aggfunc="mean"
)
# Diverging colormap centred at zero
vmax = max(abs(pivot.values.min()), abs(pivot.values.max()))
norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
im = ax.imshow(
pivot.values,
aspect="auto",
cmap="RdBu_r",
norm=norm,
origin="lower",
)
_month_labels = ["J", "F", "M", "A", "M", "J",
"J", "A", "S", "O", "N", "D"]
ax.set_yticks(range(len(pivot.index)))
ax.set_yticklabels(pivot.index)
ax.set_xticks(range(len(pivot.columns)))
ax.set_xticklabels([_month_labels[m - 1] for m in pivot.columns])
ax.set_xlabel("Month")
ax.set_ylabel("Hour")
ax.set_title(f"{self.variable} Attribution - Seasonal-Diurnal Pattern")
fig.colorbar(im, ax=ax, label=f"delta_{self.variable} ({unit})")
return fig, ax
def to_dataframe(self) -> pd.DataFrame:
"""Return contributions as a DataFrame."""
return self.contributions.copy()