"""
SLURMTaskQueue
Defines classes for executing a collection of tasks in a single SLURM job. A
task is defined as a command to be run in parallel using a given number of
SLURM tasks, as would be run with `ibrun -n`.
"""
import argparse as ap
import copy
import json
import linecache as lc
import logging
import os
import pdb
import re
import shutil
import stat
import subprocess
import sys
import tempfile
import time
import traceback
from pathlib import Path
from typing import Dict, List, Optional, Union
from pythonjsonlogger import jsonlogger
from pyslurmtq.Slot import Slot
from pyslurmtq.Task import Task
from pyslurmtq.utils import (compact_int_list, expand_int_list,
filter_res)
__author__ = "Carlos del-Castillo-Negrete"
__copyright__ = "Carlos del-Castillo-Negrete"
__license__ = "MIT"
[docs]
class SLURMTaskQueue:
"""
Implements a Task Queue of :class:Task objects to be executed in parallel
on available compute nodes according to the SLURM environment variables.
Attributes
----------
task_slots : List(:class:Slot)
List of task slots available. This is parsed upon initialization from
SLURM environment variables SLURM_JOB_NODELIST and SLURM_TASKS_PER_NODE.
workdir : str
Path to directory to store files for tasks executed, if the tasks
themselves dont specify their own work directories. Defaults to a
directory with the prefix `.stq-job{SLURM_JOB_ID}-` in the
current working directory.
delay : float
Number of seconds to pause between iterations of updating the queue.
Default is 1 second. Note this affects the poll rate of tasks runing
in the queue.
task_max_runtime : float
Max run time, in seconds, any individual task in the queue can run for.
max_runtime : float
Max run time, in seconds, for execution of `run()` to empty the queue.
task_count : int
Running counter, starting from 0, of total tasks that pass through the
queue. The current count is used for the task_id of the next task added
to the queue, so that a tasks task_id corresponds to the order in which
it was added to the queue.
running_time : float
Total running time of the queue when `run()` is executed.
queue : List(:class:Task)
List of :class:Task in queue. Populated via the
`enqueue_from_json()` method.
running : List(:clas:Task)
List of :class:Task that are currently running.
completed : List(:clas:Task)
List of :class:Task that are completed running successfully, in
that the process executing them returned a 0 exit code.
errored : List(:clas:Task)
List of :class:Task that failed to run successfully in that the
processes executing them returned a non-zero exit code..
timed_out : List(:clas:Task)
List of :class:Task that failed to run successfully in that the
their runtime exceeded `task_max_runtime`.
invalid : List(:clas:Task)
List of :class:Task that were not run because their configurations
were invalid, or the amount of resources required to run them was too
large.
"""
def __init__(
self,
tasks: List[dict] = None,
task_file: str = None,
workdir: str = None,
task_max_runtime: float = 1e10,
max_runtime: float = 1e10,
delay: float = 1,
loglevel: int = logging.DEBUG,
summary_interval: float = 60,
):
# Default workdir for executing tasks if task doesn't specify workdir
self.workdir = workdir
if self.workdir is None:
self.workdir = Path(
tempfile.mkdtemp(
prefix=f'.stq-job{os.environ["SLURM_JOB_ID"]}-',
dir=Path.cwd(),
)
)
else:
self.workdir = Path(workdir) if type(workdir) != Path else workdir
self.workdir.mkdir(exist_ok=True)
# Set-up job logging
self._logger = logging.getLogger(__name__)
_logHandler = logging.FileHandler(self.workdir / "tq_log.json")
_formatter = jsonlogger.JsonFormatter(
"%(asctime)s %(name)s - %(levelname)s:%(message)s"
)
_logHandler.setFormatter(_formatter)
self._logger.addHandler(_logHandler)
self._logger.setLevel(loglevel)
# Node list - Initialize from SLURM environment
self.task_slots = []
self._init_task_slots()
# Set queue runtime constants
self.delay = delay
self.task_max_runtime = task_max_runtime
self.max_runtime = max_runtime
self.summary_interval = summary_interval
# Initialize Task Queue Arrays
self.task_count = 0
self.running_time = 0.0
self.queue = []
self.running = []
self.completed = []
self.errored = []
self.timed_out = []
self.invalid = []
# Enqueue tasks from json file
if task_file is not None:
self.enqueue_from_json(task_file)
if tasks is not None:
self.enqueue(tasks)
self._logger.info(f"Queue initialized: {self}", extra=self.__dict__)
def __str__(self):
queue_str = ""
sc = lambda x: compact_int_list(sorted([t.task_id for t in x]))
if len(self.queue) > 0:
queue_str += f"queued=[{sc(self.queue)}], "
if len(self.running) > 0:
queue_str += f"running=[{sc(self.running)}], "
if len(self.completed) > 0:
queue_str += f"completed=[{sc(self.completed)}], "
if len(self.timed_out) > 0:
queue_str += f"timed_out=[{sc(self.timed_out)}], "
if len(self.errored) > 0:
queue_str += f"errored=[{sc(self.errored)}, "
if len(self.invalid) > 0:
queue_str += f"invalid=[{sc(self.invalid)}, "
queue_str = queue_str[:-2] if len(queue_str) != 0 else queue_str
unique_slots = list(set([s.host for s in self.task_slots]))
status = []
for h in unique_slots:
free = []
busy = []
for s in self.task_slots:
if s.host == h:
free.append(s.idx) if s.is_free() else busy.append(s.idx)
status.append((h, compact_int_list(free), compact_int_list(busy)))
slots = [f"{x[0]}: (FREE: [{x[1]}], BUSY: [{x[2]}])" for x in status]
s = f"(workdir: {self.workdir}, "
s += f"slots: [{', '.join(slots)}], "
s += f"queue-state:[{queue_str}])"
return s
def _init_task_slots(self):
"""Initialize available task slots from SLURM environment variables"""
hl = []
slurm_nodelist = os.environ["SLURM_JOB_NODELIST"]
self._logger.debug(
f"Parsing SLURM_JOB_NODELIST {slurm_nodelist}",
extra={"SLURM_JOB_NODELIST": slurm_nodelist},
)
host_groups = re.split(r",\s*(?![^\[\]]*\])", slurm_nodelist)
for hg in host_groups:
splt = hg.split("-")
h = splt[0] if type(splt) == list else splt
ns = "-".join(splt[1:])
ns = ns[1:-1] if ns[0] == "[" else ns
padding = min([len(x) for x in re.split(r"[,-]", ns)])
hl += [f"{h}-{str(x).zfill(padding)}" for x in expand_int_list(ns)]
self._logger.debug(f"Parsed nodelist {hl}", extra={"hl": hl})
tasks_per_host = []
slurm_tph = os.environ["SLURM_TASKS_PER_NODE"]
self._logger.debug(f"Parsing SLURM_TAKS_PER_NODE {slurm_tph}")
total_idx = 0
for idx, tph in enumerate(slurm_tph.split(",")):
mult_split = tph.split("(x")
ntasks = int(mult_split[0])
if len(mult_split) > 1:
for i in range(int(mult_split[1][:-1])):
tasks_per_host.append(ntasks)
for j in range(ntasks):
self.task_slots.append(Slot(hl[idx], total_idx + j))
self._logger.debug(f"Initialized slot {self.task_slots[-1]}")
total_idx += ntasks
else:
for j in range(ntasks):
self.task_slots.append(Slot(hl[idx], total_idx + j))
self._logger.debug(f"Initialized slot {self.task_slots[-1]}")
total_idx += ntasks
self._logger.debug(f"Initialized {len(self.task_slots)}")
def _request_slots(self, task):
"""Request a number of slots for a task"""
start = 0
found = False
cores = task.cores
while not found:
if start + cores > len(self.task_slots):
return False
for i in range(start, start + cores):
found = self.task_slots[i].is_free()
if not found:
start = i + 1
break
# If reach here -> Execute task on offset equal to start
self._logger.debug(
f"Starting {task.task_id} at slot index {start}", extra=task.__dict__
)
task.execute(start, cores)
self._logger.info(
f"{task.task_id} running on process {task.sub_proc.pid}",
extra=task.__dict__,
)
# Mark slots as occupied with with task_id
for n in range(start, start + cores):
s = self.task_slots[n]
self._logger.debug(f"Occupying slot{s}", extra=s.__dict__)
s.occupy(task)
self._logger.debug(f"Slot{s} occupied", extra=s.__dict__)
return True
def _release_slots(self, task_id):
"""Given a task id, release the slots that are associated with it"""
for s in self.task_slots:
if not s.is_free():
if s.tasks[-1].task_id == task_id:
self._logger.debug(f"Releasing slot {s}", extra=s.__dict__)
s.release()
self._logger.debug(f"Slot {s} released", extra=s.__dict__)
def _start_queued(self):
"""
Start queued tasks. For all queued, try to find a continuous set of
slots equal to the number of cores required for the task. The tasks are
looped through in decreasing order of number of cores required. If the
task is to big for the whole set of available slots, it is automatically
added to the invalid list. Otherwise `_request_slots` is called to see
if there space for the task to be run in the available slots.
"""
# Sort queue in decreasing order of # of cores
tqueue = copy.copy(self.queue)
tqueue.sort(key=lambda x: -x.cores)
for task in tqueue:
if task.cores > len(self.task_slots):
self._logger.warning(
f"Task {task} to large. Adding to invalid list.",
extra=task.__dict__,
)
task.err_msg = "Invalid task (too many cores for queue): "
task.err_msg += "{task.cores}>len(self.task_slots)"
self.queue.remove(task)
self.invalid.append(task)
continue
if self._request_slots(task):
self._logger.info(
f"Successfully found resources for task {task}", extra=task.__dict__
)
self.queue.remove(task)
self.running.append(task)
else:
self._logger.debug(
f"Unable to find resources for {task}.", extra=task.__dict__
)
num_removed = len(tqueue) - len(self.queue)
if num_removed > 0:
self._logger.info(f"Started {num_removed} tasks", extra=self.__dict__)
def _terminate_and_release(self, task, msg):
"""Teriminate a task and release its resources"""
self._logger.error(msg, extra=task.__dict__)
task.terminate()
task.err_msg = msg
self.timed_out.append(task)
self._logger.info(f"Releasing slots related assigned to task {task.task_id}")
self._release_slots(task.task_id)
def _update(self):
"""
Update status of tasks in queue by calling polling subprocesses
executing them with `get_rc()`. Tasks are added to the erorred or
completed lists, or terminated and added to timed_out list if
`task_max_runtime` is exceeded.
"""
running = []
for t in self.running:
rc = t.get_rc()
if rc is None:
rt = time.time() - t.start_ts
if rt > self.task_max_runtime:
msg = f"Task {t.task_id} has exceeded task max runtime {rt}"
self._terminate_and_release(t, msg)
else:
running.append(t)
else:
if rc == 0:
self._logger.info(
f"{t.task_id} DONE: {t.running_time:5.3f}s", extra=t.__dict__
)
self.completed.append(t)
self._release_slots(t.task_id)
else:
msg = f"{t.task_id} FAILED: rt = {t.running_time:5.3f}s, "
msg += f"rc = {t.rc}, err file (last_line) = {t.err_msg}"
self._logger.error(msg, extra=t.__dict__)
self.errored.append(t)
self._release_slots(t.task_id)
# Release slots for completed tasks
finished = len(self.running) - len(running)
if finished > 0:
self.running = running
self._logger.info(f"{finished} tasks finished", extra=self.__dict__)
def _save_summary(self):
"""Save task and queue summaries to workdir"""
_ = self.summary_by_task(
print_res=False, fname=str(self.workdir / "task_summary.txt")
)
_ = self.summary_by_slot(
print_res=False, fname=str(self.workdir / "slot_summary.txt")
)
[docs]
def enqueue(self, task_list: List[dict], cores: int = 1):
"""
Add a list of tasks to the queue. Each task is a dictionary with at mininum
each containing a `cmnd` field indicating the command to be executed in parallel
using a corresponding number of `cores`, which defaults to the passed in value
if not specified per task configuration.
Parameters
----------
task_list : List[dict]
List of dictionaries, one per task with the following fields:
'cmnd' : required, parllalel command to execute
'cores' : optional, number of cores to user on this task
'pre_process' : optional, serial command to run prior to running 'cmnd'
'post_process' : optional, serial command to run after running 'cmnd'
cores : int
Default number of cores to use for each task if not specified within
task configuration.
"""
self._logger.debug(f"Enqueuing {len(task_list)} tasks.")
for i, t in enumerate(task_list):
self._logger.debug(f"Attempting to create task {self.task_count}", extra=t)
try:
task = Task(
self.task_count,
t.pop("cmnd", None),
t.pop("workdir", self.workdir),
t.pop("cores", cores),
t.pop("pre", None),
t.pop("post", None),
t.pop("cdir", None),
)
except ValueError as v:
self._logger.error(f"Bad task in list at idx {i}: {v}", extra=t)
continue
self._logger.debug(f"Enqueing {task}", extra=task.__dict__)
self.queue.append(task)
self.task_count += 1
[docs]
def enqueue_from_json(self, filename, cores=1):
"""
Add a list of tasks to the queue from a JSON file. The json file must
contain a list of configurations, with at mininum each containing a
`cmnd` field indicating the command to be executed in parallel using
a corresponding number of `cores`, which defaults to the passed in value
if not specified per task configuration.
Parameters
----------
filename : str
Path to json files containing list of json configurations, one per
task to add to the queue.
cores : int
Default number of cores to use for each task if not specified within
task configuration.
"""
self._logger.debug(f"Loading json task file {filename}")
with open(filename, "r") as fp:
task_list = json.load(fp)
self._logger.debug(f"Found {len(task_list)} tasks.")
self.enqueue(task_list, cores=cores)
[docs]
def run(self):
"""
Runs tasks and wait for all tasks in queue to complete, or until
`max_runtime` is exceeded.
"""
self.start_ts = time.time()
self._logger.info("Starting launcher job", extra=self.__dict__)
self._save_summary()
summary_counter = time.time()
while True:
elapsed = time.time() - self.start_ts
if elapsed - summary_counter > self.summary_interval:
self._save_summary()
summary_counter = elapsed
if elapsed >= self.max_runtime:
msg = f"Exceeded max runtime : {elapsed}>{self.max_runtime}"
self._logger.info(msg, extra=self.__dict__)
for t in self.running:
self._terminate_and_release(t, msg)
break
# Start queued jobs
self._logger.debug("Starting queued tasks")
self._start_queued()
# Update queue for completed/errored jobs
self._logger.debug("Updating task lists")
self._update()
# Wait for a bit
time.sleep(self.delay)
# Check if done
if len(self.running) == 0:
if len(self.queue) == 0:
self._logger.info(f"Running and queue are empty.")
break
self.running_time = time.time() - self.start_ts
self._logger.info("Queue run finished", extra=self.__dict__)
self._save_summary()
[docs]
def read_log(self) -> List[Dict[str, Union[str, int, float]]]:
"""
Read the JSON log file.
Returns
-------
List[Dict[str, Union[str, int, float]]]
List of dictionaries containing the log information for each entry.
"""
log_entries = []
with open(self.workdir / "tq_log.json", "r") as f:
for line in f:
log_entries.append(json.loads(line))
return log_entries
[docs]
def get_log(
self,
fields: List[str] = ["asctime", "levelname", "message"],
search: Optional[str] = None,
match: Optional[str] = None,
print_log: bool = True,
) -> List[Dict[str, Union[str, int, float]]]:
"""
Get and optionally print log entries.
Parameters:
-----------
fields : List[str], optional
List of fields to include in the summary. Defaults to ["asctime",
"levelname", "message"].
search : str, optional
String to search for in the summary. Defaults to None.
match : str, optional
Regular expression to match against the search string. Defaults to
None.
print_log : bool, optional
Whether to print the log to the console. Defaults to True.
Returns:
--------
List[Dict[str, Union[str, int, float]]]
List of dictionaries containing the log information for each entry.
"""
log = self.read_log()
filtered = filter_res(
log, fields=fields, search=search, match=match, print_res=print_log
)
return filtered
[docs]
def summary_by_task(
self,
fields: List[str] = [
"task_id",
"running_time",
"cores",
"command",
],
search: Optional[str] = None,
match: str = r".",
all_fields: bool = False,
print_res: bool = True,
fname: Optional[str] = None,
) -> List[Dict[str, Union[str, int, float]]]:
"""
Summarize queue stats by task.
Parameters:
-----------
fields : List[str], optional
List of fields to include in the summary. Defaults to ["task_id",
"running_time", "cores", "command"].
search : str, optional
String to search for in the summary. Defaults to None.
match : str, optional
Regular expression to match against the search string. Defaults to
".".
all_fields : bool, optional
Whether to include all available fields in the summary. Defaults
to False.
print_res : bool, optional
Whether to print the summary to the console. Defaults to True.
fname : str, optional
File name to write the summary to. Defaults to None.
Returns:
--------
List[Dict[str, Union[str, int, float]]]
List of dictionaries containing the summary information for each
task.
"""
"""Summarize queue stats by task"""
avail_fields = [
"task_id",
"command",
"cores",
"pre",
"post",
"cdir",
"workdir",
"execfile",
"logfile",
"errfile",
"slots",
"start_ts",
"end_ts",
"running_time",
"err_msg",
]
bad_fields = [f for f in fields if f not in avail_fields]
if len(bad_fields) > 0:
msg = f"Invalid fields {bad_fields}. Avialable {avail_fields}"
raise ValueError(msg)
fields = avail_fields if all_fields else fields
# Build dictionary of task attributes according to fields list
get_info = lambda x: [(f, getattr(x, f)) for f in fields]
task_info = []
for task in self.running:
task_info.append(dict([("status", "running")] + get_info(task)))
for task in self.queue:
task_info.append(dict([("status", "queued")] + get_info(task)))
for task in self.completed:
task_info.append(dict([("status", "completed")] + get_info(task)))
for task in self.errored:
task_info.append(dict([("status", "errored")] + get_info(task)))
for task in self.timed_out:
task_info.append(dict([("status", "timed_out")] + get_info(task)))
for task in self.invalid:
task_info.append(dict([("status", "invalid")] + get_info(task)))
fields = ["status"] + fields
filtered = filter_res(
task_info,
fields=fields,
search=search,
match=match,
print_res=print_res,
output_file=fname,
)
return filtered
[docs]
def summary_by_slot(
self,
fields: List[str] = [
"idx",
"host",
"status",
"num_tasks",
"task_ids",
"free_time",
"busy_time",
],
search: Optional[str] = None,
match: str = r".",
all_fields: bool = False,
print_res: bool = True,
fname: Optional[str] = None,
) -> List[Dict[str, Union[int, str, List[int], float]]]:
"""
Summarize queue stats by slots.
Parameters
----------
fields : List[str], optional
List of fields to include in the summary. Defaults to
["idx", "host", "status", "num_tasks", "task_ids", "free_time",
"busy_time"].
search : str, optional
String to search for in the summary. Defaults to None.
match : str, optional
Regular expression to match against the search string. Defaults to
".".
all_fields : bool, optional
Whether to include all available fields in the summary. Defaults
to False.
print_res : bool, optional
Whether to print the summary to the console. Defaults to True.
fname : str, optional
File name to write the summary to. Defaults to None.
Returns
-------
List[Dict[str, Union[int, str, List[int], float]]]
List of dictionaries containing the summary information for each
slot.
"""
avail_fields = [
"idx",
"host",
"status",
"num_tasks",
"task_ids",
"free_time",
"busy_time",
]
bad_fields = [f for f in fields if f not in avail_fields]
if len(bad_fields) > 0:
msg = f"Invalid fields {bad_fields}. Avialable {avail_fields}"
raise ValueError(msg)
slot_info = []
for s in self.task_slots:
slot_info.append(
{
"idx": s.idx,
"host": s.host,
"status": "FREE" if s.free else "BUSY",
"num_tasks": len(s.tasks),
"task_ids": [t.task_id for t in s.tasks],
"free_time": s.free_time,
"busy_time": s.busy_time,
}
)
fields = fields
filtered = filter_res(
slot_info,
fields=fields,
search=search,
match=match,
print_res=print_res,
output_file=fname,
)
return filtered
[docs]
def cleanup(self):
"""Clean-up Task Queue by removing workdir"""
shutil.rmtree(str(self.workdir))