"""Functions for step level parallelism."""
from __future__ import annotations
from abc import abstractproperty
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import cached_property, reduce, wraps
import itertools
import math
import re
from typing import (
Any, Callable, Dict, Iterable, List, Pattern, Set, Tuple, TypedDict, Union
)
from typing_extensions import Literal, NotRequired
from kedro.pipeline.node import Node
from kedro.pipeline import node
from kedro_partitioned.utils.constants import MAX_NODES, MAX_WORKERS
from kedro_partitioned.utils.other import (
nonefy,
truthify,
)
from kedro_partitioned.utils.string import (
get_filepath_without_extension,
)
from kedro_partitioned.utils.iterable import (
firstorlist,
partition,
tolist,
unique,
optionaltolist,
)
from kedro.pipeline import Pipeline
from kedro_partitioned.utils.typing import T, Args, IsFunction
_Partitioned = Dict[str, Callable[[], Any]]
class _Template(TypedDict):
"""Configuration for Configurator/partition matching.
Attributes:
pattern: A string of the possible partition subpaths containing
{} as placeholders for the `Configurator` target replacement
hierarchy: A list of strings of the pattern placeholder declared
in the order of the most hierarchycal, to the less.
any: A dictionary of placeholder by its specific regex for `*` targets
"""
pattern: str
hierarchy: NotRequired[List[str]]
any: Dict[str, str]
class _Configurator(TypedDict, total=False):
"""Specifies how a Configurator for multinodes must be declared.
Attributes:
target: A list of replacements for the `_Template` pattern.
Should be declared in the same order as the pattern occurences.
cached: Whether a partition should not be processed.
data: Everything that is gonna be passed to the multinode function
"""
target: List[Union[str, List[str], Literal['*']]]
cached: NotRequired[bool]
data: dict
class _Configurators(TypedDict):
"""
TypedDict for a container of `Configurator`s.
This describes how a configurator entry must be declared in the
parameters files.
"""
template: _Template
configurators: List[_Configurator]
class ConfiguratorFinder:
"""Finds the best configurator given a list of configurators.
Given a configurators dict, this class will find the best configurator for
a given subpath.
Attributes:
keys (List[str]): Keys between {} in the template.
weights (List[int]): Weights of the keys given a hierarchy in
the template. The higher the weight, the more specific the
configurator is.
scores (Dict[int, int]): Scores of the possible configurator targets.
"""
NOT_FOUND = -1
ANY = '*'
def __init__(self, configurators: _Configurators):
"""Initializes the ConfiguratorFinder.
Args:
configurators (_Configurators): Configurators to use.
"""
self._template = configurators['template']
self._configurators = configurators['configurators']
@cached_property
def keys(self) -> List[str]:
"""List of keys between {} in the template.
Returns:
List[str]: Keys between {} in the template.
"""
return re.findall(re.compile(r'\{(.+?)\}'), self._template['pattern'])
@cached_property
def weights(self) -> List[int]:
"""Weight for each key in the template.
Returns:
List[int]
"""
keys = self.keys
hierarchy = self._template.get('hierarchy', keys)
return [hierarchy.index(key) for key in keys]
@cached_property
def scores(self) -> Dict[int, int]:
"""Builds a dict of <hash of a target possibility>: <score>.
Returns:
Dict[int, int]: Possible scores for a target possibility
Note:
Already sorted in crescent order
"""
return {
hash(tuple(reversed(possibility))): score
for score, possibility in
enumerate(itertools.product([False, True], repeat=len(self.keys)))
}
def _build_regex(self, configurator: _Configurator) -> Pattern:
"""Builds a regex for a configurator target.
Args:
configurator (_Configurator): Configurator to build the regex for.
Returns:
Pattern: Regex for the configurator target.
"""
def parse_target(key: str, target: str) -> List[str]:
if isinstance(target, str):
if target == self.ANY:
return [self._template.get('any', {}).get(key, '.*')]
else:
return [target]
else:
return target
targets = (
f'({"|".join(parse_target(key, target))})'
for key, target in zip(self.keys, configurator['target'])
)
key_regex = '|'.join([r'\{%s\}' % k for k in self.keys])
return r'^' + re.sub(
key_regex, lambda _: next(targets), self._template['pattern']
) + r'$'
def _score_match(self, path: str, configurator: _Configurator) -> int:
"""Computes the score of a configurator if it matches the path.
Args:
path (str): Path to match the configurator with.
configurator (_Configurator): Configurator to score.
Returns:
int: Score of the configurator.
"""
regex = self._build_regex(configurator)
targets = configurator['target']
weights = self.weights
if re.match(regex, path):
raw_possibility = [group != self.ANY for group in targets]
possibility = hash(
tuple(key for _, key in sorted(zip(weights, raw_possibility)))
)
return self.scores[possibility]
else:
return self.NOT_FOUND
def __getitem__(self, path: str) -> Union[_Configurator, None]:
"""Finds the best configurator for a given path.
Args:
path (str): Path to find the best configurator for.
Returns:
Union[_Configurator, None]: Best configurator for the path.
"""
configurator = max(
self._configurators, key=lambda x: self._score_match(path, x)
)
return None if self._score_match(
path, configurator
) == self.NOT_FOUND else configurator
class _CustomizedFuncNode(Node):
def __init__(
self,
func: Callable,
inputs: Union[None, str, List[str], Dict[str, str]],
outputs: Union[None, str, List[str], Dict[str, str]],
name: str = None,
tags: Union[str, Iterable[str]] = None,
confirms: Union[str, List[str]] = None,
namespace: str = None
):
self._original_func = func
super().__init__(
func,
inputs,
outputs,
name=name,
tags=tags,
confirms=confirms,
namespace=namespace
)
@abstractproperty
def func(self) -> Callable:
pass
def run(self, inputs: Dict[str, Any] = None) -> Dict[str, Any]:
self._func = self.func
out = super().run(inputs)
self._func = self._original_func
return out
class _SlicerNode(_CustomizedFuncNode):
"""Splits partitioned datasets partitions and store them in a metadata.
Example:
Balance loads
>>> n = _SlicerNode(2, 'a', 'b', 'x')
>>> n
Node(nonefy, ['a'], 'b-slicer', 'x')
>>> dictionary = {'a': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4}}
>>> n.run(inputs=dictionary)
{'b-slicer': [['subpath/a'], ['subpath/b']]}
>>> dictionary = {'a': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4,
... 'subpath/c.txt': lambda: 5}}
>>> n.run(inputs=dictionary)
{'b-slicer': [['subpath/a', 'subpath/b'], ['subpath/c']]}
>>> n = _SlicerNode(3, 'a', 'b', 'x')
>>> dictionary = {'a': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4,
... 'subpath/c.txt': lambda: 5}}
>>> n.run(inputs=dictionary)
{'b-slicer': [['subpath/a'], ['subpath/b'], ['subpath/c']]}
With multiple inputs
>>> n = _SlicerNode(2, ['a', 'b'], 'c', 'x')
>>> dictionary = {'a': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4},
... 'b': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4}}
>>> n.run(inputs=dictionary)
{'c-slicer': [['subpath/a'], ['subpath/b']]}
Intersect partitions
>>> dictionary = {'a': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4},
... 'b': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4,
... 'subpath/c.txt': lambda: 5}}
>>> n.run(inputs=dictionary)
{'c-slicer': [['subpath/a'], ['subpath/b']]}
Using configurators
>>> n = _SlicerNode(2, ['a', 'b'], 'c', 'x', configurator='params:z')
>>> dictionary = {'a': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4},
... 'b': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4,
... 'subpath/c.txt': lambda: 5},
... 'params:z': {
... 'template': {'pattern': 'subpath/{x}'},
... 'configurators': [
... {'target': ['a'], 'cached': True,
... 'data': 1},
... {'target': ['*'], 'cached': False,
... 'data': 1}]}}
>>> n.run(inputs=dictionary)
{'c-slicer': [['subpath/b'], []]}
"""
SLICER_SUFFIX = '-slicer'
def __init__(
self,
slice_count: int,
partitioned_inputs: Union[None, str, List[str]],
partitioned_outputs: str,
name: str,
tags: Union[str, Iterable[str]] = None,
confirms: Union[str, List[str]] = None,
namespace: str = None,
filter: IsFunction[str] = truthify,
configurator: str = None,
):
self._partitioned_inputs = partitioned_inputs
self._slice_count = slice_count
self._original_output = partitioned_outputs
self._filter = filter
self._configurator = configurator
super().__init__(
func=nonefy,
inputs=tolist(partitioned_inputs) + optionaltolist(configurator),
outputs=self._add_slicer_suffix(partitioned_outputs),
name=name,
tags=tags,
confirms=confirms,
namespace=namespace,
)
@property
def slice_count(self) -> int:
return self._slice_count
@property
def original_output(self) -> str:
return self._original_output
@property
def json_output(self) -> str:
return self._outputs
def _copy(self, **overwrite_params: Any) -> _SlicerNode:
params = {
'partitioned_inputs': self._partitioned_inputs,
'partitioned_outputs': self._original_output,
'slice_count': self._slice_count,
'name': self._name,
'namespace': self._namespace,
'tags': self._tags,
'confirms': self._confirms,
'configurator': self._configurator,
'filter': self._filter,
}
params.update(overwrite_params)
return self.__class__(**params)
@classmethod
def _add_slicer_suffix(cls, string: str) -> str:
"""Returns the same string with `-slicer` suffix.
Args:
string (str)
Returns:
str
Example:
>>> _SlicerNode._add_slicer_suffix('test')
'test-slicer'
>>> _SlicerNode._add_slicer_suffix(
... 'test-slicer')
'test-slicer'
"""
return (
string if string.endswith(cls.SLICER_SUFFIX) else
f'{string}{cls.SLICER_SUFFIX}'
)
def _intersect_partitioneds(self,
partitioneds: List[_Partitioned]) -> List[str]:
"""Takes only the matching partitions (required for the input `zip`).
Args:
partitioneds (List[Partitioned]): partitioned dicionaries
Returns:
List[str]
"""
partitioned_sets = [{
get_filepath_without_extension(path)
for path in partitioned
} for partitioned in partitioneds]
return list(
reduce(
lambda inter, curr: inter.intersection(curr), partitioned_sets
)
)
def _calc_slice_bound(self, partition_count: int, slice_id: int) -> int:
"""Calculates the bounds of the subset of partitions for node.
Args:
partition_count (int): size of the partitions dictionary
slice_id (int): current slice id
Returns:
int
"""
return math.ceil((partition_count / self._slice_count) * slice_id)
def _slice_partitions(
self,
partitions: List[str],
slice_id: int,
) -> List[str]:
"""Returns a subset of the original partitions.
Args:
partitions (List[str])
slice_id (int)
Returns:
List[str]
"""
partition_count = len(partitions)
return partitions[self._calc_slice_bound(
partition_count,
slice_id,
):self._calc_slice_bound(
partition_count,
slice_id + 1,
)]
def _apply_filter(self, intersection: List[str]) -> List[str]:
return [p for p in intersection if self._filter(p)]
@classmethod
def _extract_args_part(cls, args: tuple,
nargs: int) -> Tuple[tuple, tuple]:
return args[:nargs], args[nargs:]
def _extract_args(
self, args: tuple
) -> Tuple[List[_Partitioned], _Configurators]:
if self._configurator is None:
return args, {}
else:
partitioneds, configurators = self._extract_args_part(args, -1)
return partitioneds, configurators[0]
def _filter_cached(
self, configurators: _Configurators, intersection: List[str]
) -> List[str]:
if self._configurator:
configurator_finder = ConfiguratorFinder(configurators)
return [
p for p in intersection
if not (configurator_finder[p] or {}).get('cached', False)
]
else:
return intersection
@property
def func(self) -> Callable:
def fn(*args: Any) -> List[List[str]]:
partitioneds, configurators = self._extract_args(args)
intersection = self._intersect_partitioneds(partitioneds)
intersection = self._apply_filter(intersection)
intersection = self._filter_cached(configurators, intersection)
intersection = sorted(intersection)
return [
self._slice_partitions(intersection, i)
for i in range(self._slice_count)
]
return fn
class _MultiNode(_CustomizedFuncNode):
"""Node to process a slice of a partitioned dataset.
Example:
>>> lbn = _SlicerNode(2, 'a', 'b', 'x')
>>> dictionary = {'a': {'subpath/a.txt': lambda: 3,
... 'subpath/b.txt': lambda: 4}}
>>> lb = lbn.run(inputs=dictionary)
>>> def fn(x: int) -> int: return x+10
>>> n = _MultiNode(slicer=lbn,
... func=fn,
... partitioned_inputs='a',
... partitioned_outputs='b',
... slice_id=0,
... slice_count=2,
... name='x')
>>> dictionary['b-slicer'] = [['subpath/a'], ['subpath/b']]
>>> n.run(inputs=dictionary)
{'b-slice-0': {'subpath/a': 13}}
Multiple inputs
>>> dictionary['b'] = {'subpath/a.txt': lambda: 4,
... 'subpath/b.txt': lambda: 5}
>>> def fn(x: int, y: int) -> list: return [x+10, y+20]
>>> n = _MultiNode(slicer=lbn,
... func=fn,
... partitioned_inputs=['a', 'b'],
... partitioned_outputs='c',
... slice_id=0,
... slice_count=2,
... name='x')
>>> n.run(inputs=dictionary)
{'c-slice-0': {'subpath/a': [13, 24]}}
Multiple outputs
>>> n = _MultiNode(slicer=lbn,
... func=fn,
... partitioned_inputs=['a', 'b'],
... partitioned_outputs=['c', 'd'],
... slice_id=0,
... slice_count=2,
... name='x')
>>> n.run(inputs=dictionary)
{'c-slice-0': {'subpath/a': 13}, 'd-slice-0': {'subpath/a': 24}}
Other inputs
>>> dictionary['e'] = 100
>>> def fn(x: int, y: int, e: int) -> list: return [x+e, y+e]
>>> n = _MultiNode(slicer=lbn,
... func=fn,
... partitioned_inputs=['a', 'b'],
... partitioned_outputs=['c', 'd'],
... other_inputs=['e'],
... slice_id=0,
... slice_count=2,
... name='x')
>>> n.run(inputs=dictionary)
{'c-slice-0': {'subpath/a': 103}, 'd-slice-0': {'subpath/a': 104}}
Configurators
>>> def fn(x: int, y: int, conf, e: int) -> list:
... return [x+e+conf['add'], y+e+conf['add']]
>>> n = _MultiNode(slicer=lbn,
... func=fn,
... partitioned_inputs=['a', 'b'],
... partitioned_outputs=['c', 'd'],
... other_inputs=['e'],
... slice_id=0,
... slice_count=1,
... name='x',
... configurator='param:conf')
>>> dictionary['param:conf'] = {
... 'template': {'pattern': 'subpath/{ab}'},
... 'configurators': [{'target': ['a'], 'data': {'add': 100}},
... {'target': ['*'], 'data': {'add': 20}}]}
>>> n.run(inputs=dictionary)
{'c-slice-0': {'subpath/a': 203}, 'd-slice-0': {'subpath/a': 204}}
"""
SLICE_SUFFIX = '-slice-'
def __init__(
self,
slicer: _SlicerNode,
func: Callable,
name: str,
partitioned_inputs: Union[str, List[str]],
partitioned_outputs: Union[None, str, List[str]],
slice_id: int,
slice_count: int,
other_inputs: Union[None, str, List[str]] = [],
tags: Union[str, Iterable[str]] = None,
confirms: Union[str, List[str]] = None,
namespace: str = None,
previous_nodes: List[_MultiNode] = [],
configurator: str = None,
):
self._slicer = slicer
self._partitioned_inputs = partitioned_inputs
self._other_inputs = other_inputs
self._partitioned_outputs = partitioned_outputs
self._slice_id = slice_id
self._slice_count = slice_count
self._configurator = configurator
self._point_to_matches(previous_nodes)
super().__init__(
func=func,
inputs=([self.slicer_output] + tolist(partitioned_inputs)
+ optionaltolist(configurator) + tolist(other_inputs)),
outputs=self.partitioned_outputs,
name=self._add_slice_suffix(name),
tags=tags,
confirms=confirms,
namespace=namespace,
)
def _calc_match_index(self, input_count: int) -> int:
"""Assings a previous output copy to an input of this layer.
Args:
input_count (int): number of slices for a previous output
Returns:
int
"""
return math.ceil(self.slice_id * (input_count / self.slice_count))
def _point_to_matches(self, previous_nodes: List[_MultiNode]) -> List[str]:
"""Points to the partitioned output copies from previous nodes.
Args:
previous_nodes (List[_MultiNode])
Returns:
List[str]
Note:
This is not necessary, it is only made for visualization purposes.
i.e. all nodes can point to the slice 0 if a previous node outputs
an input
"""
inputs = Counter(
inp for n in previous_nodes
for inp in n.original_partitioned_outputs
)
self._partitioned_inputs = [
self.add_slice_suffix(
input, self._calc_match_index(inputs[input])
) if inputs[input] > 0 else input
for input in self.original_partitioned_inputs
]
# required for inheritance
def _copy(self, **overwrite_params: Any) -> _MultiNode:
params = {
'slicer': self._slicer,
'func': self._func,
'partitioned_inputs': self._partitioned_inputs,
'other_inputs': self._other_inputs,
'partitioned_outputs': self._partitioned_outputs,
'slice_id': self._slice_id,
'slice_count': self._slice_count,
'name': self._name,
'namespace': self._namespace,
'tags': self._tags,
'confirms': self._confirms,
'configurator': self._configurator,
}
params.update(overwrite_params)
return self.__class__(**params)
def _validate_inputs(self, func: Any, inputs: Any):
try:
super()._validate_inputs(func, inputs)
except TypeError as e:
expected, passed = re.findall(r'(\[.*?\])', str(e))
if len(expected) > len(passed):
raise e
@property
def slicer_output(self) -> str:
"""Returns the load balancer json output.
Returns:
str
"""
return self._slicer.json_output
@classmethod
def add_slice_suffix(cls, string: Union[str, List[str]],
slice_id: int) -> Union[str, List[str]]:
"""Adds a `{SLICE_SUFFIX}{slice_id}` at the end of a string.
Args:
string (Union[str, List[str]])
slice_id (int)
Returns:
str
Example:
>>> _MultiNode.add_slice_suffix('test', 1)
'test-slice-1'
>>> _MultiNode.add_slice_suffix('test-slice-1', 1)
'test-slice-1'
"""
return firstorlist([
el if re.search(rf'{cls.SLICE_SUFFIX}\d+$', el) else
f'{el}{cls.SLICE_SUFFIX}{slice_id}' for el in tolist(string)
])
def _add_slice_suffix(
self,
string: Union[str, List[str]],
) -> Union[str, List[str]]:
return self.add_slice_suffix(string, self.slice_id)
@property
def slice_id(self) -> int:
"""Index of the current multinode.
Returns:
int
"""
return self._slice_id
@property
def slice_count(self) -> int:
"""Size of this multinode set.
Returns:
int
"""
return self._slice_count
@property
def original_partitioned_inputs(self) -> List[str]:
"""Partitioned inputs as list.
Returns:
List[str]
"""
return tolist(self._partitioned_inputs)
@property
def partitioned_inputs(self) -> List[str]:
return tolist(self._partitioned_inputs)
@property
def other_inputs(self) -> List[str]:
"""Regular inputs (provided to all multinodes).
Returns:
List[str]
"""
return tolist(self._other_inputs or [])
@property
def original_partitioned_outputs(self) -> List[str]:
"""Partitioned outputs passed in init.
Returns:
List[str]
"""
return tolist(self._partitioned_outputs)
@property
def partitioned_outputs(self) -> List[str]:
"""Partitioned outputs adding slice suffix.
Returns:
List[str]
"""
return [
self._add_slice_suffix(output)
for output in self.original_partitioned_outputs
]
@property
def outputs(self) -> List[str]:
"""Node outputs adding slice suffixes.
Returns:
List[str]
"""
return [
self._add_slice_suffix(output)
if output in self.original_partitioned_outputs else output
for output in super().outputs
]
def _intersect_partitioneds(
self, slice: Set[str], partitioneds: List[_Partitioned]
) -> List[_Partitioned]:
"""Takes only the matching partitions (required for the input `zip`).
Args:
partitioneds (List[Partitioned]): partitioned dicionaries
Returns:
List[Partitioned]
"""
return [{
path: partitioned[path]
for path in partitioned
if get_filepath_without_extension(path) in slice
} for partitioned in partitioneds]
def _get_slice(self, slices: List[List[str]]) -> Set[str]:
return set(slices[self.slice_id])
def _slice_inputs(
self, slices: List[List[str]], partitioneds: List[_Partitioned]
) -> List[_Partitioned]:
"""Returns the partitioned dictionaries sliced for this node.
Args:
partitioneds (List[Partitioned]): original partitioned dictionaries
Returns:
List[Partitioned]
"""
slice = self._get_slice(slices)
return self._intersect_partitioneds(slice, partitioneds)
@classmethod
def _extract_args_part(cls, args: List[Any],
nargs: int) -> Tuple[List[Any], List[Any]]:
return args[:nargs], args[nargs:]
def _extract_slices(self,
args: List[Any]) -> Tuple[List[List[str]], List[Any]]:
slices, args = self._extract_args_part(args, 1)
return slices[0], args
def _extract_partitioneds(
self, args: List[Any]
) -> Tuple[List[_Partitioned], List[Any]]:
return self._extract_args_part(args, len(self.partitioned_inputs))
def _extract_configurators(
self, args: List[Any]
) -> Tuple[_Configurators, List[Any]]:
if self._configurator is None:
return {}, args
else:
configurator, rest = self._extract_args_part(args, 1)
return configurator[0], rest
def _extract_args(
self, args: List[Any]
) -> Tuple[List[List[str]],
List[_Partitioned],
Union[None, Dict[str, _Configurator]],
List[Any],
]:
slices, args = self._extract_slices(args)
partitioneds, args = self._extract_partitioneds(args)
configurators, args = self._extract_configurators(args)
other_inputs = args
return slices, partitioneds, configurators, other_inputs
@property
def func(self) -> Callable:
"""Original `func`, but adding the partition loop.
Returns:
Callable
"""
@wraps(self._func)
def fn(*args: Any) -> Any:
slices, partitioneds, configurators, other_inputs =\
self._extract_args(args)
partitioneds = self._slice_inputs(slices, partitioneds)
if self._configurator:
configurator_finder = ConfiguratorFinder(configurators)
outputs = [dict() for _ in range(len(self.partitioned_outputs))]
if partitioneds[0]:
for partitions in zip(
*[partition.items() for partition in partitioneds]
):
# partitions[i][j]
# i = partitioned partitions
# j = key == 0, value == 1
partition = get_filepath_without_extension(
partitions[0][0]
)
self._logger.info(
f'Processing "{partition}" on "{self.name}"'
)
configurator = []
if self._configurator:
possible_configurator = configurator_finder[partition]
if possible_configurator is None:
self._logger.warning(
f'No configurator found for "{partition}"'
)
else:
target = possible_configurator['target']
configurator = [possible_configurator['data']]
self._logger.info(
f'Using configurator "{target}" for '
f'"{partition}"'
)
with ThreadPoolExecutor() as pool:
inputs = pool.map(lambda p: p[1](), partitions)
fn_return = self._original_func(
*inputs, *configurator, *other_inputs
)
if len(self.partitioned_outputs) > 1:
for i, _ in enumerate(self.partitioned_outputs):
outputs[i][partition] = fn_return[i]
else:
outputs[0][partition] = fn_return
return outputs
return fn
class _SynchronizationNode(_CustomizedFuncNode):
"""Barrier node to prevent multinode dependants to run out of order.
Example:
>>> lbn = _SlicerNode(2, 'a', 'b', 'x')
>>> def fn(x: int) -> int: x + 10
>>> mns = [_MultiNode(slicer=lbn, func=fn,
... partitioned_inputs='a', slice_count=2,
... partitioned_outputs='b', slice_id=i, name='x')
... for i in range(2)]
>>> n = _SynchronizationNode(multinodes=mns, name='x',
... partitioned_outputs='b',)
>>> dictionary = {mn.outputs[0]: {'subpath/a': lambda: 13,
... 'subpath/b': lambda: 14,
... 'subpath/c': lambda: 15} for mn in mns}
>>> n.run(inputs=dictionary)
{'b': {}}
"""
SYNCHRONIZATION_SUFFIX = '-synchronization'
def __init__(
self,
multinodes: List[_MultiNode],
partitioned_outputs: Union[str, List[str]],
name: str,
tags: Union[str, Iterable[str]] = None,
confirms: Union[str, List[str]] = None,
namespace: str = None,
):
self._multinodes = multinodes
self._partitioned_outputs = partitioned_outputs
super().__init__(
func=nonefy,
inputs=self._extract_inputs(multinodes),
outputs=tolist(partitioned_outputs),
name=self._add_synchronization_suffix(name),
tags=tags,
confirms=confirms,
namespace=namespace,
)
def _add_synchronization_suffix(self, string: str) -> str:
if not string.endswith(self.SYNCHRONIZATION_SUFFIX):
return f'{string}{self.SYNCHRONIZATION_SUFFIX}'
else:
return string
# required for inheritance
def _copy(self, **overwrite_params: Any) -> _SynchronizationNode:
params = {
'multinodes': self._multinodes,
'partitioned_outputs': self._partitioned_outputs,
'name': self._name,
'namespace': self._namespace,
'tags': self._tags,
'confirms': self._confirms,
}
params.update(overwrite_params)
return self.__class__(**params)
@classmethod
def _extract_inputs(cls, nodes: List[_MultiNode]) -> List[str]:
return [output for node in nodes for output in node.outputs]
@property
def func(self) -> Callable:
def fn(*args: Any) -> List[dict]:
return [dict() for _ in range(len(self.outputs))]
return fn
def _treat_optional_one_or_many(arg: Union[None, T, List[str]]) -> List[T]:
"""Treats argument that can be None, List, or another type as list.
Args:
arg (Union[None, T, List[str]])
Returns:
List[T]
Example:
>>> _treat_optional_one_or_many(3)
[3]
>>> _treat_optional_one_or_many([3])
[3]
>>> _treat_optional_one_or_many(None)
[]
"""
if arg is None:
return []
if isinstance(arg, list):
return arg
else:
return [arg]
def _sortnodes(nodes: Iterable[Node]) -> List[Node]:
"""Sorts a list of nodes by its name.
Args:
nodes (List[Node])
Returns:
List[Node]
Example:
>>> _sortnodes([node(min, 'a', 'b', name='def'),
... node(max, 'b', 'c', name='abc')])
[Node(max, 'b', 'c', 'abc'), Node(min, 'a', 'b', 'def')]
"""
return sorted(nodes, key=lambda x: x.name)
[docs]def multipipeline(
pipe: Pipeline,
partitioned_input: Union[str, List[str]],
name: str,
configurator: str = None,
tags: Union[str, Iterable[str]] = None,
confirms: Union[str, List[str]] = None,
namespace: str = None,
n_slices: int = MAX_NODES * MAX_WORKERS,
max_simultaneous_steps: int = None,
filter: IsFunction[str] = truthify,
) -> Pipeline:
"""Creates multiple pipelines to process partitioned data.
Multipipelines are the same as multinode, but instead of adding a
synhcronization node for each step, it creates small pipelines that work
like a multinode, with a synchronization only at the end of the pipeline.
This enables to process data in parallel, but without the need of waiting
for a multinode layer to finish its work.
See also:
:py:func:`kedro_partitioned.multinode.multinode`
Args:
pipe (Pipeline): Pipeline to be parallelized by multiple nodes.
partitioned_input (Union[str, List[str]]):
Name of the `PartitionedDataSet` used as input. If a list is
provided, it will work like a zip(partitions_a, ..., partitions_n)
configurator (str, optional): Name of partitioned parameters used as
input. e.g. 'param:configurators'
name (str, optional): Name prefix for the multiple nodes, and the name
of the post node. Defaults to None.
tags (Union[str, Iterable[str]], optional): List of tags for all nodes
generated by this function. Defaults to None.
confirms (Union[str, List[str]], optional): List of DataSets that
should be confirmed before the execution of the nodes.
Defaults to None.
namespace (str, optional): Namespace the nodes belong to.
Defaults to None.
n_slices (int): Number of multinodes to build.
Defaults to MAX_WORKERS + MAX_NODES
max_simultaneous_steps (int): Maximum number of slices created for
each branch. Defaults to None.
filter (IsFunction[str]): A function applied to each partition of
the partitioned inputs. If the function returns False, the
parttition won't be used.
Returns:
Pipeline
Example:
>>> sortnodes = lambda pipe: sorted(pipe.nodes, key=lambda n: n.name)
>>> funcpipe = Pipeline([node(min, ['a', 'b'], 'c', name='abc'),
... node(max, ['c', 'd'], ['e'], name='def')])
>>> sortnodes(multipipeline(
... funcpipe,
... ['a'],
... 'x',
... n_slices=2)) # doctest: +NORMALIZE_WHITESPACE
[Node(min, ['c-slicer', 'a', 'b'], ['c-slice-0'], 'abc-slice-0'),
Node(min, ['c-slicer', 'a', 'b'], ['c-slice-1'], 'abc-slice-1'),
Node(max, ['c-slicer', 'c-slice-0', 'd'], ['e-slice-0'],\
'def-slice-0'),
Node(max, ['c-slicer', 'c-slice-1', 'd'], ['e-slice-1'],\
'def-slice-1'),
Node(nonefy, ['a'], 'c-slicer', 'x'),
Node(nonefy, ['e-slice-0', 'e-slice-1'], ['c', 'e'],\
'x-synchronization')]
Max Simultaneous Steps:
This configuration defines the maximum number of steps per branch.
Check the example bellow for more details:
.. code-block:: python
max_simultaneous_steps = None
n_slices = 2
func = pipe([A->B, B->C, [C, D]->E])
B->D
output = pipe(A->B0, B0->C0, [C0, D0] -> E0)
A->B1 B1->C1 [C1, D1] -> D1
B0->D0
B1->D1
max_simultaneous_steps = 2
output: pipe(A->B0, B0->C0, [C0, D0] -> E0)
B0->D0
Warning:
every function must me declared considering partitioned inputs are
the first arguments of the function, the configurator (if present)
is the following argument, and the rest of the arguments are other
outputs.
Warning:
The configurator syntax prioritizes more specific targets rather
than generalist ones. However, if you have ambiguity
between configurators, a wrong configurator may be used e.g. if
you have a configurator with target ['a', 'b'], and another with
target ['a'], the first match will be used i.e. order is random or
list instance order driven.
"""
# sorts just to keep output consistency
partitioned_output = sorted(list(pipe.all_outputs()))
if max_simultaneous_steps is not None:
n_slices = max(
1,
math.floor(
max_simultaneous_steps
/ max(len(layer) for layer in pipe.grouped_nodes)
)
)
# because a multinode becomes multiple nodes
confirms = unique(_treat_optional_one_or_many(confirms))
tags = unique(_treat_optional_one_or_many(tags) + [name])
slicer = Pipeline([
_SlicerNode(
slice_count=n_slices,
partitioned_inputs=partitioned_input,
partitioned_outputs=tolist(partitioned_output)[0],
name=name,
tags=tags,
confirms=confirms,
namespace=namespace,
filter=filter,
configurator=configurator
)
])
sources = set(tolist(partitioned_input) + tolist(partitioned_output))
multinodes = Pipeline([])
for layer in pipe.grouped_nodes:
multinode_layer: List[_MultiNode] = []
for lnode in layer:
assert lnode._name, f'"{lnode}" name not defined'
partitioned, other = partition(
lambda x: x in sources, lnode.inputs
)
possible_configurator, other = partition(
lambda x: x == configurator, other
)
node_configurator = (
possible_configurator[0] if possible_configurator else None
)
for i in range(n_slices):
multinode_layer.append(
_MultiNode(
name=lnode.name,
func=lnode._func,
namespace=namespace,
other_inputs=other,
partitioned_inputs=partitioned,
partitioned_outputs=lnode.outputs,
slice_count=n_slices,
slice_id=i,
slicer=slicer.nodes[0],
confirms=unique(lnode.confirms + confirms),
tags=unique(list(lnode.tags) + tags),
previous_nodes=multinodes._nodes,
configurator=node_configurator,
)
)
multinodes = multinodes + Pipeline(multinode_layer)
synchronization = Pipeline([
_SynchronizationNode(
multinodes=_sortnodes(multinodes.grouped_nodes[-1]),
partitioned_outputs=partitioned_output,
name=name,
tags=tags,
confirms=confirms,
namespace=namespace,
)
])
return slicer + multinodes + synchronization
[docs]def multinode(
func: Callable[[Args[Any]], Union[Any, List[Any]]],
partitioned_input: Union[str, List[str]],
partitioned_output: Union[str, List[str]],
name: str,
configurator: str = None,
other_inputs: List[str] = [],
tags: Union[str, Iterable[str]] = None,
confirms: Union[str, List[str]] = None,
namespace: str = None,
n_slices: int = MAX_NODES * MAX_WORKERS,
filter: IsFunction[str] = truthify,
) -> Pipeline:
"""Creates multiple nodes to process partitioned data.
Multinodes are a way to implement step level parallelism. It is useful
for processing independent data in parallel, managed by pipeline runners.
For example, in Kedro, running the pipeline with ParallelRunner would
enable the steps generated by multinode to be run using multiple cpus. At
the same time, if you run this pipeline in a distributed context, you
could rely on a "DistributedRunner" or in another pipeline manager like
AzureML or Kubeflow, without having to change the code.
Multinodes work like the following flowchart:
.. code-block:: text
+--------------+
| Configurator |
| parameter |-+
+--------------+ |
| |
v |
+-------------+ +->+------------+ +--------------------+
| Slicer Node |--+->| Slice-0 |--+->| Synchronization |
| my | +->| my-slice-0 | | | my-synchronization |
+-------------+ | +------------+ | +--------------------+
^ | ... | |
| +->+------------+ | v
+-------------+ +->| Slice-1 |--| +-------------+
| Partitioned |--+->| my-slice-1 | | | Partitioned |
| input-ds | | +------------+ | | output-ds |
+-------------+ +->+------------+ | +-------------+
Start +->| Slice-n |--+ End
+->| my-slice-n |
+------------+
Nodes specification:
Slicer Node:
name: multinode name
inputs: partitioned inputs and configurator cached flags
outputs: json with a list of partitions for each slice
Slice:
name: multinode name + slice id
inputs: partitioned inputs, slicer json, configurator data
outputs: subset of the partitioned outputs
Synchronization:
name: multinode name + synchronization
inputs: subset of the partitioned outputs
outputs: partitioned outputs without data (synchronization only)
Args:
func (Callable[[Args[Any]], Union[str, List[Any]]]): Function executed
by each of the n-nodes. It takes n positional arguments, being the
first of them one partition from the `partitioned_input`
partitioned_input (Union[str, List[str]]):
Name of the `PartitionedDataSet` used as input. If a list is
provided, it will work like a zip(partitions_a, ..., partitions_n)
partitioned_output (Union[str, List[str]]):
Name of the `PartitionedDataSet` used as output by func nodes
configurator (str, optional): An input name of a partitioned parameter
dict. e.g. 'params:config'
name (str, optional): Name prefix for the multiple nodes, and the name
of the post node. Defaults to None.
other_inputs (List[str], optional):
Name of other inputs for func. Defaults to [].
tags (Union[str, Iterable[str]], optional): List of tags for all nodes
generated by this function. Defaults to None.
confirms (Union[str, List[str]], optional): List of DataSets that
should be confirmed before the execution of the nodes.
Defaults to None.
namespace (str, optional): Namespace the nodes belong to.
Defaults to None.
n_slices (int): Number of multinodes to build.
Defaults to MAX_WORKERS + MAX_NODES
filter (IsFunction[str], optional): Function to filter input partitions
Returns:
Pipeline
Example:
>>> sortnodes = lambda pipe: sorted(pipe.nodes, key=lambda n: n.name)
>>> sortnodes(multinode(
... func=max,
... partitioned_input='a',
... partitioned_output='b',
... other_inputs=['d'],
... n_slices=2,
... name='x',)) # doctest: +NORMALIZE_WHITESPACE
[Node(nonefy, ['a'], 'b-slicer', 'x'),
Node(max, ['b-slicer', 'a', 'd'], ['b-slice-0'], 'x-slice-0'),
Node(max, ['b-slicer', 'a', 'd'], ['b-slice-1'], 'x-slice-1'),
Node(nonefy, ['b-slice-0', 'b-slice-1'], ['b'], 'x-synchronization')]
Accepts multiple inputs (works like zip(*partitioneds)):
>>> sortnodes(multinode(
... func=max,
... partitioned_input=['a', 'b'],
... partitioned_output='c',
... n_slices=2,
... other_inputs=['d'],
... name='x')) # doctest: +NORMALIZE_WHITESPACE
[Node(nonefy, ['a', 'b'], 'c-slicer', 'x'),
Node(max, ['c-slicer', 'a', 'b', 'd'], ['c-slice-0'], 'x-slice-0'),
Node(max, ['c-slicer', 'a', 'b', 'd'], ['c-slice-1'], 'x-slice-1'),
Node(nonefy, ['c-slice-0', 'c-slice-1'], ['c'], 'x-synchronization')]
Accepts multiple outputs:
>>> sortnodes(multinode(
... func=max,
... partitioned_input='a',
... partitioned_output=['b', 'c'],
... n_slices=2,
... other_inputs=['d'],
... name='x')) # doctest: +NORMALIZE_WHITESPACE
[Node(nonefy, ['a'], 'b-slicer', 'x'),
Node(max, ['b-slicer', 'a', 'd'], ['b-slice-0', 'c-slice-0'],\
'x-slice-0'),
Node(max, ['b-slicer', 'a', 'd'], ['b-slice-1', 'c-slice-1'],\
'x-slice-1'),
Node(nonefy, ['b-slice-0', 'c-slice-0', 'b-slice-1', 'c-slice-1'],\
['b', 'c'], 'x-synchronization')]
Tags and namespaces are allowed:
>>> mn = multinode(
... func=max,
... partitioned_input='a',
... partitioned_output=['b', 'c'],
... name='x',
... n_slices=2,
... tags=['test_tag'],
... namespace='namespace',)
>>> sortnodes(mn) # doctest: +NORMALIZE_WHITESPACE
[Node(nonefy, ['a'], 'b-slicer', 'x'),
Node(max, ['b-slicer', 'a'], ['b-slice-0', 'c-slice-0'], 'x-slice-0'),
Node(max, ['b-slicer', 'a'], ['b-slice-1', 'c-slice-1'], 'x-slice-1'),
Node(nonefy, ['b-slice-0', 'c-slice-0', 'b-slice-1', 'c-slice-1'],\
['b', 'c'], 'x-synchronization')]
>>> all([n.tags == {'x', 'test_tag'} for n in mn.nodes])
True
>>> all([n.namespace == 'namespace' for n in mn.nodes])
True
Configurators
A configurator is a dict parameter declared in parameters yamls that
contains two sections: 'template' and 'configurators'. The 'template'
section is a dict that contains the partitions subpath pattern, and
its configurations. The 'configurators' section is a list of
configurators. Each configurator is a dict that contains a 'target'
list specifying replacements for the pattern, and a data entry that
is going to be inputted to the multinode.
.. code-block:: yaml
config:
template:
pattern: 'a-part-{part}'
# optional, overwrites '.*' as the regex when a
# v target is set to '*'
# any:
# part: '(a|b|c|d)'
# optional, specifies priority of each target if left
# v to right order is not correct
# hierarchy:
# - part
configurators:
-
target: # replaces pattern's {} from left to right order
-
- a
- b
# or a regex alternate syntax
# - a|b
cached: true # will not run
data:
setting_a: 'foo'
setting_b: 2
-
target:
- c
data:
setting_a: 'zzz'
setting_b: 4
-
target:
- '*'
data:
setting_a: 'bar'
setting_b: 1
In the example above, target ['a', 'b'] will be the configurator of the
partition 'a-part-a' and 'a-part-b', the configurator with target 'c'
will be the configurator of the partition 'a-part-c', and the
configurator with target '*' will be the configurator of all other
partitions.
Example:
>>> mn = multinode(
... func=max,
... partitioned_input='a',
... partitioned_output=['b', 'c'],
... name='x',
... n_slices=2,
... tags=['test_tag'],
... namespace='namespace',
... configurator='params:config')
>>> sortnodes(mn) # doctest: +NORMALIZE_WHITESPACE
[Node(nonefy, ['a', 'params:config'], 'b-slicer', 'x'),
Node(max, ['b-slicer', 'a', 'params:config'],\
['b-slice-0', 'c-slice-0'], 'x-slice-0'),
Node(max, ['b-slicer', 'a', 'params:config'],\
['b-slice-1', 'c-slice-1'], 'x-slice-1'),
Node(nonefy, ['b-slice-0', 'c-slice-0', 'b-slice-1', 'c-slice-1'],\
['b', 'c'], 'x-synchronization')]
Note:
the multinode name is also added as a tag into all nodes in order
to allow running the multinode with `kedro run --tag`.
Warning:
every function must me declared considering partitioned inputs are
the first arguments of the function, the configurator (if present)
is the following argument, and the rest of the arguments are other
outputs.
Warning:
The configurator syntax prioritizes more specific targets rather
than generalist ones. However, if you have ambiguity
between configurators, a wrong configurator may be used e.g. if
you have a configurator with target ['a', 'b'], and another with
target ['a'], the first match will be used i.e. order is random or
list instance order driven.
"""
pipe = Pipeline([
node(
func=func,
inputs=(
tolist(partitioned_input) + optionaltolist(configurator)
+ tolist(other_inputs)
),
outputs=partitioned_output,
name=name,
)
])
return multipipeline(
pipe=pipe,
partitioned_input=partitioned_input,
configurator=configurator,
confirms=confirms,
filter=filter,
max_simultaneous_steps=None,
n_slices=n_slices,
name=name,
namespace=namespace,
tags=tags,
)