Source code for forest_helper

from ws3 import common
from ws3 import opt
from concurrent.futures import ProcessPoolExecutor #, as_completed
from multiprocessing import get_context


MP_CONTEXT = "fork"

_GLOBAL_MODEL_GEN_VARS = None
_GLOBAL_COEFF_FUNCS_GEN_VARS = None
_GLOBAL_WORKERS_GEN_VARS = 1


[docs] def choose_max_batch_factor(workers): """ Adaptive max_batch_factor for auto_batch based on number of workers. :param workers: Number of worker processes. Integer value representing the number of worker processes available. :type workers: int :return: Optimized max_batch_factor value. An integer value that represents the optimized max_batch_factor value based on the number of workers. :rtype: int Usage notes: * This function is designed to work with auto_batch, which controls batch sizes for parallel processing. * The function's output is an integer representing the optimal max_batch_factor value for a given number of worker processes. Examples: >>> choose_max_batch_factor(1) 2 >>> choose_max_batch_factor(8) 4 Edge case warnings: * If workers <= 0, this function will raise a ValueError. """ if workers <= 2: return 2 elif workers <= 8: return 4 elif workers <= 16: return 8 else: return 16
[docs] def auto_batch(tasks, workers, max_batch_factor = None, size_fn= lambda x: 1.): """Split tasks into batches for parallel processing. Optionally sorts tasks by size (descending) and greedily fills batches. :param tasks: List of tasks to batch :type tasks: list :param workers: Number of workers (cores) that will be used later to process batches :type workers: int :param max_batch_factor: Scaling parameter (larger value yields more smaller batches), defaults to None :type max_batch_factor: int, optional :param size_fn: Task size estimation function returning float for greedy task sort, defaults to `lambda x: 1.` :type size_fn: function, optional :return: List of task batches :rtype: list[list] """ if not tasks: return [] if max_batch_factor is None: max_batch_factor = choose_max_batch_factor(workers) target_batches = max(1, workers * max_batch_factor) batch_size = max(1, len(tasks) // target_batches) # Default size function if not given if size_fn is None: size_fn = lambda x: 1 # Sort tasks by size (descending) sized_tasks = sorted(tasks, key=size_fn, reverse=True) # Initialize batches and their current total size batches = [[] for _ in range(target_batches)] batch_loads = [0] * target_batches # Greedy fill: always append to the lightest batch for task in sized_tasks: idx = batch_loads.index(min(batch_loads)) batches[idx].append(task) batch_loads[idx] += size_fn(task) # Remove empty batches (if tasks < batches) batches = [b for b in batches if b] # Optionally further split overly large batches if needed final_batches = [] for batch in batches: if len(batch) > batch_size * 2: # prevent one batch from being huge for i in range(0, len(batch), batch_size): final_batches.append(batch[i:i + batch_size]) else: final_batches.append(batch) return final_batches
[docs] def worker_summarize_tree_batch(args): """Summarize a batch of trees into coverage constraints and leaf outputs. :param args: [batch, z_coeff_key] :type args: list :return: [(cname, coeffs, z_coeffs), ...] :rtype: list[list] """ batch, z_coeff_key = args results = [] for i, tree in batch: cname = f'cov_{common.hex_id(i)}' coeffs = {} z_coeffs = {} for path in tree.paths(): j = tuple(n.data('acode') for n in path) leaf_id = path[-1].data('leaf_id') vname = f"x_{leaf_id}" coeffs[vname] = 1.0 z_coeffs[vname] = path[-1].data(z_coeff_key) results.append((cname, coeffs, z_coeffs)) return results
[docs] def sanitize_func(f): """Make a version of f that is safe to serialize via dill in `spawn` mode :param f: Function to sanitize :type f: function :return: Sanitized function :rtype: function """ import functools import types if isinstance(f, functools.partial): return functools.partial(sanitize_func(f.func), *f.args, **(f.keywords or {})) if isinstance(f, types.FunctionType): new_f = types.FunctionType( f.__code__, {}, # empty globals dict — no module context name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__, ) new_f.__module__ = '__main__' return new_f raise TypeError(f"Don't know how to sanitize function of type {type(f)}") from concurrent.futures import ThreadPoolExecutor, as_completed
[docs] def init_worker_gen_vars(blob_bytes_local, serialized_funcs_local, workers=1): """Initializer for `_gen_vars_m1` workers: load model and coefficient functions once. Also stores desired worker count for `_bld_tree_m1`. :param blob_bytes_local: Serialized `ForestModel` object :type blob_bytes_local: bytes :param serialized_funcs_local: dict of serialized functions keyed on `coeff_funcs` keys :type serialized_funcs_local: dict[str, bytes] :param workers: Number of workers, defaults to 1 :type workers: int, optional """ global _GLOBAL_MODEL_GEN_VARS, _GLOBAL_COEFF_FUNCS_GEN_VARS, _GLOBAL_WORKERS_GEN_VARS import dill _GLOBAL_MODEL_GEN_VARS = dill.loads(blob_bytes_local) _GLOBAL_COEFF_FUNCS_GEN_VARS = {k: dill.loads(f_bytes) for k, f_bytes in serialized_funcs_local.items()} _GLOBAL_WORKERS_GEN_VARS = workers
[docs] def worker_gen_vars(tasks, acodes): """ Worker for building trees in `_gen_vars_m1`. :param tasks: list of (dtk, age) tuples to process :type tasks: list[(str, ...), int)] :param acodes: list of action codes to use when building trees :type acodes: list[str] :return: list of (dtk, age, tree) tuples :rtype: list[(str, ...), int, Tree] """ model = _GLOBAL_MODEL_GEN_VARS coeff_funcs = _GLOBAL_COEFF_FUNCS_GEN_VARS workers = _GLOBAL_WORKERS_GEN_VARS results = [] for (dtk, age) in tasks: model.reset() area = model.dtypes[dtk].area(1, age) if not area: continue tree = model._bld_tree_m1( area, dtk, age, coeff_funcs, tree=None, period=1, acodes=acodes, compile_c_ycomps=True) results.append((dtk, age, tree)) return results
# ---------------------------- # Globals for _cmp_cflw_m1 parallel execution # ----------------------------
[docs] def worker_cmp_cflw_batch(args): """Worker function to process batches of tasks for `_cmp_cflw_m1` :param args: (batch, cflw_keys, periods) :type args: list[list, dict, list] :return: list of (t, o, i, j, value) tuples :rtype: list[int, str, tuple, tuple, float] """ batch, cflw_keys, periods = args results = [] for i, tree in batch: for path in tree.paths(): j = tuple(n.data('acode') for n in path) for o in cflw_keys: _mu = path[-1].data(o) for t in periods: results.append((t, o, i, j, _mu.get(t, 0.0))) return results
[docs] def worker_cmp_cflw_phase3(args): """ Worker function to compute (name, coeffs, sense, rhs) tuples for Phase 3 of `_cmp_cflw_m1`. :param args: (t, o, mu_t_o, mu_ref_o, eps, xnames) :type args: tuple(int, str, float, float, float, list[str]) :return: list of (constraint_name, mu_lb, sense, 0.) tuples :rtype: list[(str, float, str, float)] """ t, o, mu_t_o, mu_ref_o, eps, xnames = args results = [] keys = list(mu_t_o.keys()) x_keys = [xnames[k] for k in keys] mu_vals = [mu_t_o[k] for k in keys] mu_ref = [mu_ref_o[k] for k in keys] # Lower bound row mu_lb_vals = [v - (1 - eps) * r for v, r in zip(mu_vals, mu_ref)] mu_lb = dict(zip(x_keys, mu_lb_vals)) results.append((f'flw-lb_{t:03d}_{o}', mu_lb, opt.SENSE_GEQ, 0.0)) # Upper bound row mu_ub_vals = [v - (1 + eps) * r for v, r in zip(mu_vals, mu_ref)] mu_ub = dict(zip(x_keys, mu_ub_vals)) results.append((f'flw-ub_{t:03d}_{o}', mu_ub, opt.SENSE_LEQ, 0.0)) return results
[docs] def worker_cmp_cflw_phase3_batch(batch): """Worker function to process batches of phase 3 tasks for `_cmp_cflw_m1` :param batch: list of tasks (tuples) :type batch: list[tuple] :return: list of results :rtype: list[tuple] """ batch_results = [] for task in batch: batch_results.extend(worker_cmp_cflw_phase3(task)) return batch_results
# ---------------------------- # Globals for _cmp_cgen_m1 parallel execution # ----------------------------
[docs] def worker_cmp_cgen_batch(args): """Worker function to process batches of tasks for `_cmp_cgen_m1` :param args: (batch, cgen_keys, periods) :type args: list[list, dict, list] :return: list of (t, o, i, j, value) tuples :rtype: list[int, str, tuple, tuple, float] """ batch, cgen_keys, periods = args results = [] for i, tree in batch: for path in tree.paths(): j = tuple(n.data('acode') for n in path) leaf = path[-1] for o in cgen_keys: _mu = leaf.data(o) # dict {period: value} for t in periods: results.append((t, o, i, j, _mu.get(t, 0.0))) return results
[docs] def worker_cmp_cgen_phase3(args): """ Args: (t, o, mu_t_o, lb, ub) Returns: [(name, coeffs, sense, rhs), ...] """ t, o, mu_t_o, lb, ub = args # Build coeffs exactly like the known-good path: # NOTE: keys in mu_t_o are (i, j) coeffs = {'x_%s' % common.hex_id(k): v for k, v in mu_t_o.items()} res = [] if lb is not None and t in lb: res.append((f'gen-lb_{t:03d}_{o}', coeffs, opt.SENSE_GEQ, lb[t])) if ub is not None and t in ub: res.append((f'gen-ub_{t:03d}_{o}', coeffs, opt.SENSE_LEQ, ub[t])) return res
[docs] def worker_cmp_cgen_phase3_batch(batch): """Process a batch of Phase 3 CGEN tasks.""" out = [] for task in batch: out.extend(worker_cmp_cgen_phase3(task)) return out
[docs] class PersistentWorkerPool: """ Context manager for a persistent ProcessPoolExecutor that initializes workers with ForestModel and coeff_funcs. """ def __init__(self, workers, blob_bytes=None, serialized_funcs=None): """Constructor :param workers: Number of workers :type workers: int :param blob_bytes: Serialized `ForestModel` objects, defaults to None :type blob_bytes: bytes, optional :param serialized_funcs: dict of serialzed `coeff_funcs` functions, defaults to None :type serialized_funcs: dict[str, function], optional """ self.workers = workers self.blob_bytes = blob_bytes self.serialized_funcs = serialized_funcs self.executor = None def __enter__(self): """Create persistent worker pool executor :return: :rtype: ProcessPoolExecutor """ if self.workers > 1: ctx = get_context(MP_CONTEXT) self.executor = ProcessPoolExecutor( max_workers=self.workers, mp_context=ctx, initializer=init_worker_gen_vars, initargs=(self.blob_bytes, self.serialized_funcs, self.workers), ) return self.executor def __exit__(self, exc_type, exc_value, traceback): """Shut down persisten pool executor when with block exits""" if self.executor is not None: self.executor.shutdown()