Source code for ewoksid02.utils.secondaryscattering

import time
from typing import Optional, Tuple

import numexpr
import numpy
from scipy.signal import convolve as oaconvolve_numpy

from .io import get_persistent_array_window_wagon, get_persistent_array_mask

try:
    import cupy
    from cupyx.scipy.signal import oaconvolve as oaconvolve_cupy
    from .cupyutils import log_allocated_gpu_memory
except ImportError:
    CUPY_AVAILABLE = False
    CUPY_MEM_POOL = None
    cupy = numpy
else:
    CUPY_AVAILABLE = True
    CUPY_MEM_POOL = cupy.get_default_memory_pool()
import logging

from .caving import (
    _process_data_caving_cupy,
    process_data_caving,
    _mask_caving,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


[docs] def shift_window( array_window: numpy.ndarray, cx: int, cy: int, ) -> Tuple[numpy.ndarray, numpy.ndarray]: """ Shift the window data to align with the specified center coordinates and apply a mask. Inputs: - array_window (numpy.ndarray): The input window data to be shifted. - cx (int): X-coordinate of the direct beam. - cy (int): Y-coordinate of the direct beam. Outputs: Tuple[numpy.ndarray, numpy.ndarray]: The shifted window data and the mask. """ y_size, x_size = array_window.shape y, x = numpy.meshgrid(numpy.arange(y_size), numpy.arange(x_size), indexing="ij") y_center_mask, x_center_mask = numpy.unravel_index( numpy.argmax(array_window), shape=array_window.shape ) deltax = int(cx) - x_center_mask deltay = int(cy) - y_center_mask x_clipped = numpy.clip(x - deltax, 0, x_size - 5) y_clipped = numpy.clip(y - deltay, 0, y_size - 5) window_shifted = array_window[y_clipped, x_clipped] x_shifted = x - deltax y_shifted = y - deltay mask_x = numpy.logical_and(0.0 < x_shifted, x_shifted < x.max()) mask_y = numpy.logical_and(0.0 < y_shifted, y_shifted < y.max()) mask = mask_x & mask_y # Shifting and clipping mask are ok (in 4 shift directions) return window_shifted, ~mask
def _process_dataset_2scat_numpy( dataset_signal: numpy.ndarray, # 3dim array_window: numpy.ndarray, # Original, non-shifted Center_1: int, Center_2: int, WindowRoiSize: int = 120, Dummy: Optional[int] = -10, dataset_variance: Optional[numpy.ndarray] = None, clip_data: bool = True, use_numexpr: bool = False, pre_caving: bool = False, log: bool = False, **kwargs: Optional[dict], ) -> Tuple[Optional[numpy.ndarray], Optional[numpy.ndarray]]: if log: logger.info("Using numpy for secondary scattering correction") t0_ = time.perf_counter() Center_1 = int(Center_1) Center_2 = int(Center_2) if pre_caving: params_caving_numpy = { "Center_1": Center_1, "Center_2": Center_2, "Dummy": Dummy, "algorithm": "numpy", "return_mask": False, "log": log, **kwargs, } dataset_signal = process_data_caving( data=dataset_signal, **params_caving_numpy, ) if dataset_variance is not None: dataset_variance = process_data_caving( data=dataset_variance, **params_caving_numpy, ) # 1) Slice the original data around the center, this will be the convolution kernel WindowRoiSize = int(WindowRoiSize) t0 = time.perf_counter() subdataset_signal = dataset_signal[ :, Center_2 - WindowRoiSize : Center_2 + WindowRoiSize, Center_1 - WindowRoiSize : Center_1 + WindowRoiSize, ].copy() t1 = time.perf_counter() # 2) Cover the dummy values in the subdataset (dummy values jeopardize the convolution) numpy.copyto(subdataset_signal, 0.0, where=subdataset_signal == Dummy) t2 = time.perf_counter() # 3) Shift the window array_window_shifted, mask_clip = shift_window( array_window=array_window, cx=Center_1, cy=Center_2, ) t3 = time.perf_counter() # 3) Perform the convolution all across the dataset (3-dimensional) signal_2scat = numpy.array( [ oaconvolve_numpy(array_window_shifted, subdata, mode="same") for subdata in subdataset_signal ] ) t4 = time.perf_counter() # 4) Calculate the corrected signal, variance and sigma if use_numexpr: dataset_signal_corrected = numexpr.evaluate( "where(dataset_signal, Dummy, dataset_signal - signal_2scat)" ) if dataset_variance is not None: dataset_variance_corrected = numexpr.evaluate( "where(dataset_variance, Dummy, dataset_variance + signal_2scat + 0.0)" ) dataset_sigma_corrected = numexpr.evaluate( "where(dataset_variance_corrected, Dummy, sqrt(dataset_variance_corrected))" ) else: dataset_variance_corrected = None dataset_sigma_corrected = None else: dataset_signal_corrected = numpy.where( dataset_signal == Dummy, dataset_signal, dataset_signal - signal_2scat ) if dataset_variance is not None: dataset_variance_corrected = numpy.where( dataset_variance == Dummy, dataset_variance, dataset_variance + signal_2scat + 0.0, ) dataset_sigma_corrected = numpy.where( dataset_variance == Dummy, dataset_variance, numpy.sqrt(dataset_variance_corrected), ) else: dataset_variance_corrected = None dataset_sigma_corrected = None t5 = time.perf_counter() # 5) Clip the data that could not be corrected if clip_data: if use_numexpr: dataset_signal_corrected = numexpr.evaluate( "where(mask_clip, Dummy, dataset_signal_corrected)" ) if dataset_variance_corrected is not None: dataset_variance_corrected = numexpr.evaluate( "where(mask_clip, Dummy, dataset_variance_corrected)" ) if dataset_sigma_corrected is not None: dataset_sigma_corrected = numexpr.evaluate( "where(mask_clip, Dummy, dataset_sigma_corrected)" ) else: numpy.copyto(dataset_signal_corrected, Dummy, where=mask_clip) if dataset_variance_corrected is not None: numpy.copyto(dataset_variance_corrected, Dummy, where=mask_clip) if dataset_sigma_corrected is not None: numpy.copyto(dataset_sigma_corrected, Dummy, where=mask_clip) t6 = time.perf_counter() if log: nb_frames = len(dataset_signal) logger.info( f" 1) Subdata slicing per frame: {(t1 - t0) / nb_frames * 1000:.4f} ms" ) logger.info( f" 2) Mask subdata per frame shifting: {(t2 - t1) / nb_frames * 1000:.4f} ms" ) logger.info(f" 3) Window shifting: {(t3 - t2) * 1000:.4f} ms") logger.info( f" 4) Convolution per frame: {(t4 - t3) / nb_frames * 1000:.4f} ms" ) logger.info( f" 5) Correction calculation per frame: {(t5 - t4) / nb_frames * 1000:.4f} ms" ) logger.info( f" 6) Data clipping per frame: {(t6 - t5) / nb_frames * 1000:.4f} ms" ) logger.info( f" 7) Total 2scat per frame: {(t6 - t0) / nb_frames * 1000:.4f} ms" ) logger.info( f"Total time 2scat+cave per frame: {(t6 - t0_) / nb_frames*1000:.4f} ms" ) return ( dataset_signal_corrected, dataset_variance_corrected, dataset_sigma_corrected, signal_2scat, ) def _process_dataset_2scat_cupy( dataset_signal: numpy.ndarray, # 3dim array_window: numpy.ndarray, # Original, non-shifted Center_1: int, Center_2: int, WindowRoiSize: int = 120, Dummy: Optional[int] = -10, dataset_variance: Optional[numpy.ndarray] = None, clip_data: bool = True, pre_caving: bool = False, filename_mask_static: str = None, filename_mask_reference: str = None, flip_caving: bool = False, flip_horizontally_preference: bool = True, **kwargs: Optional[dict], ) -> Tuple[Optional[numpy.ndarray], Optional[numpy.ndarray]]: log_allocated_gpu_memory() Center_1 = int(Center_1) Center_2 = int(Center_2) WindowRoiSize = int(WindowRoiSize) if pre_caving: data_shape = dataset_signal.shape[1:] binning = kwargs.get("binning") y_vector, x_vector = numpy.meshgrid( numpy.arange(data_shape[0]), numpy.arange(data_shape[1]), indexing="ij", sparse=True, ) x_shifted = 2 * int(Center_1) - x_vector + 1 y_shifted = 2 * int(Center_2) - y_vector + 1 x_shifted_cupy = cupy.asarray(x_shifted) y_shifted_cupy = cupy.asarray(y_shifted) x_vector_cupy = cupy.asarray(x_vector) y_vector_cupy = cupy.asarray(y_vector) mask_static_cupy = None if filename_mask_static: mask_static_cupy = get_persistent_array_mask( filename_mask=filename_mask_static, data_signal_shape=data_shape, binning=binning, use_cupy=True, ) mask_reference = None if filename_mask_reference: mask_reference = get_persistent_array_mask( filename_mask=filename_mask_reference, data_signal_shape=data_shape, binning=binning, use_cupy=False, # We don't use Cupy here ) mask_pixels_available_centrosymmetric_cupy = _mask_caving( data_shape, Center_1, Center_2, mask_reference, vertical_symmetry=True, horizontal_symmetry=True, use_cupy=True, ) if flip_caving: mask_pixels_available_horizontal_cupy = _mask_caving( data_shape, Center_1, Center_2, mask_reference, vertical_symmetry=False, horizontal_symmetry=True, use_cupy=True, ) mask_pixels_available_vertical_cupy = _mask_caving( data_shape, Center_1, Center_2, mask_reference, vertical_symmetry=True, horizontal_symmetry=False, use_cupy=True, ) else: mask_pixels_available_horizontal_cupy = None mask_pixels_available_vertical_cupy = None params_caving_cupy = { "mask_pixels_available_centrosymmetric_cupy": mask_pixels_available_centrosymmetric_cupy, "x_shifted_cupy": x_shifted_cupy, "y_shifted_cupy": y_shifted_cupy, "Dummy": Dummy, "mask_static_cupy": mask_static_cupy, "flip_caving": flip_caving, "flip_horizontally_preference": flip_horizontally_preference, "mask_pixels_available_horizontal_cupy": mask_pixels_available_horizontal_cupy, "mask_pixels_available_vertical_cupy": mask_pixels_available_vertical_cupy, "x_vector_cupy": x_vector_cupy, "y_vector_cupy": y_vector_cupy, } array_window_shifted, mask_clip = shift_window( array_window=array_window, cx=Center_1, cy=Center_2, ) params_2scat_cupy = { "array_window_cupy": cupy.asarray(array_window_shifted), "Center_1": Center_1, "Center_2": Center_2, "WindowRoiSize": WindowRoiSize, "Dummy": Dummy, "clip_data": clip_data, "mask_clip_cupy": cupy.asarray(mask_clip), } dataset_signal_corrected = numpy.zeros_like(dataset_signal) dataset_signal_2scat = numpy.zeros_like(dataset_signal) dataset_variance_corrected = None dataset_sigma_corrected = None if dataset_variance is not None: dataset_variance_corrected = numpy.zeros_like(dataset_variance) dataset_sigma_corrected = numpy.zeros_like(dataset_variance) for index_frame, data_signal in enumerate(dataset_signal): data_signal_cupy = cupy.asarray(data_signal) if dataset_variance is not None: data_variance_cupy = cupy.asarray(dataset_variance[index_frame]) else: data_variance_cupy = None if pre_caving: data_signal_cupy = _process_data_caving_cupy( data_cupy=data_signal_cupy, **params_caving_cupy, ) if data_variance_cupy is not None: data_variance_cupy = _process_data_caving_cupy( data_cupy=data_variance_cupy, **params_caving_cupy, ) ( data_signal_corrected_cupy, data_variance_corrected_cupy, data_sigma_corrected_cupy, signal_2scat_cupy, ) = _process_data_2scat_cupy( data_signal=data_signal_cupy, data_variance=data_variance_cupy, **params_2scat_cupy, ) dataset_signal_corrected[index_frame] = data_signal_corrected_cupy.get() dataset_signal_2scat[index_frame] = signal_2scat_cupy.get() if data_variance_corrected_cupy is not None: dataset_variance_corrected[index_frame] = data_variance_corrected_cupy.get() if data_sigma_corrected_cupy is not None: dataset_sigma_corrected[index_frame] = data_sigma_corrected_cupy.get() return ( dataset_signal_corrected, dataset_variance_corrected, dataset_sigma_corrected, dataset_signal_2scat, ) def _process_data_2scat_cupy( data_signal: numpy.ndarray, array_window_cupy: cupy.ndarray, # already shifted window Center_1: int, Center_2: int, WindowRoiSize: int = 120, Dummy: Optional[int] = -10, data_variance: Optional[numpy.ndarray] = None, clip_data: bool = True, mask_clip_cupy: Optional[cupy.ndarray] = None, **kwargs: Optional[dict], ) -> Tuple[Optional[numpy.ndarray], Optional[numpy.ndarray]]: data_signal_cupy = None if isinstance(data_signal, cupy.ndarray): data_signal_cupy = data_signal elif isinstance(data_signal, numpy.ndarray): data_signal_cupy = cupy.asarray(data_signal) data_variance_cupy = None if data_variance is not None: if isinstance(data_variance, cupy.ndarray): data_variance_cupy = data_variance elif isinstance(data_variance, numpy.ndarray): data_variance_cupy = cupy.asarray(data_variance) # 1) Slice the original data around the center, this will be the convolution kernel subdata_signal_cupy = data_signal_cupy[ Center_2 - WindowRoiSize : Center_2 + WindowRoiSize, Center_1 - WindowRoiSize : Center_1 + WindowRoiSize, ].copy() # 2) Cover the dummy values in the subdataset (dummy values jeopardize the convolution) cupy.copyto(subdata_signal_cupy, 0.0, where=subdata_signal_cupy == Dummy) # 3) Perform the convolution signal_2scat_cupy = oaconvolve_cupy( array_window_cupy, subdata_signal_cupy, mode="same" ) # 4) Calculate the corrected signal, variance and sigma data_signal_corrected_cupy = cupy.where( data_signal_cupy == Dummy, data_signal_cupy, data_signal_cupy - signal_2scat_cupy, ) if data_variance_cupy is not None: data_variance_corrected_cupy = cupy.where( data_variance_cupy == Dummy, data_variance_cupy, data_variance_cupy + signal_2scat_cupy + 0.0, ) data_sigma_corrected_cupy = cupy.where( data_variance_cupy == Dummy, data_variance_cupy, cupy.sqrt(data_variance_corrected_cupy), ) else: data_variance_corrected_cupy = None data_sigma_corrected_cupy = None # 5) Clip the data that could not be corrected if clip_data: cupy.copyto(data_signal_corrected_cupy, Dummy, where=mask_clip_cupy) if data_variance_corrected_cupy is not None: cupy.copyto(data_variance_corrected_cupy, Dummy, where=mask_clip_cupy) if data_sigma_corrected_cupy is not None: cupy.copyto(data_sigma_corrected_cupy, Dummy, where=mask_clip_cupy) return ( data_signal_corrected_cupy, data_variance_corrected_cupy, data_sigma_corrected_cupy, signal_2scat_cupy, )
[docs] def process_dataset_2scat( dataset_signal: numpy.ndarray, filename_window_wagon: str, Center_1: float, Center_2: float, WindowRoiSize: int = 120, Dummy: Optional[int] = -10, dataset_variance: Optional[numpy.ndarray] = None, algorithm_2scat: str = "numpy", clip_data: bool = True, pre_caving: bool = True, filename_mask_static: str = None, filename_mask_reference: str = None, flip_caving: bool = False, flip_horizontally_preference: bool = True, **kwargs, ) -> Tuple[Optional[numpy.ndarray], Optional[numpy.ndarray]]: """ Calculate the secondary scattering correction for the given dataset. Parameters: dataset (numpy.ndarray): The input dataset to be corrected. window_pattern (str): Path to the window pattern file. WindowRoiSize (int): Distance to extract subdata for correction. center_x (Optional[float], optional): X-coordinate of the center. Defaults to None. center_y (Optional[float], optional): Y-coordinate of the center. Defaults to None. dummy (int, optional): Dummy value for masked regions. Defaults to -10. use_cupy (bool, optional): Whether to use CuPy for GPU acceleration. Defaults to True. Returns: Tuple[Optional[numpy.ndarray], Optional[numpy.ndarray]]: The corrected dataset and the secondary scattering. """ dataset_signal_corrected = None dataset_variance_corrected = None dataset_sigma_corrected = None dataset_2scat_correction = None results = ( dataset_signal_corrected, dataset_variance_corrected, dataset_sigma_corrected, dataset_2scat_correction, ) if dataset_signal is None: logger.error("Dataset is None. Sec. scattering correction cannot be performed") return results if dataset_signal.ndim not in (2, 3): logger.error( f"Dataset with shape {dataset_signal.shape} must be 2 or 3-dimensional" ) return results # Load the additional data if filename_window_wagon is None: logger.error( "Window pattern data is None. Cannot perform secondary scattering correction" ) return results binning = kwargs.get("binning") data_signal_shape = dataset_signal[0].shape use_cupy = True if algorithm_2scat == "cupy" else False array_window_wagon = get_persistent_array_window_wagon( filename_window_wagon=filename_window_wagon, data_signal_shape=data_signal_shape, datatype=dataset_signal.dtype, binning=binning, use_cupy=False, # shift_window is a numpy method ) if array_window_wagon is None: logger.error( f"{filename_window_wagon} could not be loaded. Cannot perform secondary scattering correction" ) return results if algorithm_2scat not in ALGORITHMS_AVAILABLE: logger.warning( f"Algorithm '{algorithm_2scat}' is not available. Using '{DEFAULT_ALGORITHM}' instead." ) algorithm_2scat = DEFAULT_ALGORITHM elif algorithm_2scat == "cupy" and not CUPY_AVAILABLE: logger.warning(f"CuPy is not available. Using {DEFAULT_ALGORITHM} instead.") algorithm_2scat = DEFAULT_ALGORITHM use_cupy = True if algorithm_2scat == "cupy" else False params_2scat = { "dataset_signal": dataset_signal, "dataset_variance": dataset_variance, "Center_1": Center_1, "Center_2": Center_2, "array_window": array_window_wagon, "WindowRoiSize": WindowRoiSize, "Dummy": Dummy, "clip_data": clip_data, "pre_caving": pre_caving, "filename_mask_static": filename_mask_static, "filename_mask_reference": filename_mask_reference, "flip_caving": flip_caving, "flip_horizontally_preference": flip_horizontally_preference, "use_cupy": use_cupy, **kwargs, } results = ALGORITHMS_AVAILABLE[algorithm_2scat]["algorithm"](**params_2scat) return results
ALGORITHMS_AVAILABLE = { "numpy": {"algorithm": _process_dataset_2scat_numpy, "use_cupy": False}, "cupy": {"algorithm": _process_dataset_2scat_cupy, "use_cupy": True}, } DEFAULT_ALGORITHM = "numpy"