Source code for katsdpsigproc.rfi.twodflag

################################################################################
# Copyright (c) 2014-2020, 2025, National Research Foundation (SARAO)
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use
# this file except in compliance with the License. You may obtain a copy
# of the License at
#
#   https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

"""Library to contain 2d RFI flagging routines and other RFI related functions."""

import concurrent.futures
import multiprocessing

import numba
import numpy as np

from . import MAD_NORMAL


def _as_min_dtype(value):
    """Convert a non-negative integer into a numpy scalar of the narrowest type will hold it.

    This is used because in some cases an array must be allocated of the
    same type later, and using the narrowest type saves memory in that array.
    """
    if value >= 0 and value < 2**8:
        dtype = np.uint8
    elif value >= 0 and value < 2**16:
        dtype = np.uint16
    elif value >= 0 and value < 2**32:
        dtype = np.uint32
    else:
        dtype = np.int64
    return np.array(value, dtype)


def _asbool(data):
    """Create a boolean array with the same values as `data`.

    The `data` must contain only 0's and 1's. If possible, a view is returned,
    otherwise a copy.
    """
    if data.dtype.itemsize == 1:
        return data.view(np.bool_)
    else:
        return data.astype(np.bool_)


@numba.extending.overload(_asbool, jit_options=dict(nogil=True))
def _overload_asbool(data):
    if isinstance(data.dtype, numba.types.Boolean) or (
        isinstance(data.dtype, numba.types.Integer) and data.dtype.bitwidth == 8
    ):
        return lambda data: data.view(np.bool_)
    else:
        return lambda data: data.astype(np.bool_)


@numba.jit(nopython=True, nogil=True)
def _average_freq(in_data, in_flags, factor):
    """Do several preconditioning steps.

    1. Converts complex data to real.
    2. Flags data with non-finite values.
    3. Sets the value of flagged elements to zero.
    4. Does frequency averaging by a factor of `factor`.
    5. Transposes the data ordering so that baseline is the first,
       slowest-varying axis.

    Parameters
    ----------
    in_data : ndarray, real or complex
        Visibilities or their magnitudes, with shape (time, frequency, baseline).
    in_flags : ndarray, bool
        Flags corresponding to the visibilities. This can safely be a type
        other than bool, where non-zero values indicate flagged data.
    factor : int
        Amount by which to decimate in frequency. This must be a numpy 0-d
        array (so that the dtype can be extracted).
    """
    if in_data.shape != in_flags.shape:
        raise ValueError("shape mismatch")
    n_time, n_freq, n_bl = in_data.shape
    a_freq = (n_freq + factor - 1) // factor
    out_shape = (n_bl, n_time, a_freq)
    avg_data = np.zeros(out_shape, np.float32)
    avg_weight = np.zeros(out_shape, factor.dtype)
    # TODO: might need to do this through a temporary buffer to avoid cache
    # aliasing problems.
    for i in range(n_time):
        for j in range(n_freq):
            jout = j // factor
            for k in range(n_bl):
                data = np.abs(in_data[i, j, k])
                if not in_flags[i, j, k] and not np.isnan(data):
                    avg_data[k, i, jout] += data
                    avg_weight[k, i, jout] += 1
    for i in range(n_bl):
        for j in range(n_time):
            for k in range(a_freq):
                flag = avg_weight[i, j, k] == 0
                if flag:
                    avg_data[i, j, k] = 0  # Avoid divide by zero and a NaN
                else:
                    avg_data[i, j, k] /= avg_weight[i, j, k]
                # Replace weight with flag (in-place) to save memory
                avg_weight[i, j, k] = flag
    return avg_data, _asbool(avg_weight)


@numba.jit(nopython=True, nogil=True)
def _time_median(data, flags):
    """Independently for each channel, compute the median of the unflagged values.

    If all values for a channel are flagged, 0 is used instead, and the result
    is flagged.

    The time dimension is kept in the result as a length-1 dimension.

    Parameters
    ----------
    data : ndarray, real
        Visibilities, with shape (time, frequency)
    flags : ndarray, bool
        Flags corresponding to `data`

    Returns
    -------
    out_data : ndarray, real
        Median of `data` for each frequency
    out_flags : ndarray, bool
        Flags corresponding to `out_data`
    """
    n_time, n_freq = data.shape
    values = np.empty((n_freq, n_time), data.dtype)
    counts = np.zeros(n_freq, np.uint32)
    out_data = np.empty((1, n_freq), data.dtype)
    out_flags = np.zeros((1, n_freq), np.bool_)
    for t in range(n_time):
        for f in range(n_freq):
            if not flags[t, f]:
                values[f, counts[f]] = data[t, f]
                counts[f] += 1
    for f in range(n_freq):
        if counts[f] == 0:
            out_data[0, f] = 0  # No data
            out_flags[0, f] = True
        else:
            out_data[0, f] = np.median(values[f, : counts[f]])
    return out_data, out_flags


@numba.jit(nopython=True, nogil=True)
def _median_abs(data, flags):
    """Compute median of absolute values of non-flagged values in `data`."""
    values = np.empty(data.size, data.dtype)
    n = np.int64(0)
    for idx in np.ndindex(data.shape):
        if not flags[idx]:
            values[n] = np.abs(data[idx])
            n += 1
    if n == 0:
        return np.nan
    else:
        return np.median(values[:n])


@numba.jit(nopython=True, nogil=True)
def _median_abs_axis0(data, flags):
    """Compute median of absolute values of non-flagged values in `data`, along axis 0.

    The first dimension is kept in the output as a dimension of size 1 (to
    avoid issues with numba converting 0d arrays to scalars.
    """
    values = np.empty(data.shape[1:] + data.shape[:1], data.dtype)
    counts = np.zeros(data.shape[1:], np.uint32)
    out_data = np.empty((1,) + data.shape[1:], data.dtype)
    for i in range(data.shape[0]):
        for j in np.ndindex(data.shape[1:]):
            if not flags[(i,) + j]:
                values[j + (counts[j],)] = np.abs(data[(i,) + j])
                counts[j] += 1
    for j in np.ndindex(data.shape[1:]):
        if counts[j] == 0:
            out_data[(0,) + j] = np.nan
        else:
            out_data[(0,) + j] = np.median(values[j + (slice(0, counts[j]),)])
    return out_data


@numba.jit(nopython=True, nogil=True)
def _linearly_interpolate_nans1d(data):
    """Replace NaNs in `data` by linear interpolation in-place.

    Extrapolation is done by repeating the first/last valid element.  If all
    input data are NaNs, they are all replaced by zeros.

    Parameters
    ----------
    data : ndarray, real
        Data to interpolate, 1D. It is modified in-place.
    """
    n = data.size
    # Find first valid value
    p = 0
    while p < n and np.isnan(data[p]):
        p += 1
    if p == n:
        data[:] = 0
        return
    data[:p] = data[p]  # Extrapolate backwards
    p += 1
    while p < n:
        if np.isnan(data[p]):
            # Find next valid value
            q = p + 1
            while q < n and np.isnan(data[q]):
                q += 1
            if q == n:
                data[p:] = data[p - 1]  # Extrapolate forwards
            else:
                start = data[p - 1]
                grad = (data[q] - start) / (q - (p - 1))
                for i in range(p, q):
                    data[i] = start + (i - (p - 1)) * grad
            p = q
        else:
            p += 1


@numba.jit(nopython=True, nogil=True)
def _linearly_interpolate_nans(data):
    """Replace nans in `data` by linear interpolation across frequencies.

    Extrapolation is done by repeating the first/last valid element.

    Parameters
    ----------
    data : ndarray, real
        Data to interpolate, with shape (time, frequency).
    """
    for i in range(data.shape[0]):
        _linearly_interpolate_nans1d(data[i])


@numba.jit(nopython=True, nogil=True)
def _box_gaussian_filter1d(data, r, out, passes):
    """Implement :func:`_box_gaussian_filter` along the first axis of an array.

    It is safe to use this function in-place i.e. with `out` equal to `data`.

    Parameters
    ----------
    data : ndarray, real
        Input data, with at least 1 dimension.
    r : int
        Radius of the box filter
    out : ndarray, real
        Output data, with same shape as `data`.
    passes : int
        Number of boxcar filters to apply
    """
    K = passes
    if data.shape[0] == 0 or K == 0:
        out[:] = data[:]
        return
    d = 2 * r + 1
    # Pad on left with zeros.
    padding = r * K
    # TODO: hoist memory allocations into caller
    padded = np.empty((data.shape[0] + padding,) + data.shape[1:], data.dtype)
    padded[:padding] = 0
    padded[padding:] = data
    prev_start = padding  # First element with valid data
    s = np.zeros(data.shape[1:], np.float64)
    for p in range(1, K + 1):
        # On each pass, padded[i] is replaced by the sum of padded[i : i + d]
        # from the previous pass. The accumulator is kept in double precision
        # to avoid excessive accumulation of errors.
        s[()] = 0
        start = padding - 2 * r * p
        stop = start + data.shape[0] + 2 * padding
        start = max(start, 0)
        stop = min(stop, padded.shape[0])
        tail = min(stop, padded.shape[0] - 2 * r)
        for i in range(prev_start, min(start + 2 * r, padded.shape[0])):
            s += padded[i]
        for i in range(start, tail):
            for j in np.ndindex(data.shape[1:]):
                s[j] += padded[(i + 2 * r,) + j]
                prev = padded[(i,) + j]
                padded[(i,) + j] = s[j]
                s[j] -= prev
        for i in range(tail, stop):
            for j in np.ndindex(data.shape[1:]):
                prev = padded[(i,) + j]
                padded[(i,) + j] = s[j]
                s[j] -= prev
        prev_start = start
    for idx in np.ndindex(out.shape):
        out[idx] = padded[idx] / data.dtype.type(d) ** K


@numba.jit(nopython=True, nogil=True)
def _box_gaussian_filter(data, sigma, out, passes=4):
    """Filter `data` with an approximate Gaussian filter.

    The filter is based on repeated filtering with a boxcar function. See
    [Get13]_ for details. It has finite support. Values outside the boundary
    are taken as zero.

    This function is not suitable when the input contains non-finite values,
    or very large variations in magnitude, as it internally computes a rolling
    sum. It also quantizes the requested sigma.

    .. [Get13] Pascal Getreuer, A Survey of Gaussian Convolution Algorithms,
       Image Processing On Line, 3 (2013), pp. 286-310.

    Parameters
    ----------
    data : ndarray
        Input data to filter (2D)
    sigma : ndarray
        Standard deviation of the Gaussian filter, per axis
    out : ndarray
        Output data, with the same shape as the input
    passes : int
        Number of boxcar filters to apply
    """
    if len(sigma) != data.ndim:
        raise ValueError("sigma has wrong number of elements")
    assert data.ndim == 2
    r = (0.5 * np.sqrt(12.0 * sigma**2 / passes + 1)).astype(np.int_)
    need_copy = True
    if r[0] > 0:
        # Process chunks of columns. See _sum_threshold for explanation.
        step = 256
        for i in range(0, data.shape[1], step):
            sub = slice(i, min(i + step, data.shape[1]))
            _box_gaussian_filter1d(data[:, sub], r[0], out[:, sub], passes)
        data = out  # Use out in next step
        need_copy = False
    if r[1] > 0:
        for i in range(data.shape[0]):
            _box_gaussian_filter1d(data[i], r[1], out[i], passes)
        need_copy = False
    if need_copy:
        out[:] = data[:]


[docs]@numba.jit(nopython=True, nogil=True) def masked_gaussian_filter(data, flags, sigma, out, passes=4): """Filter an image using an approximate Gaussian filter. Some values may be flagged and are ignored. Values outside the grid are also treated as if flagged. See :func:`_box_gaussian_filter` for a number of caveats. The result may contain non-finite values where the finite support of the Gaussian approximation contains no values without flags. Parameters ---------- data : ndarray, 2D Input data to filter flags : ndarray, bool True values correspond to elements of `data` to be ignored sigma : float or sequence of floats Standard deviation of the Gaussian filter, per axis out : ndarray, 2D Output array, with same shape as `data` passes : int Number of boxcar filters to apply """ if data.shape != flags.shape: raise ValueError("shape mismatch between data and flags") if data.shape != out.shape: raise ValueError("shape mismatch between data and out") weight = np.empty_like(data) for idx in np.ndindex(data.shape): weight[idx] = not flags[idx] out[idx] = 0 if flags[idx] else data[idx] _box_gaussian_filter(weight, sigma, weight, passes=passes) _box_gaussian_filter(out, sigma, out, passes=passes) for idx in np.ndindex(out.shape): # Numeric instability can make out non-zero (but tiny) even # where filtered_weight is zero, which would make the ratio +/-inf # rather than NaN. Set to NaN explicitly in this case. if weight[idx] == 0: out[idx] = np.nan else: out[idx] /= weight[idx]
@numba.jit(nopython=True, nogil=True) def _get_background2d(data, flags, iterations, spike_width, reject_threshold, freq_chunk_ends): """Determine a smooth background over a 2D array. This is done by iteratively convolving the data with elliptical Gaussians with linearly decreasing width from `iterations`*`spike_width` down to `spike width`. Outliers greater than `reject_threshold`*sigma from the background are masked on each iteration. Initial weights are set to zero at positions specified in `in_flags` if given. After the final iteration a final Gaussian smoothed background is computed and any stray NaNs in the background are interpolated in frequency (axis 1) for each timestamp (axis 0). The NaNs can appear when the the convolving Gaussian is completely covering masked data as the sum of convolved weights will be zero. Parameters ---------- data : 2D ndarray, float The input data array to be smoothed, with shape (time, frequency). flags : 2D ndarray, boolean Flags corresponding to `data` iterations : int Number of iterations of Gaussian filtering spike_width : ndarray, float Two-element array containing the 1-sigma radius of the Gaussian filter in each axis. reject_threshold : float Number of standard deviations above which to flag data. freq_chunk_ends : ndarray, float Endpoints of intervals in which to compute noise estimates independently. This array must start with 0 and end of the number of channels, and be strictly increasing. """ n_time, n_freq = data.shape flags = flags.copy() # Gets modified background = np.empty_like(data) for extend_factor in range(iterations, 0, -1): sigma = extend_factor * spike_width masked_gaussian_filter(data, flags, sigma, background) for c in range(freq_chunk_ends.size - 1): sub = (slice(None, None), slice(freq_chunk_ends[c], freq_chunk_ends[c + 1])) sub_data = data[sub] sub_flags = flags[sub] # Convert background to an absolute value residual, in-place sub_residual = background[sub] for t in range(n_time): for f in range(sub_data.shape[1]): sub_residual[t, f] = np.abs(sub_data[t, f] - sub_residual[t, f]) threshold = _median_abs(sub_residual, sub_flags) threshold *= MAD_NORMAL * reject_threshold for t in range(n_time): for f in range(sub_data.shape[1]): # sub_residual can contain NaNs, but only where the flags already apply if sub_residual[t, f] > threshold: sub_flags[t, f] = True # Compute final background masked_gaussian_filter(data, flags, spike_width, background) # Remove NaNs via linear interpolation _linearly_interpolate_nans(background) return background @numba.jit(nopython=True, nogil=True) def _convolve_flags(in_values, scale, threshold, out_flags, window): """Flag values with a threshold, and smear the flags. This is rolled into a single function for efficient implementation, but there are logically several steps: - For each value v in `in_values`, flag it if ``v * scale > threshold``. - Convolve the flags by a box filter of size `window`, expanding the width. - Logical OR these new flags into `out_flags`. """ cum_size = in_values.shape[0] + 2 * window - 1 # TODO: could preallocate this externally cum = np.empty((cum_size,) + (in_values.shape[1:]), np.uint32) # Cumulative flagged values cum[:window] = 0 for i in range(in_values.shape[0]): for j in np.ndindex(in_values.shape[1:]): flag = in_values[(i,) + j] * scale > threshold[(0,) + j] cum[(window + i,) + j] = cum[(window + i - 1,) + j] + flag # numba doesn't seem to fully support negative indices, hence # the addition of cum_size. cum[cum_size - (window - 1) :] = cum[cum_size - window] for i in range(out_flags.shape[0]): for j in np.ndindex(out_flags.shape[1:]): out_flags[(i,) + j] |= cum[(i + window,) + j] - cum[(i,) + j] != 0 @numba.jit(nopython=True, nogil=True) def _sum_threshold1d(input_data, input_flags, output_flags, windows, outlier_nsigma, rho, chunks): """Implement :func:`_sum_threshold`. It operates along the first axis. """ for ci in range(chunks.size - 1): chunk_slice = slice(chunks[ci], chunks[ci + 1]) chunk_data = input_data[chunk_slice] chunk_flags = input_flags[chunk_slice] # Get standard deviation using MAD and set up initial threshold threshold = _median_abs_axis0(chunk_data, chunk_flags) threshold_scale = outlier_nsigma * MAD_NORMAL for idx in np.ndindex(threshold.shape): if np.isnan(threshold[idx]): threshold[idx] = np.inf else: threshold[idx] *= threshold_scale padded_slice = slice( max(chunks[ci] - np.max(windows) + 1, 0), min(chunks[ci + 1] + np.max(windows) - 1, input_data.size), ) padded_data = input_data[padded_slice] # TODO: can pre-allocate these outside the loop (but will need resizing) output_flags_pos = np.zeros(padded_data.shape, np.bool_) output_flags_neg = np.zeros(padded_data.shape, np.bool_) for window in windows: # The threshold for this iteration is calculated from the initial threshold # using the equation from Offringa (2010). tf = pow(rho, np.log2(window)) # Get the thresholds thisthreshold = threshold / tf # Set already flagged values to be the +/- value of the # threshold if they are outside the threshold, and take # a cumulative sum. cum_data = np.empty((padded_data.shape[0] + 1,) + padded_data.shape[1:], np.float64) cum_data[0] = 0 for i in range(padded_data.shape[0]): for j in np.ndindex(padded_data.shape[1:]): idx = (i,) + j clamped = padded_data[idx] limit = thisthreshold[(0,) + j] if output_flags_pos[idx] and clamped > limit: clamped = limit elif output_flags_neg[idx] and clamped < -limit: clamped = -limit cum_data[(i + 1,) + j] = cum_data[idx] + clamped # Calculate a rolling sum array from the data with the window for this iteration, # which is later scaled by rolliing_scale to give the rolling average. avgarray = cum_data[window:] - cum_data[:-window] rolling_scale = np.float32(1.0 / window) # Work out the flags from the average data above the current threshold, # convolve them, and combine with current flags. _convolve_flags(avgarray, rolling_scale, thisthreshold, output_flags_pos, window) # Work out the flags from the average data below the current threshold, # convolve them, and OR with current flags. _convolve_flags(avgarray, -rolling_scale, thisthreshold, output_flags_neg, window) # Extract just the portion of output_flags_pos/neg corresponding to the # chunk itself, without the padding rel_slice = slice( chunk_slice.start - padded_slice.start, chunk_slice.stop - padded_slice.start, ) output_flags[chunk_slice] = output_flags_pos[rel_slice] | output_flags_neg[rel_slice] @numba.jit(nopython=True, nogil=True) def _sum_threshold(input_data, input_flags, axis, windows, outlier_nsigma, rho, chunks=None): """Apply the SumThreshold method along the given axis of `input_data`. Parameters ---------- input_data : ndarray, real Deviations from the background. The implementation is optimised for 2D (and does not currently work in 1D), but higher dimensions are supported. input_flags : ndarray, bool Input flags. Used as a mask when computing the initial standard deviations of the input data. axis : int The axis on which to apply the SumThreshold operation. In the current implementation, must be 0 or 1. windows : ndarray, int Window sizes to average data in each SumThreshold step outlier_nsigma : float Number of standard deviations at which to flag rho : float Parameter controlling the relationship between threshold and window size chunks : ndarray, int Boundaries between chunks in which each chunk has a separate noise estimation. This array must start with 0, be strictly increasing, and end with ``input_data.shape[axis]``. Returns ------- output_flags : ndarray, bool The derived flags """ if chunks is None: chunks = np.array([0, input_data.shape[axis]]) output_flags = np.empty(input_data.shape, np.bool_) if axis < 0 or axis >= input_data.ndim: raise ValueError("axis is out of range") elif axis == 1: for i in range(input_data.shape[0]): _sum_threshold1d( input_data[i], input_flags[i], output_flags[i], windows, outlier_nsigma, rho, chunks, ) elif axis == 0: # The operation is independent of the other dimensions, but we process # them in chunks to be cache friendly. The step size should be big # enough that whole cache lines are used (even if the alignment is # poor), but small enough that multiple rows fit in L1. This heuristic # value assumes a 2D input. step = 256 for i in range(0, input_data.shape[1], step): sub = slice(i, min(i + step, input_data.shape[1])) _sum_threshold1d( input_data[:, sub], input_flags[:, sub], output_flags[:, sub], windows, outlier_nsigma, rho, chunks, ) else: raise ValueError("axis must be 0 or 1") return output_flags @numba.jit(nopython=True, nogil=True, cache=True) def _get_flags_impl( in_data, in_flags, out_flags, outlier_nsigma, windows_time, windows_freq, background_reject, background_iterations, spike_width_time, spike_width_freq, time_extend, freq_extend, freq_chunk_ends, average_freq, flag_all_time_frac, flag_all_freq_frac, rho, ): n_time, n_freq, n_bl = in_data.shape # Average `in_data` in frequency. This is done unconditionally, because it # also does other useful steps (see the documentation). data, flags = _average_freq(in_data, in_flags, average_freq) # Output flags, in baseline-major order tmp_flags = np.empty((n_bl, n_time, n_freq), np.bool_) # Do operations independently per baseline. for bl in range(data.shape[0]): _get_baseline_flags( data[bl], flags[bl], tmp_flags[bl], outlier_nsigma, windows_time, windows_freq, background_reject, background_iterations, spike_width_time, spike_width_freq, time_extend, freq_extend, freq_chunk_ends, average_freq, flag_all_time_frac, flag_all_freq_frac, rho, ) # Transpose the output flags and explicitly flag nans from input for t in range(n_time): for f in range(n_freq): for bl in range(n_bl): out_flags[t, f, bl] = tmp_flags[bl, t, f] or np.isnan(in_data[t, f, bl]) @numba.jit(nopython=True, nogil=True) def _combine_flags(spec_flags, time_flags, freq_flags, time_extend, out): """Combine several sources of flags and smear them in time. Parameters ---------- spec_flags : 1D ndarray, bool Flags with shape (frequency) time_flags, freq_flags : 2D ndarray, bool Flags with shape (time, frequency) time_extend : int Width of the convolution kernel for time smearing (should be odd) out : 2D ndarray, bool Output flags """ n_time, n_freq = time_flags.shape # Combine spec_flags, time_flags and freq_flags, and take a cumulative sum # along the time axis for the purposes of convolution. flag_sum = np.empty((n_time + 1, n_freq), time_extend.dtype) flag_sum[0] = 0 for t in range(n_time): for f in range(n_freq): flag = spec_flags[0, f] or time_flags[t, f] or freq_flags[t, f] flag_sum[t + 1, f] = flag_sum[t, f] + flag # Difference the cumulative sums to get time-smeared flags. time_delta_lo = -(time_extend // 2) time_delta_hi = time_delta_lo + time_extend for t in range(n_time): # Rows to difference, clamping to the data limits t0 = max(t + time_delta_lo, 0) t1 = min(t + time_delta_hi, n_time) for f in range(n_freq): out[t, f] = flag_sum[t0, f] != flag_sum[t1, f] @numba.jit(nopython=True, nogil=True) def _unaverage_freq(flags, freq_extend, average_freq, flag_all_time_frac, flag_all_freq_frac, out): """Perform final processing for a single baseline. 1. Flags are replicated to undo the effect of frequency averaging. 2. Flags are smeared, using a kernel of width `flag_all_freq_frac`. 3. Times and frequencies where more than `flag_all_freq_frac` or `flag_all_time_frac` of the values are already flagged become fully flagged. """ # Frequency replication and smearing n_time, n_freq = flags.shape orig_freq = out.shape[-1] flag_sum = np.empty(orig_freq + 1, np.int32) flag_sum_time = np.zeros(orig_freq, np.int32) freq_delta_lo = -(freq_extend // 2) freq_delta_hi = freq_delta_lo + freq_extend for t in range(n_time): flag_sum[0] = 0 for f in range(orig_freq): flag_sum[f + 1] = flag_sum[f] + flags[t, f // average_freq] # Take differences of the cumulative sums to get smearing tot = 0 for f in range(orig_freq): f0 = max(f + freq_delta_lo, 0) f1 = min(f + freq_delta_hi, orig_freq) flag = flag_sum[f1] != flag_sum[f0] out[t, f] = flag tot += flag flag_sum_time[f] += flag # If too much is flagged, flag the entire time if tot > flag_all_freq_frac * orig_freq: out[t, :] = True # Flag all times if too much is flagged. This should be rare, so we # write in columns even though that's normally an unfriendly access # pattern. for f in range(orig_freq): if flag_sum_time[f] > n_time * flag_all_time_frac: out[:, f] = True @numba.jit(nopython=True, nogil=True) def _get_baseline_flags( data, flags, out_flags, outlier_nsigma, windows_time, windows_freq, background_reject, background_iterations, spike_width_time, spike_width_freq, time_extend, freq_extend, freq_chunk_ends, average_freq, flag_all_time_frac, flag_all_freq_frac, rho, ): """Compute flags for a single baseline. It is called after frequency averaging, but writes back un-averaged results. Parameters ---------- data : ndarray, real Visibility magnitudes, with shape (time, frequency) flags : ndarray, bool User-input flags corresponding to `data` out_flags : ndarray, bool Returned flags (which will have more channels than `data` if `average_freq` is greater than 1). outlier_nsigma : float Number of standard deviations at which to flag windows_time : array, int Size of averaging windows to use in the SumThreshold method in time windows_freq : array, int Size of averaging windows to use in the SumThreshold method in frequency background_reject : float Number of sigma to reject outliers when backgrounding background_iterations : int Number of iterations to use when determining a smooth background, after each iteration data in excess of `background_reject`*`sigma` are masked spike_width_time : float Characteristic width in dumps to smooth over when backgrounding. This is the one-sigma width of the convolving Gaussian in axis 0. spike_width_freq : float Characteristic width in channels to smooth over when backgrounding. This is the one-sigma width of the convolving Gaussian in axis 1. time_extend : int Size of kernel in time to convolve with flags after detection freq_extend : int Size of kernel in frequency to convolve with flags after detection freq_chunk_ends : ndarray, float Endpoints of intervals in which to compute noise estimates independently. This array must start with 0 and end of the number of channels, and be strictly increasing. average_freq : int Number of channels to average frequency before flagging. Flags will be extended to the frequency shape of the input data before being returned flag_all_time_frac : float Fraction of data flagged above which to extend flags to all data in time axis. flag_all_freq_frac : float Fraction of data flagged above which to extend flags to all data in frequency axis. rho : float Parameter controlling the relationship between threshold and window size """ n_time, n_freq = data.shape # Generate median spectrum, background it, and flag it spec_data, spec_flags = _time_median(data, flags) spec_background = _get_background2d( spec_data, spec_flags, background_iterations, np.array((0.0, spike_width_freq)), background_reject, freq_chunk_ends, ) spec_data -= spec_background spec_flags = _sum_threshold( spec_data, spec_flags, 1, windows_freq, outlier_nsigma, rho, freq_chunk_ends ) # Broadcast spectral flags to per-timestamp flags |= spec_flags # Get and subtract 2D background background = _get_background2d( data, flags, background_iterations, np.array((spike_width_time, spike_width_freq)), background_reject, freq_chunk_ends, ) data -= background # SumThreshold along time axis time_flags = _sum_threshold(data, flags, 0, windows_time, outlier_nsigma, rho) # SumThreshold along frequency axis - with time flags in the input flags flags |= time_flags freq_flags = _sum_threshold(data, flags, 1, windows_freq, outlier_nsigma, rho, freq_chunk_ends) # Combine flag sources and do time smearing. We overwrite 'flags' since the # previous result is no longer needed. _combine_flags(spec_flags, time_flags, freq_flags, time_extend, flags) _unaverage_freq( flags, freq_extend, average_freq, flag_all_time_frac, flag_all_freq_frac, out_flags, ) def _get_flags_mp(in_data, in_flags, flagger): """Callback function for ProcessPoolExecutor. It allocates its own storage for the output. """ out_flags = np.empty_like(in_flags) flagger._get_flags(in_data, in_flags, out_flags) return out_flags
[docs]class SumThresholdFlagger: """Flagger that detects spikes in both frequency and time axes. It uses the SumThreshold method (Offringa, A., MNRAS, 405, 155-167, 2010). The full algorithm does the following: 1. Average the data in the frequency dimension (axis 1) into bins of size `self.average_freq` 2. Divide the data into overlapping sub-chunks in frequency which are backgrounded and thresholded independently 3. Flag a 1d spectrum median filtered in time to get fainter contaminated channels. 4. Derive a smooth 2d background through each chunk 5. SumThreshold the background subtracted chunks in time and frequency 6. Extend derived flags in time and frequency, via self.freq_extend and self.time_extend 7. Extend flags to all times and frequencies in cases when more than a given fraction of samples are flagged (via `self.flag_all_time_frac` and `self.flag_all_freq_frac`) Parameters ---------- outlier_nsigma : float Number of sigma to reject outliers when thresholding windows_time : array, int Size of averaging windows to use in the SumThreshold method in time windows_freq : array, int Size of averaging windows to use in the SumThreshold method in frequency background_reject : float Number of sigma to reject outliers when backgrounding background_iterations : int Number of iterations to use when determining a smooth background, after each iteration data in excess of `background_reject`*`sigma` are masked spike_width_time : float Characteristic width in dumps to smooth over when backgrounding. This is the one-sigma width of the convolving Gaussian in axis 0. spike_width_freq : float Characteristic width in channels to smooth over when backgrounding. This is the one-sigma width of the convolving Gaussian in axis 1. time_extend : int Size of kernel in time to convolve with flags after detection freq_extend : int Size of kernel in frequency to convolve with flags after detection freq_chunks : int Number of equal-sized chunks to independently flag in frequency. Smaller chunks will be less affected by variations in the band in the frequency domain. average_freq : int Number of channels to average frequency before flagging. Flags will be extended to the frequency shape of the input data before being returned flag_all_time_frac : float Fraction of data flagged above which to extend flags to all data in time axis. flag_all_freq_frac : float Fraction of data flagged above which to extend flags to all data in frequency axis. rho : float Falloff exponent for SumThreshold """ def __init__( self, outlier_nsigma=4.5, windows_time=[1, 2, 4, 8], windows_freq=[1, 2, 4, 8], background_reject=2.0, background_iterations=1, spike_width_time=12.5, spike_width_freq=10.0, time_extend=3, freq_extend=3, freq_chunks=10, average_freq=1, flag_all_time_frac=0.6, flag_all_freq_frac=0.8, rho=1.3, ): self.outlier_nsigma = outlier_nsigma self.windows_time = windows_time # Scale the frequency windows, and remove possible duplicates windows_freq = np.ceil(np.array(windows_freq, dtype=np.float32) / average_freq) self.windows_freq = np.unique(windows_freq.astype(np.int_)) self.background_reject = background_reject self.background_iterations = background_iterations self.spike_width_time = spike_width_time # Scale spike_width by average_freq self.spike_width_freq = spike_width_freq / average_freq self.time_extend = _as_min_dtype(time_extend) self.freq_extend = _as_min_dtype(freq_extend) self.freq_chunks = freq_chunks self.average_freq = _as_min_dtype(average_freq) self.flag_all_time_frac = flag_all_time_frac self.flag_all_freq_frac = flag_all_freq_frac self.rho = rho def _get_flags(self, in_data, in_flags, out_flags): """Flag a batch of baselines. The batches are doled out by :meth:`get_flags`, either to an executor pool or directly. The batching is important because it affects memory access patterns. The batch size should not be too large, as otherwise it will overload the cache. This function is the interface between Python code and numba code, and takes care of conditioning the parameters into a form that the numba code can consume. All the actual work is done in :func:`_get_flags_impl`. """ average_freq = int(self.average_freq) # Avoid numpy overflow errors averaged_channels = (in_data.shape[1] + average_freq - 1) // average_freq # Set up frequency chunks freq_chunk_ends = np.linspace(0, averaged_channels, self.freq_chunks + 1).astype(np.int_) # Clip the windows to the available time and frequency range windows_time = np.array([w for w in self.windows_time if w <= in_data.shape[1]], np.int_) windows_freq = np.array([w for w in self.windows_freq if w <= averaged_channels], np.int_) _get_flags_impl( in_data, in_flags, out_flags, self.outlier_nsigma, windows_time, windows_freq, self.background_reject, self.background_iterations, self.spike_width_time, self.spike_width_freq, self.time_extend, self.freq_extend, freq_chunk_ends, self.average_freq, self.flag_all_time_frac, self.flag_all_freq_frac, self.rho, )
[docs] def get_flags(self, data, flags, pool=None, chunk_size=None, is_multiprocess=None): """Compute flags in data array. It has optional input `flags` of the same shape that denote samples in `data` to ignore when backgrounding and deriving thresholds. This can run in parallel if given a :class:`concurrent.futures.Executor`. Performance is generally better with a :class:`~concurrent.futures.ThreadPoolExecutor`. While a :class:`~concurrent.futures.ProcessPoolExecutor` is supported, it is usually limited by the speed at which the data can be pickled and transferred to the other processes. Parameters ---------- data : 3D array The input visibility data, in (time, frequency, baseline) order. It may also contain just the magnitudes. flags : 3D array, boolean Input flags. pool : :class:`concurrent.futures.Executor`, optional Worker pool for parallel computation. If not specified, computation will be done serially. chunk_size : int, optional Number of baselines to process at a time. If not specified, heuristics are used to pick a reasonable value. Values above 16 give diminishing returns and much larger values may actually reduce performance. Power-of-two sizes are likely to perform best. is_multiprocess : bool, optional If `pool` behaves like :class:`concurrent.futures.ProcessPoolExecutor` (in particular, if it makes copies of its arguments) then this must be set to ``True`` to invoke a slower path that ensures that results are returned and reassembled. If unspecified, it defaults to true for :class:`concurrent.futures.ProcessPoolExecutor` and false for all other types. Thus, it only needs to be specified when using an object that isn't a :class:`concurrent.futures.ProcessPoolExecutor` but behaves like one. Returns ------- out_flags : 3D array, boolean, same shape as `data` Derived flags (True=flagged) """ if data.shape != flags.shape: raise ValueError("Shape mismatch") if data.ndim != 3: raise ValueError("data has wrong number of dimensions") out_flags = np.empty(flags.shape, np.bool_) n_bl = data.shape[-1] if not chunk_size: chunk_size = 16 if pool is not None: # Make sure there is enough parallelism. There is no way to # query the number of workers in a pool, so we'll just assume # it is equal to cpu_count. We want at least 4 tasks per CPU # to avoid load imbalances. workers = multiprocessing.cpu_count() while chunk_size > 1 and chunk_size * workers * 4 > n_bl: chunk_size //= 2 if pool is not None and is_multiprocess is None: is_multiprocess = isinstance(pool, concurrent.futures.ProcessPoolExecutor) futures = [] outputs = {} try: for i in range(0, n_bl, chunk_size): chunk_data = data[..., i : i + chunk_size] chunk_flags = flags[..., i : i + chunk_size] chunk_out = out_flags[..., i : i + chunk_size] if pool is not None and is_multiprocess: future = pool.submit(_get_flags_mp, chunk_data, chunk_flags, self) outputs[future] = chunk_out futures.append(future) elif pool is not None: futures.append(pool.submit(self._get_flags, chunk_data, chunk_flags, chunk_out)) else: self._get_flags(chunk_data, chunk_flags, chunk_out) # Wait for all the futures to complete, and raise any exception. # In multiprocessing mode, copy results back. for future in concurrent.futures.as_completed(futures): result = future.result() if is_multiprocess: outputs[future][:] = result return out_flags finally: # If there's an exception, stop any work we can for future in futures: future.cancel()