  1""" Plotting functionality """
  3from pathlib import Path
  5import matplotlib.patches as patches
  6import matplotlib.pyplot as plt
  7import numpy as np
  8import pandas as pd
  9from matplotlib.colors import LinearSegmentedColormap
 10from matplotlib.gridspec import GridSpec
 11from matplotlib.ticker import AutoMinorLocator
 13from foam import functions_for_mesa as ffm
 17def make_multipanel_plot(
 18    nr_panels=1,
 19    xlabel="",
 20    ylabels=[""],
 21    keys=None,
 22    title="",
 23    label_size=22,
 24    xlim=[],
 25    left_space=0.1,
 26    bottom_space=0.085,
 27    right_space=0.978,
 28    top_space=0.97,
 29    h_space=0.12,
 30    figure_size=[12, 8],
 32    """
 33    Make a multipanel figure for plots.
 35    Parameters
 36    ----------
 37    nr_panels: int
 38        The number of panels to add to the plot
 39    xlabel: string
 40        Name for x axis.
 41    ylabels: list of strings
 42        Names for y axes.
 43    keys:
 44        The keys corresponding to the dictionary entries of the axes
 45    title: string
 46        Title of the plot, no title if argument not given.
 47    left_space, bottom_space, right_space, top_space: float, optional
 48        The space that needs to be left open around the figure.
 49    h_space: float
 50        The size of the space in between both panels.
 51    figure_size: list of 2 floats
 52        Specify the dimensions of the figure in inch.
 53    label_size: float
 54        The size of the labels on the axes, and slightly smaller on the ticks
 55    xlim: list
 56        Lists of length 2, specifying the lower and upper limits of the x axe.
 57        Matplotlib sets the limits automatically if the arguments are not given.
 59    Returns
 60    ----------
 61    ax_dict: dict
 62        Dictionary of the axes
 63    fig: Figure
 64        The figure
 65    """
 66    if keys == None:  # make keys integers
 67        keys = range(nr_panels)
 69    fig = plt.figure(figsize=(figure_size[0], figure_size[1]))
 70    gs = GridSpec(nr_panels, 1)  # multiple rows, 1 column
 71    ax_dict = {}
 73    for i in range(0, nr_panels):
 74        if i == 0:
 75            ax = fig.add_subplot(gs[i : i + 1, 0])
 76            if len(xlim) == 2:
 77                ax.set_xlim(xlim[0], xlim[1])
 79        else:
 80            ax = fig.add_subplot(gs[i : i + 1, 0], sharex=ax_dict[keys[0]])
 82        ax_dict.update({keys[i]: ax})
 84        ax.set_ylabel(ylabels[i], size=label_size)
 85        ax.tick_params(labelsize=label_size - 2)
 87        if i == nr_panels - 1:
 88            ax.set_xlabel(xlabel, size=label_size)
 89        else:
 90            plt.setp(ax.get_xticklabels(), visible=False)
 92    ax_dict[keys[0]].set_title(title)
 93    plt.subplots_adjust(hspace=h_space, left=left_space, right=right_space, top=top_space, bottom=bottom_space)
 95    return ax_dict, fig
 99def corner_plot(
100    merit_values_file,
101    merit_values_file_error_ellipse,
102    fig_title,
103    observations_file,
104    label_size=20,
105    fig_output_dir="figures_correlation/",
106    percentile_to_show=0.5,
107    logg_or_logL="logL",
108    mark_best_model=False,
109    n_sigma_box=3,
110    grid_parameters=None,
111    axis_labels_dict={
112        "rot": r"$\Omega_{\mathrm{rot}}$ [d$^{-1}$]",
113        "M": r"M$_{\rm ini}$",
114        "Z": r"Z$_{\rm ini}$",
115        "logD": r"log(D$_{\rm env}$)",
116        "aov": r"$\alpha_{\rm CBM}$",
117        "fov": r"f$_{\rm CBM}$",
118        "Xc": r"$\rm X_c$",
119    },
121    """
122    Make a plot of all variables vs each other variable, showing the MLE values as colorscale.
123    A kiel/HR diagram is made, depending on if logg_obs or logL_obs is passed as a parameter.
124    The subplots on the diagonal show the distribution of that variable.
125    The list of variables is retrieved from columns of the merit_values_file,
126    where the first column is 'meritValue', which are the MLE values.
127    The resulting figure is saved afterwards in the specified location.
129    Parameters
130    ----------
131    merit_values_file: string
132        Path to the hdf5 file with the merit function values and parameters of the models in the grid.
133    merit_values_file_error_ellipse: string
134        Path to the hdf5 files with the merit function values and parameters of the models in the error ellipse.
135    observations_file: string
136        Path to the tsv file with observations, with a column for each observable and each set of errors.
137        Column names specify the observable, and "_err" suffix denotes that it's the error.
138    fig_title: string
139        Title of the figure and name of the saved png.
140    label_size: int
141        Size of the axis labels.
142    fig_output_dir: string
143        Output directory for the figures.
144    percentile_to_show: float
145        Percentile of models to show in the plots.
146    logg_or_logL: string
147        String 'logg' or 'logL' indicating whether log of surface gravity (g) or luminosity (L) is plot.
148    mark_best_model: boolean
149        Indicate the best model with a marker
150    grid_parameters: list of string
151        List of the parameters in the theoretical grid.
152    axis_labels_dict: dictionary
153        Keys are grid parameters, values are strings how those values should be shown on the axis labels
154    """
155    # Define custom colormap
156    cdict = {
157        "red": ((0.0, 1.0, 1.0), (0.5, 1.0, 1.0), (0.75, 1.0, 1.0), (1.0, 0.75, 0.75)),
158        "green": ((0.0, 1.0, 1.0), (0.25, 0.5, 0.5), (0.5, 0.0, 0.0), (1.0, 0.0, 0.0)),
159        "blue": ((0.0, 0.0, 0.0), (0.5, 0.0, 0.0), (1.0, 1.0, 1.0)),
160    }
161    CustomCMap = LinearSegmentedColormap("CustomMap", cdict)
162    # theoretical models within the error ellipse
163    dataframe_theory_error_ellipse = pd.read_hdf(merit_values_file_error_ellipse, sep="\s+", header=0)
164    dataframe_theory_error_ellipse = dataframe_theory_error_ellipse.sort_values(
165        "meritValue", ascending=False
166    )  # Order from high to low, to plot lowest values last
168    # theoretical models
169    dataframe_theory = pd.read_hdf(merit_values_file)
170    dataframe_theory = dataframe_theory.sort_values(
171        "meritValue", ascending=False
172    )  # Order from high to low, to plot lowest values last
173    dataframe_theory = dataframe_theory.iloc[
174        int(dataframe_theory.shape[0] * (1 - percentile_to_show)) :
175    ]  # only plot the given percentage lowest meritValues
177    if (
178        dataframe_theory.iloc[0]["rot"]
179        == dataframe_theory.iloc[1]["rot"]
180        == dataframe_theory.iloc[2]["rot"]
181        == dataframe_theory.iloc[-1]["rot"]
182    ):  # rotation is fixed, don't plot it
183        # make new dataframe with only needed info
184        df_error_ellipse = dataframe_theory_error_ellipse.filter(["meritValue"] + grid_parameters)
185        df = dataframe_theory.filter(["meritValue"] + grid_parameters)
186        # Remove models in the error ellipse from the regular dataframe.
187        df = (
188            pd.merge(
189                df,
190                df_error_ellipse,
191                indicator=True,
192                how="outer",
193                on=grid_parameters,
194                suffixes=[None, "_remove"],
195            )
196            .query('_merge=="left_only"')
197            .drop(["meritValue_remove", "_merge"], axis=1)
198        )
200    else:  # rotation was varied, include it in the plots
201        # make new dataframe with only needed info
202        df_error_ellipse = dataframe_theory_error_ellipse.filter(["meritValue"] + ["rot"] + grid_parameters)
203        df = dataframe_theory.filter(["meritValue"] + ["rot"] + grid_parameters)
204        # Remove models in the error ellipse from the regular dataframe.
205        df = (
206            pd.merge(
207                df,
208                df_error_ellipse,
209                indicator=True,
210                how="outer",
211                on=grid_parameters,
212                suffixes=[None, "_remove"],
213            )
214            .query('_merge=="left_only"')
215            .drop(["meritValue_remove", "rot_remove", "_merge"], axis=1)
216        )
218    ax_dict = {}
219    # dictionary of dictionaries, holding the subplots of the figure, keys indicate position (row, column) of the subplot
220    nr_params = len(df.columns) - 1
221    for i in range(nr_params):
222        ax_dict.update({i: {}})
224    fig = plt.figure(figsize=(10, 8))
225    gs = GridSpec(nr_params, nr_params)  # multiple rows and columns
227    if mark_best_model:
228        # get the best model according to the point estimator
229        min_index = df_error_ellipse["meritValue"].idxmin(axis="index", skipna=True)
231    for ix in range(0, nr_params):
232        for iy in range(0, nr_params - ix):
233            if iy == 0:
234                share_x = None
235            else:
236                share_x = ax_dict[0][ix]
237            if (ix == 0) or (iy + ix == nr_params - 1):
238                share_y = None
239            else:
240                share_y = ax_dict[iy][0]
242            # create subplots and add them to the dictionary
243            ax = fig.add_subplot(gs[nr_params - iy - 1 : nr_params - iy, ix : ix + 1], sharex=share_x, sharey=share_y)
244            ax_dict[iy].update({ix: ax})
246            # manage visibility and size of the labels and ticks
247            ax.tick_params(labelsize=label_size - 4)
248            if ix == 0:
249                ax.set_ylabel(axis_labels_dict[df.columns[iy + 1]], size=label_size)
250                if iy == nr_params - 1:
251                    plt.setp(ax.get_yticklabels(), visible=False)
252            else:
253                plt.setp(ax.get_yticklabels(), visible=False)
254            if iy == 0:
255                ax.set_xlabel(axis_labels_dict[df.columns[nr_params - ix]], size=label_size)
256                ax.tick_params(axis="x", rotation=45)
257            else:
258                plt.setp(ax.get_xticklabels(), visible=False)
260            if iy + ix == nr_params - 1:  # make distribution plots on the diagonal subplots
261                values = sorted(np.unique(df.iloc[:, nr_params - ix]))
262                # determine edges of the bins for the histogram distribution plots
263                if df.columns[nr_params - ix] == "rot":
264                    domain = (values[0], values[-1])
265                    ax.hist(
266                        df_error_ellipse.iloc[:, nr_params - ix],
267                        bins=25,
268                        range=domain,
269                        density=False,
270                        cumulative=False,
271                        histtype="step",
272                    )
274                else:
275                    if len(values) > 1:
276                        bin_half_width = (values[0] + values[1]) / 2 - values[0]
277                    else:
278                        bin_half_width = 1e-3
279                    bin_edges = [values[0] - bin_half_width]
280                    for i in range(len(values) - 1):
281                        bin_edges.extend([(values[i] + values[i + 1]) / 2])
282                    bin_edges.extend([values[-1] + bin_half_width])
283                    ax.hist(
284                        df_error_ellipse.iloc[:, nr_params - ix],
285                        bins=bin_edges,
286                        density=False,
287                        cumulative=False,
288                        histtype="step",
289                    )
291                ax.tick_params(axis="y", left=False)
292                continue
294            im = ax.scatter(
295                df.iloc[:, nr_params - ix],
296                df.iloc[:, iy + 1],
297                c=np.log10(df.iloc[:, 0]),
298                cmap="Greys_r",
299            )
300            im = ax.scatter(
301                df_error_ellipse.iloc[:, nr_params - ix],
302                df_error_ellipse.iloc[:, iy + 1],
303                c=np.log10(dataframe_theory_error_ellipse["meritValue"]),
304                cmap=CustomCMap,
305            )
306            if mark_best_model:
307                ax.scatter(
308                    df_error_ellipse.loc[min_index][nr_params - ix],
309                    df.loc[min_index][iy + 1],
310                    color="white",
311                    marker="x",
312                )
313            # Adjust x an y limits of subplots
314            limit_adjust = (max(df.iloc[:, iy + 1]) - min(df.iloc[:, iy + 1])) * 0.08
315            if limit_adjust == 0:
316                limit_adjust = 0.1
317            ax.set_ylim(min(df.iloc[:, iy + 1]) - limit_adjust, max(df.iloc[:, iy + 1]) + limit_adjust)
318            limit_adjust = (max(df.iloc[:, nr_params - ix]) - min(df.iloc[:, nr_params - ix])) * 0.08
319            if limit_adjust == 0:
320                limit_adjust = 0.1
321            ax.set_xlim(
322                min(df.iloc[:, nr_params - ix]) - limit_adjust,
323                max(df.iloc[:, nr_params - ix]) + limit_adjust,
324            )
326    fig.align_labels()
327    # add subplot in top right for Kiel or HRD
328    ax_hrd = fig.add_axes([0.508, 0.65, 0.33, 0.33])  # X, Y, width, height
330    ax_hrd.set_xlabel(r"log(T$_{\mathrm{eff}}$ [K])", size=label_size)
331    ax_hrd.tick_params(labelsize=label_size - 4)
332    ax_hrd.invert_xaxis()
334    # Observations
335    if n_sigma_box != None:
336        obs_dataframe = pd.read_table(observations_file, sep="\s+", header=0, index_col="index")
337        if (("logL" in obs_dataframe.columns) or ("logg" in obs_dataframe.columns)) and (
338            "Teff" in obs_dataframe.columns
339        ):
340            if "logL" not in obs_dataframe.columns:
341                logg_or_logL = "logg"
343            # Observed spectroscopic error bar, only added if observational constraints were provided.
344            # To add the 1 and n-sigma spectro error boxes, calculate their width (so 2 and 2*n sigma wide)
345            width_logTeff_sigma = np.log10(
346                obs_dataframe["Teff"].iloc[0] + obs_dataframe["Teff_err"].iloc[0]
347            ) - np.log10(obs_dataframe["Teff"].iloc[0] - obs_dataframe["Teff_err"].iloc[0])
348            width_logTeff_nsigma = np.log10(
349                obs_dataframe["Teff"].iloc[0] + n_sigma_box * obs_dataframe["Teff_err"].iloc[0]
350            ) - np.log10(obs_dataframe["Teff"].iloc[0] - n_sigma_box * obs_dataframe["Teff_err"].iloc[0])
351            errorbox_1s = patches.Rectangle(
352                (
353                    np.log10(obs_dataframe["Teff"].iloc[0] - obs_dataframe["Teff_err"].iloc[0]),
354                    obs_dataframe[logg_or_logL].iloc[0] - obs_dataframe[f"{logg_or_logL}_err"].iloc[0],
355                ),
356                width_logTeff_sigma,
357                2 * obs_dataframe[f"{logg_or_logL}_err"].iloc[0],
358                linewidth=1.7,
359                edgecolor="cyan",
360                facecolor="none",
361                zorder=2.1,
362            )
363            errorbox_ns = patches.Rectangle(
364                (
365                    np.log10(obs_dataframe["Teff"].iloc[0] - n_sigma_box * obs_dataframe["Teff_err"].iloc[0]),
366                    obs_dataframe[logg_or_logL].iloc[0] - n_sigma_box * obs_dataframe[f"{logg_or_logL}_err"].iloc[0],
367                ),
368                width_logTeff_nsigma,
369                2 * n_sigma_box * obs_dataframe[f"{logg_or_logL}_err"].iloc[0],
370                linewidth=1.7,
371                edgecolor="cyan",
372                facecolor="none",
373                zorder=2.1,
374            )
375            ax_hrd.add_patch(errorbox_1s)
376            ax_hrd.add_patch(errorbox_ns)
378    if logg_or_logL == "logg":
379        ax_hrd.invert_yaxis()
381    im = ax_hrd.scatter(
382        dataframe_theory["logTeff"],
383        dataframe_theory[logg_or_logL],
384        c=np.log10(dataframe_theory["meritValue"]),
385        cmap="Greys_r",
386    )
387    im_error_ellipse = ax_hrd.scatter(
388        dataframe_theory_error_ellipse["logTeff"],
389        dataframe_theory_error_ellipse[logg_or_logL],
390        c=np.log10(dataframe_theory_error_ellipse["meritValue"]),
391        cmap=CustomCMap,
392    )
393    ax_hrd.set_ylabel(f"{logg_or_logL[:-1]} {logg_or_logL[-1]}")
394    if logg_or_logL == "logL":
395        ax_hrd.set_ylabel(r"log(L/L$_{\odot}$)", size=label_size)
396    elif logg_or_logL == "logg":
397        ax_hrd.set_ylabel(r"$\log\,g$ [dex]", size=label_size)
399    ax_hrd.xaxis.set_minor_locator(AutoMinorLocator(2))
400    ax_hrd.yaxis.set_minor_locator(AutoMinorLocator(2))
401    ax_hrd.tick_params(which="major", length=6)
402    ax_hrd.tick_params(which="minor", length=4)
404    if mark_best_model:
405        ax_hrd.scatter(
406            dataframe_theory_error_ellipse["logTeff"][min_index],
407            dataframe_theory_error_ellipse[logg_or_logL][min_index],
408            marker="x",
409            color="white",
410        )
412    # Add color bar
413    cax = fig.add_axes([0.856, 0.565, 0.04, 0.415])  # X, Y, width, height
414    cbar = fig.colorbar(im, cax=cax, orientation="vertical")
415    cax2 = fig.add_axes([0.856, 0.137, 0.04, 0.415])  # X, Y, width, height
416    cbar2 = fig.colorbar(
417        im_error_ellipse,
418        cax=cax2,
419        orientation="vertical",
420    )
422    # To prevent messing up colours due to automatic rescaling of colorbar
423    if dataframe_theory_error_ellipse.shape[0] == 1:
424        im_error_ellipse.set_clim(
425            np.log10(dataframe_theory_error_ellipse["meritValue"]),
426            np.log10(dataframe_theory_error_ellipse["meritValue"]) * 1.1,
427        )
429    if "_MD_" in fig_title:
430        cbar.set_label("log(MD)", rotation=90, size=label_size)
431        cbar2.set_label("log(MD)", rotation=90, size=label_size)
432    elif "_CS_" in fig_title:
433        cbar.set_label(r"log($\chi^2$)", rotation=90, size=label_size)
434        cbar2.set_label(r"log($\chi^2$)", rotation=90, size=label_size)
435    else:
436        cbar.set_label("log(merit function value)", rotation=90)
437    cbar.ax.tick_params(labelsize=label_size - 4)
438    cbar2.ax.tick_params(labelsize=label_size - 4)
439    fig.subplots_adjust(left=0.114, right=0.835, bottom=0.137, top=0.99)
441    # fig.suptitle(fig_title, horizontalalignment='left', size=20, x=0.28)
442    Path(fig_output_dir).mkdir(parents=True, exist_ok=True)
443    fig.savefig(f"{fig_output_dir}{fig_title}.png", dpi=400)
444    plt.clf()
445    plt.close(fig)
449def plot_mesa_file(
450    profile_file,
451    x_value,
452    y_value,
453    ax=None,
454    label_size=16,
455    colour="",
456    linestyle="solid",
457    alpha=1,
458    legend=True,
459    label=None,
461    """
462    Plot the requested quantities for the given MESA profile or history file.
464    Parameters
465    ----------
466    profile_file: string
467        The path to the profile file to be used for the plotting.
468    x_value: string
469        The parameter of the profile plotted on the x axis.
470        If x_value is mass or radius, it will be put in units relative to the total mass or radius
471    y_value: string
472        The parameter of the profile plotted on the y axis.
473    ax: Axes
474        Axes object on which the plot will be made. If None: make figure and axis within this function.
475    label_size, alpha: float
476        The size of the labels in the figure, and transparency of the plot
477    colour: string
478        Colour of the plotted data.
479    linestyle: string
480        Linestyle of the plotted data.
481    label: string
482        Label of the plotted data.
483    legend: boolean
484        Flag to enable or disable a legend on the figure
485    """
486    if ax is None:
487        fig = plt.figure()
488        ax = fig.add_subplot(111)
490    header, data = ffm.read_mesa_file(profile_file)
491    # from "data", extract the columns
492    y = np.asarray(data[y_value])
493    x = np.asarray(data[x_value])
494    if label == None:  # Set the label to be the name of the y variable
495        label = y_value
496    if x_value == "radius" or x_value == "mass":
497        x = x / x[0]  # normalized radius/mass coordinates
498        ax.set_xlim(0, 1)
499    # generate the plot, in which colour will not be specified
500    if colour == "":
501        ax.plot(x, y, label=label, linestyle=linestyle, alpha=alpha)
502    # generate the plot, in which colour will be specified
503    else:
504        ax.plot(x, y, label=label, linestyle=linestyle, alpha=alpha, color=colour)
505    if legend is True:
506        ax.legend(loc="best", prop={"size": label_size})
507    ax.set_xlabel(x_value, size=label_size)
508    ax.set_ylabel(y_value, size=label_size)
512def plot_mesh_histogram(
513    profile_file,
514    x_value="radius",
515    ax=None,
516    label_size=16,
517    colour="",
518    linestyle="solid",
519    alpha=1,
520    legend=True,
521    label=None,
522    bins=200,
524    """
525    Make a histogram of the mesh points in the MESA profile.
527    Parameters
528    ----------
529    profile_file: string
530        The path to the profile file to be used for the plotting.
531    x_value: string
532        The x value to use for the histogram
533        If x_value is mass or radius, it will be put in units relative to the total mass or radius
534    ax: an axis object
535        Axes object on which the plot will be made. If None: make figure and axis within this function.
536    label_size: float
537        The size of the labels in the figure.
538    alpha: float
539        Transparency of the plot.
540    colour, linestyle, label: float
541        Settings for the plot
542    legend: boolean
543        Flag to enable or disable a legend on the figure.
544    bins: int
545        Number of bins used in the histogram
546    """
547    if ax is None:
548        fig = plt.figure()
549        ax = fig.add_subplot(111)
551    header, data = ffm.read_mesa_file(profile_file)
552    print(f'Total zones of {profile_file} : {header["num_zones"]}')
554    # from "data", extract the columns
555    x = np.asarray(data[x_value])
556    if label == None:  # Set the label to be the name of the y variable
557        legend = False
558    if x_value == "radius" or x_value == "mass":
559        x = x / x[0]  # normalized radius/mass coordinates
560        ax.set_xlim(0, 1)
561    # generate the plot, in which colour will not be specified
562    if colour == "":
563        ax.hist(x, bins=bins, histtype="step", label=label, alpha=alpha, linestyle=linestyle)
564    # generate the plot, in which colour will be specified
565    else:
566        ax.hist(
567            x,
568            bins=bins,
569            histtype="step",
570            label=label,
571            alpha=alpha,
572            linestyle=linestyle,
573            color=colour,
574        )
575    # generate a legend if true
576    if legend is True:
577        ax.legend(loc="best", prop={"size": label_size})
578    ax.set_xlabel(x_value, size=label_size)
579    ax.set_ylabel("Meshpoints", size=label_size)
583def plot_hrd(
584    hist_file,
585    ax=None,
586    colour="blue",
587    linestyle="solid",
588    label="",
589    label_size=16,
590    Xc_marked=None,
591    Teff_logscale=True,
592    start_track_from_Xc=None,
593    diagram="HRD",
595    """
596    Makes an HRD plot from a provided MESA history file.
598    Parameters
599    ----------
600    hist_file: string
601        The path to the profile file to be used for the plot.
602    ax: Axes
603        Axes object on which the plot will be made. If None: make figure and axis within this function.
604    colour: string
605        Colour of the plotted data.
606    linestyle: string
607        Linestyle of the plotted data.
608    label: string
609        Label of the plotted data.
610    label_size: float
611        The size of the labels in the figure.
612    Xc_marked: list of floats
613        Models with these Xc values are marked with red dots on the plot (listed in increasing value).
614    Teff_logscale: boolean
615        Plot effective temperature in logscale (True), or not (False).
616    start_track_from_Xc: float
617        Only start plotting the track if Xc drops below this value (e.g. to not plot the initial relaxation loop).
618    diagram: string
619        Type of diagram that is plotted. Options are HRD (logL vs logTeff), sHRD (log(Teff^4/g) vs logTeff) or kiel (logg vs logTeff).
620    """
621    if ax is None:
622        fig = plt.figure()
623        ax = fig.add_subplot(111)
625    header, data = ffm.read_mesa_file(hist_file)
627    # From "data", extract the required columns as numpy arrays
628    log_L = np.asarray(data["log_L"])
629    log_Teff = np.asarray(data["log_Teff"])
630    log_g = np.asarray(data["log_g"])
631    center_h1 = np.asarray(data["center_h1"])
632    # Plot the x-axis in log scale
633    if Teff_logscale:
634        T = log_Teff
635        ax.set_xlabel(r"log(T$_{\mathrm{eff}}$)", size=label_size)
636    # Plot the x-axis in linear scale
637    else:
638        T = 10**log_Teff
639        ax.set_xlabel(r"T$_{\mathrm{eff}}$ [K]", size=label_size)
641    # Plot HRD
642    if diagram == "HRD":
643        y_axis = log_L
644        ax.set_ylabel(r"log(L/L$_{\odot}$)", size=label_size)
645    # Plot sHRD (log_Teff^4/log_g vs log_Teff)
646    elif diagram == "sHRD":
647        log_Lsun = 10.61
648        y_axis = 4 * log_Teff - log_g - log_Lsun
649        ax.set_ylabel(
650            r"$\log \left(\frac{{T_{\mathrm{eff}}}^4}{g}\right) \ (\mathscr{L}_\odot)$",
651            size=label_size,
652        )
653    # Plot Kiel diagram (log_g vs log_Teff)
654    elif diagram == "kiel":
655        y_axis = log_g
656        ax.set_ylabel(r"log g [dex]", size=label_size)
658    # Start plotting from Xc value
659    if start_track_from_Xc != None:
660        for i in range(len(center_h1)):
661            if center_h1[i] < start_track_from_Xc:
662                T = T[i:]
663                y_axis = y_axis[i:]
664                break
666    # Plot the HRD diagram (log_L vs. T)
667    ax.plot(T, y_axis, color=colour, linestyle=linestyle, label=label)
669    # Put specific marks on the HRD diagram
670    if Xc_marked is None:
671        return
672    k = 0
673    for i in range(len(center_h1) - 1, -1, -1):
674        if center_h1[i] > Xc_marked[k]:
675            ax.scatter(T[i], log_L[i], marker="o", color="red", lw=2)
676            k += 1
677            if k >= len(Xc_marked):
678                return
682def plot_khd(hist_file, ax=None, number_mix_zones=8, xaxis="model_number"):
683    """
684    Makes a Kippenhahn plot from a provided MESA history file.
686    Parameters
687    ----------
688    hist_file: string
689        The path to the history file to be used for the plot.
690    ax: Axes
691        Axes object on which the plot will be made. If None: make figure and axis within this function.
692    number_mix_zones: int
693        Number of mixing zones included in the mesa history file.
694    xaxis: string
695        Quantity to put on the x-axis of the plot (e.g. model_number or star_age).
696    """
697    if ax is None:
698        fig = plt.figure(figsize=(10, 4))
699        ax = fig.add_subplot(111)
701    _, data = ffm.read_mesa_file(hist_file)
703    x_values = data[xaxis]
704    m_star = data["star_mass"]
705    m_ini = m_star[0]
707    for j in range(number_mix_zones):
708        colours = {
709            "-1": "w",
710            "0": "w",
711            "1": "lightgrey",
712            "2": "b",
713            "7": "g",
714            "3": "cornflowerblue",
715            "8": "red",
716        }
717        if j == number_mix_zones - 1:
718            ax.vlines(
719                x_values,
720                0,
721                data[f"mix_qtop_{number_mix_zones-j}"] * m_star / m_ini,
722                color=[colours[str(x)] for x in data[f"mix_type_{number_mix_zones-j}"]],
723            )
724        else:
725            ax.vlines(
726                x_values,
727                data[f"mix_qtop_{number_mix_zones-1-j}"] * m_star / m_ini,
728                data[f"mix_qtop_{number_mix_zones-j}"] * m_star / m_ini,
729                color=[colours[str(x)] for x in data[f"mix_type_{number_mix_zones-j}"]],
730            )
732    ax.plot(x_values, m_star / m_ini, lw=1, color="black", label=f"{m_ini:.1f} $M_\odot$")
734    ax.set_xlim(min(x_values) * 0.99, max(x_values) * 1.01)
735    ax.set_ylim(0, 1.02)
737    # Only to make it appear in the Legend
738    ax.plot([], [], lw=10, color="lightgrey", label=r"Convective")
739    # Only to make it appear in the Legend
740    ax.plot([], [], lw=10, color="cornflowerblue", label=r"CBM")
741    ax.legend(
742        bbox_to_anchor=(0.15, 0.97),
743        fontsize=10,
744        frameon=False,
745        fancybox=False,
746        shadow=False,
747        borderpad=False,
748    )
749    ax.set_ylabel(r"Relative Mass $\, m/M_\star$")
750    ax.set_xlabel(xaxis)
751    plt.tight_layout()
753    return
