Source code for ewoksid02.scripts.saxs.main
import logging
import time
import os
from pprint import pprint
from ewokstools.submit import (
finish_queue,
save_and_execute,
get_valid_home_directory,
)
from ewokstools.iterators import (
get_datasets_list_id02,
)
from ewokstools.parsers import (
generate_params_from_yaml_file,
get_params_from_cli,
)
from ...resources import WORKFLOW_SAXS_LOOP
from ...tasks.averagetask import AverageTask
from ...tasks.azimuthaltask import AzimuthalTask
from ...tasks.cavingtask import CavingBeamstopTask
from ...tasks.normalizationtask import NormalizationTask
from ...tasks.secondaryscatteringtask import SecondaryScatteringTask
from ...tasks.id02processingtask import ID02ProcessingTask
from ...tasks.gallerytask import SaveGalleryTask
from ...utils.blissdata import LIMA_URL_TEMPLATE_ID02, get_lima_url_template_args_id02
from ..utils import (
SLURM_JOB_PARAMETERS_SAXS,
ID02_SLURM_PRE_SCRIPT,
ID02_SLURM_POST_SCRIPT,
ID02_WORKER_MODULE,
)
logger = logging.getLogger(__name__)
[docs]
def get_saxs_inputs(
**kwargs,
) -> list:
"""Compile and return the list of inputs to be used on an ewoks SAXS/WAXS pipeline."""
inputs = []
###########
# Add all nodes inputs
###########
node_common_keys = ID02ProcessingTask.input_names()
inputs_dict_all_nodes = {
k: kwargs[k] for k in node_common_keys if kwargs.get(k) is not None
}
inputs_dict_all_nodes.update(
{
"filename_data": kwargs.get("dataset_filename"),
"filename_lima": kwargs.get("filename_scan"),
"lima_url_template": LIMA_URL_TEMPLATE_ID02,
"lima_url_template_args": get_lima_url_template_args_id02(
scan_number=kwargs.get("scan_nb"),
detector_name=kwargs.get("detector_name"),
collection_name=kwargs.get("collection_name"),
),
}
)
inputs += [
{"name": k, "value": v, "all": True} for k, v in inputs_dict_all_nodes.items()
]
#############
# Add normalization inputs
#############
inputs_dict_norm = {
k: kwargs[k]
for k in NormalizationTask.input_names()
if kwargs.get(k) is not None and k not in node_common_keys
}
additional_keys_norm = {
"filename_maskgaps": "filename_mask_normalization",
"algorithm_normalization": "algorithm_normalization",
}
for key_template, key_node in additional_keys_norm.items():
if kwargs.get(key_template):
inputs_dict_norm[key_node] = kwargs.get(key_template)
inputs += [
{
"name": k,
"value": v,
"task_identifier": NormalizationTask.class_registry_name(),
"id": "norm",
}
for k, v in inputs_dict_norm.items()
]
#############
# Add secondary scattering inputs
#############
inputs_dict_2scat = {
k: kwargs[k]
for k in SecondaryScatteringTask.input_names()
if kwargs.get(k) is not None and k not in node_common_keys
}
additional_keys_2scat = {
"filename_maskgaps": "filename_mask_static",
"filename_maskbeamstop": "filename_mask_reference",
"algorithm_2scat": "algorithm_2scat",
}
for key_template, key_node in additional_keys_2scat.items():
if kwargs.get(key_template):
inputs_dict_2scat[key_node] = kwargs.get(key_template)
inputs += [
{
"name": k,
"value": v,
"task_identifier": SecondaryScatteringTask.class_registry_name(),
"id": "2scat",
}
for k, v in inputs_dict_2scat.items()
]
#############
# Add caving inputs
#############
inputs_dict_cave = {
k: kwargs[k]
for k in CavingBeamstopTask.input_names()
if kwargs.get(k) is not None and k not in node_common_keys
}
additional_keys_cave = {
"filename_maskbeamstop": "filename_mask_static",
"algorithm_cave": "algorithm_cave",
}
for key_template, key_node in additional_keys_cave.items():
if kwargs.get(key_template):
inputs_dict_cave[key_node] = kwargs.get(key_template)
inputs += [
{
"name": k,
"value": v,
"task_identifier": CavingBeamstopTask.class_registry_name(),
"id": "cave",
}
for k, v in inputs_dict_cave.items()
]
#############
# Add azimuthal inputs
#############
inputs_dict_azim = {
k: kwargs[k]
for k in AzimuthalTask.input_names()
if kwargs.get(k) is not None and k not in node_common_keys
}
additional_keys_azim = {
"filename_maskbeamstop": "filename_mask_azimuthal",
}
for key_template, key_node in additional_keys_azim.items():
if kwargs.get(key_template):
inputs_dict_azim[key_node] = kwargs.get(key_template)
inputs += [
{
"name": k,
"value": v,
"task_identifier": AzimuthalTask.class_registry_name(),
"id": "azim",
}
for k, v in inputs_dict_azim.items()
]
#############
# Add average inputs
#############
inputs_dict_ave = {
k: kwargs[k]
for k in AverageTask.input_names()
if kwargs.get(k) is not None and k not in node_common_keys
}
additional_keys_ave = {}
for key_template, key_node in additional_keys_ave.items():
if kwargs.get(key_template):
inputs_dict_ave[key_node] = kwargs.get(key_template)
inputs += [
{
"name": k,
"value": v,
"task_identifier": AverageTask.class_registry_name(),
"id": "ave",
}
for k, v in inputs_dict_ave.items()
]
#############
# Add flag inputs
#############
to_process = kwargs.get("to_process", "").split(" ")
to_save = kwargs.get("to_save", "").split(" ")
nodes = ["norm", "2scat", "cave", "azim", "ave", "scalers"]
inputs += [
{"name": "do_process", "id": node, "value": node in to_process}
for node in nodes
]
inputs += [
{"name": "do_save", "id": node, "value": node in to_save} for node in nodes
]
##############
# Add processing filenames inputs
##############
processing_filename_template = kwargs.get("processed_filename_template")
if not processing_filename_template:
processing_filename_template = kwargs.get("processed_filename_scan")
if processing_filename_template:
processing_filename_template = processing_filename_template.replace(
".h5", "{tag}.h5"
)
tag = kwargs.get("tag", "")
if tag:
tag = f"_{tag}"
if processing_filename_template:
inputs_dict_filenames = {
"norm": processing_filename_template.format(tag=f"{tag}_norm"),
"2scat": processing_filename_template.format(tag=f"{tag}_2scat"),
"cave": processing_filename_template.format(tag=f"{tag}_cave"),
"azim": processing_filename_template.format(tag=f"{tag}_azim"),
"ave": processing_filename_template.format(tag=f"{tag}_ave"),
"scalers": processing_filename_template.format(tag=f"{tag}_scalers"),
}
inputs += [
{"name": "processing_filename", "value": value, "id": task_id}
for task_id, value in inputs_dict_filenames.items()
]
# Gallery task
inputs_gallery = {
"nxdata_url": f"{processing_filename_template.format(tag=f"{tag}_ave")}::/entry_0000/PyFAI/result_ave",
"do_save": kwargs.get("save_gallery", True),
}
inputs += [
{
"name": k,
"value": v,
"task_identifier": SaveGalleryTask.class_registry_name(),
"id": "save_gallery",
}
for k, v in inputs_gallery.items()
]
return inputs
[docs]
def main(args):
"""Main function to trigger the SAXS/WAXS pipeline."""
saxs_parameters = {}
saxs_dataset_parameters = {}
# 2) Get parameters from .yaml files provided in the command line
for saxs_parameters_from_yaml in generate_params_from_yaml_file(args.FILES):
saxs_parameters.update(saxs_parameters_from_yaml)
# 3) Add more bliss filenames from the command line
saxs_parameters["bliss_filenames"] += [
file for file in args.FILES if file.endswith(".h5")
]
# 4) Get parameters from the command line
reprocess_parameters_from_cli = get_params_from_cli(args)
saxs_parameters.update(reprocess_parameters_from_cli)
# 5) If no input/pyfai parameter was provided, try through user input
if not saxs_parameters.get("bliss_filenames"):
saxs_parameters["bliss_filenames"] = (
input(
"No bliss filenames provided. Please enter the filenames (comma-separated): "
)
.strip()
.split(",")
)
# 6) Iterate through the bliss saving objects
dataset_list = get_datasets_list_id02(**saxs_parameters)
nb_datasets = len(dataset_list)
print(
f"\033[92mFound {nb_datasets} datasets in {saxs_parameters['bliss_filenames']}\033[0m"
)
filenames_dataset = [
dataset_info["dataset_filename"] for dataset_info in dataset_list
]
print("\033[92m", end="")
pprint(filenames_dataset)
print("\033[0m", end="")
if nb_datasets > 10:
logger.warning(
"More than 10 datasets found in this file. You have 10 seconds to cancel..."
)
time.sleep(10)
for nb_submitted, dataset_info in enumerate(dataset_list, start=1):
saxs_dataset_parameters = {
**saxs_parameters,
**dataset_info,
}
tag = saxs_dataset_parameters.get("tag", "")
dryrun = "dryrun" if not saxs_dataset_parameters.get("submit") else ""
tag = "_".join(filter(None, [tag, dryrun]))
saxs_dataset_parameters["tag"] = tag
# Take slurm parameters
slurm_job_parameters = {
**SLURM_JOB_PARAMETERS_SAXS,
**saxs_dataset_parameters.pop("slurm_job_parameters", {}),
}
# slurm scripts options
if not saxs_dataset_parameters.get("worker_module"):
saxs_dataset_parameters["worker_module"] = ID02_WORKER_MODULE
# slurm scripts options
if not saxs_dataset_parameters.get("pre_script_python"):
saxs_dataset_parameters["pre_script_python"] = ID02_SLURM_PRE_SCRIPT
# slurm scripts options
if not saxs_dataset_parameters.get("post_script_python"):
saxs_dataset_parameters["post_script_python"] = ID02_SLURM_POST_SCRIPT
user_name = os.environ.get(
"SLURM_USER", saxs_dataset_parameters.get("slurm_user", None)
)
user_name = user_name or os.environ.get("USER")
saxs_dataset_parameters["slurm_log_directory"] = (
f"{get_valid_home_directory(user_name=user_name)}/.ewoksid02/slurm_logs"
)
save_and_execute(
workflow=WORKFLOW_SAXS_LOOP,
inputs=get_saxs_inputs(**saxs_dataset_parameters),
slurm_job_parameters=slurm_job_parameters,
processing_name=tag,
**saxs_dataset_parameters,
execution_kwargs={
"engine": "ppf",
"pool_type": "thread",
},
)
print(
f"\033[92mSubmitted {nb_submitted}/{nb_datasets} datasets for reprocessing: {nb_submitted / nb_datasets * 100:.2f}%\033[0m"
)
finish_queue(**saxs_dataset_parameters)