Source code for exotic_miri.reference.set_custom_linearity

import numpy as np
from jwst import datamodels
from jwst.stpipe import Step
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors


[docs] class SetCustomLinearity(Step): """ Set a custom linearity correction. """ spec = """ group_idx_start_fit = integer(default=10) # first group index included in linear fit group_idx_end_fit = integer(default=40) # end group index included in linear fit group_idx_start_derive = integer(default=10) # first group index included in poly deviation group_idx_end_derive = integer(default=-1) # end group index included in poly deviation row_idx_start_used = integer(default=350) # first row index to be included in derivation row_idx_end_used = integer(default=386) # end row index to be included in derivation draw_corrections = boolean(default=False) # draw corrections """
[docs] def process(self, input): """ Make self-calibrated linearity corrections per amplifier. This step uses the uncal.fits data to create a new linearity model. This model can then be passed to the jwst.calwebb_detector1.linearity_step via the arg 'override_linearity'. The correction involves extrapolating a linear fit to an assumed linear /“well-behaved” section of the ramps, and then fitting a polynomial to the residuals. The polynomial has the constant- and linear-term coefficients fixed at 0 and 1 respectively. Recommended usage requires a large number of groups, >~40, although this is still experimental. Parameters ---------- input: jwst.datamodels.RampModel This is an uncal.fits loaded data segment. group_idx_start_fit: integer The first group index included in the linear fit. This corresponds to the start of the section of the ramp which is assumed to be well behaved. Default is 10. group_idx_end_fit: integer The last group index included in the linear fit. This corresponds to the end of the section of the ramp which is assumed to be well behaved. Default is 40. group_idx_start_derive: integer The first group index included in the derived linearity correction. Default is 10. group_idx_end_derive: integer The last group index included in the derived linearity correction. Default is -1. row_idx_start_used: integer The first row index included in the derived linearity correction. Default is 350. row_idx_end_used: integer The last row index included in the derived linearity correction. Default is 386. draw_corrections: boolean Plot the derived linearity correction. Returns ------- linearity : jwst.datamodels.LinearityModel The linearity model which can be passed to other steps. """ with datamodels.open(input) as input_model: # Check input model type. if not isinstance(input_model, datamodels.RampModel): self.log.error("Input is a {} which was not expected for " "CustomLinearityStep, skipping step.".format( str(type(input_model)))) return None groups_all = np.arange(self.group_idx_start_derive, self.group_idx_end_derive) # Exclude grps beyond help, e.g., final. groups_fit = np.arange(self.group_idx_start_fit, self.group_idx_end_fit) rows = (self.row_idx_start_used, self.row_idx_end_used) amplifier_cols = [34, 35, 36, 37, 38] amplifier_idxs = [2, 3, 0, 1, 2] amplifier_dns = [[], [], [], []] amplifier_fs = [[], [], [], []] amplifier_ccs = [[], [], [], []] for amp_idx, amp_col in zip(amplifier_idxs, amplifier_cols): # Get linear section of ramps for fitting and all for calibration. ramps_all = input_model.data[ :, groups_all, rows[0]:rows[1], amp_col]\ .reshape(groups_all.shape[0], -1) ramps_fit = input_model.data[ :, groups_fit, rows[0]:rows[1], amp_col]\ .reshape(groups_fit.shape[0], -1) # Fit each linear section with a linear model. lin_coeffs = np.polyfit(groups_fit, ramps_fit, 1) # Calculate linear model for all ramps. lin_ramps = np.matmul( lin_coeffs.T, np.array([groups_all, np.ones(groups_all.shape)])) # Save F and DN values per amplifier. amplifier_dns[amp_idx].extend(ramps_all.T.ravel().tolist()) amplifier_fs[amp_idx].extend(lin_ramps.ravel().tolist()) # F = c0 + c1 * DN + c2 * DN**2 + c3 * DN**3 + c4 * DN**4. for amp_idx in range(4): fix_lin = True if fix_lin: x = np.array(amplifier_dns[amp_idx]) y = amplifier_fs[amp_idx] xx_fix = np.vstack((x, np.ones_like(x))).T xx_fit = np.vstack((x**4, x**3, x**2)).T p_fix = np.array([1., 0.]) y_fix = np.dot(p_fix, xx_fix.T) p_fit = np.linalg.lstsq(xx_fit, y - y_fix, rcond=None)[0] corr_coeffs = np.concatenate([p_fit, p_fix]) else: corr_coeffs = np.polyfit(amplifier_dns[amp_idx], amplifier_fs[amp_idx], 4) amplifier_ccs[amp_idx].extend(corr_coeffs) if self.draw_corrections: self.draw_amplifier_corrections(amplifier_idxs, amplifier_dns, amplifier_fs, amplifier_ccs) # Use default reference file as template for custom file. self.log.info("Building custom linearity datamodel.") linearity_ref_name = self.get_reference_file(input_model, "linearity") linearity_model = datamodels.LinearityModel(linearity_ref_name) linearity_model.coeffs = np.zeros((5, 1024, 1032)) linearity_model.dq = np.zeros((1024, 1032)) for amp_idx in range(4): for coeff_idx, coeff in enumerate(np.flip(amplifier_ccs[amp_idx])): linearity_model.coeffs[coeff_idx, :, amp_idx::4] = coeff # Overwrite linearity_model.coeffs[:, :, :4] = 0. linearity_model.meta.ref_file = input_model.meta.ref_file return linearity_model
def finalize_result(self, res, ref): """ :meta private: """ # Required to enable ref model to be returned. # Overwrites base class method. pass def linearity_correction(self, dn, coeffs): return coeffs[4] + coeffs[3] * dn + coeffs[2] * dn**2 \ + coeffs[1] * dn**3 + coeffs[0] * dn**4 def draw_amplifier_corrections(self, amplifier_idxs, amplifier_dns, amplifier_fs, amplifier_ccs): fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(11, 7)) amp_colors = ["#003f5c", "#7a5195", "#ef5675", "#ffa600"] for amp_idx in range(4): ax1.scatter(amplifier_dns[amp_idx], amplifier_fs[amp_idx], c=amp_colors[amp_idx], alpha=0.005) ax3.scatter(amplifier_dns[amp_idx], np.array(amplifier_fs[amp_idx]) - np.array(amplifier_dns[amp_idx]), c=amp_colors[amp_idx], alpha=0.005) dns = np.linspace(0, np.max(amplifier_dns[amp_idx]), 1000) ax2.plot(dns, self.linearity_correction(dns, amplifier_ccs[amp_idx]), c=amp_colors[amp_idx], label="Amplifier {} correction" .format(amplifier_idxs[amp_idx])) xs = [] ys = [] for amp_idx in range(4): for x, y in zip(amplifier_dns[amp_idx], amplifier_fs[amp_idx]): xs.append(x) ys.append(y - x) ax4.hexbin(xs, ys, gridsize=(30, 30), norm=mcolors.PowerNorm(gamma=0.2)) ax1.set_xlabel("DN") ax1.set_ylabel("Corrected DN") ax3.set_xlabel("DN") ax3.set_ylabel("Linear model - DN") ax2.set_xlabel("DN") ax2.set_ylabel("Model corrected DN") ax2.set_xlim(ax1.get_xlim()) ax2.set_ylim(ax1.get_ylim()) ax2.legend(loc="upper left") ax4.set_xlabel("DN") ax4.set_ylabel("Linear model - DN") ax4.set_xlim(ax3.get_xlim()) ax4.set_ylim(ax3.get_ylim()) plt.tight_layout() plt.show()