import pickle
import pathlib
import luigi
import luigi.tools.deps
import d6tcollect
import d6tflow.targets
import d6tflow.settings as settings
from d6tflow.cache import data as cache
import d6tflow.cache
def _taskpipeoperation(task, fun, funargs=None):
pipe = task.get_pipe()
fun = getattr(pipe, fun)
funargs = {} if funargs is None else funargs
return fun(**funargs)
[docs]class TaskData(luigi.Task):
"""
Task which has data as input and output
Args:
target_class (obj): target data format
target_ext (str): file extension
persist (list): list of string to identify data
data (dict): data container for all outputs
"""
target_class = d6tflow.targets.DataTarget
target_ext = 'ext'
persist = ['data']
metadata = None
def __init__(self, *args, path=None, flows=None, **kwargs):
kwargs_ = {k: v for k, v in kwargs.items(
) if k in self.get_param_names(include_significant=True)}
super().__init__(*args, **kwargs_)
# Check if Child Has Path Var
self.path = getattr(self, 'path', path)
# Alias persists
self.persist = getattr(self, 'persists', self.persist)
# Flow
self.flows = flows
[docs] @classmethod
def get_param_values(cls, params, args, kwargs):
kwargs_ = {k: v for k, v in kwargs.items(
) if k in cls.get_param_names(include_significant=True)}
return super(TaskData, cls).get_param_values(params, args, kwargs_)
[docs] def reset(self, confirm=True):
"""
Reset a task, eg by deleting output file
"""
return self.invalidate(confirm)
[docs] def invalidate(self, confirm=True):
"""
Reset a task, eg by deleting output file
"""
if confirm:
c = input(
'Confirm invalidating task: {} (y/n). PS You can disable this message by passing confirm=False'.format(
self.__class__.__qualname__))
else:
c = 'y'
if c == 'y': # and self.complete():
if self.persist == ['data']: # 1 data shortcut
self.output().invalidate()
else:
[t.invalidate() for t in self.output().values()]
return True
@d6tcollect._collectClass
def complete(self, cascade=True):
"""
Check if a task is complete by checking if output exists, eg if output file exists
"""
complete = super().complete()
if d6tflow.settings.check_dependencies and cascade and not getattr(self, 'external', False):
complete = complete and all(
[t.complete() for t in luigi.task.flatten(self.requires())])
return complete
# Private Get Path Function
def _getpath(self, k, subdir=True, check_pipe=False):
# Get Output dir
# Check if using d6tpipe
if check_pipe and hasattr(self, 'pipename'):
import d6tflow.pipes
dirpath = d6tflow.pipes.get_dirpath(self.pipename)
# Class has set Path
elif self.path is not None:
dirpath = pathlib.Path(self.path)
# Default Settings
else:
dirpath = settings.dirpath
# Add Group
if hasattr(self, 'task_group'):
dirpath = dirpath / f"/group={getattr(self, 'task_group')}"
# Get Path
tidroot = getattr(self, 'target_dir', self.task_id.split('_')[0])
fname = '{}-{}'.format(self.task_id, k) if (settings.save_with_param and getattr(
self, 'save_attrib', True)) else '{}'.format(k)
fname += '.{}'.format(self.target_ext)
if subdir:
path = dirpath / tidroot / fname
else:
path = dirpath / fname
return path
[docs] def output(self):
"""
Similar to luigi task output
"""
save_ = getattr(self, 'persist', [])
output = dict([(k, self.target_class(self._getpath(k, check_pipe=True)))
for k in save_])
if self.persist == ['data']: # 1 data shortcut
output = output['data']
return output
[docs] def outputLoad(self, keys=None, as_dict=False, cached=False):
"""
Load all or several outputs from task
Args:
keys (list): list of data to load
as_dict (bool): cache data in memory
cached (bool): cache data in memory
Returns: list or dict of all task output
"""
if not self.complete():
raise RuntimeError(
'Cannot load, task not complete, run flow first')
# Check Keys is not empty
keys = self.persist if keys is None else keys
# Not List
if type(keys) is not list:
if not keys in self.persist:
raise IndexError('Key name does not match')
else:
for key in keys:
if not key in self.persist:
raise IndexError('Key name does not match')
if self.persist == ['data']: # 1 data shortcut
persist_data = self.output().load()
return persist_data
# Get Data
data = {k: v.load(cached)
for k, v in self.output().items() if k in keys}
# Return As List
if not as_dict:
data = list(data.values())
# If Keys is not a list
if type(keys) is not list:
data = data[0]
# Return
return data
[docs] def save(self, data, **kwargs):
"""
Persist data to target
Args:
data (dict): data to save. keys are the self.persist keys and values is data
"""
if self.persist == ['data']: # 1 data shortcut
self.output().save(data, **kwargs)
else:
targets = self.output()
if not set(data.keys()) == set(targets.keys()):
raise ValueError(
'Save dictionary needs to consistent with Task.persist')
for k, v in data.items():
targets[k].save(v, **kwargs)
def _get_meta_path(self, task):
# Get Meta Path
meta_path = task._getpath('meta').with_suffix('.pickle')
meta_path.parent.mkdir(exist_ok=True, parents=True)
return meta_path
@d6tcollect._collectClass
def get_pipename(self):
"""
Get associated pipe name
"""
return getattr(self, 'pipename', d6tflow.cache.pipe_default_name)
[docs] def get_pipe(self):
"""
Get associated pipe object
"""
import d6tflow.pipes
return d6tflow.pipes.get_pipe(self.get_pipename())
[docs] def pull(self, **kwargs):
"""
Pull files from data repo
"""
return _taskpipeoperation(self, 'pull', **kwargs)
[docs] def pull_preview(self, **kwargs):
"""
Preview pull files from data repo
"""
return _taskpipeoperation(self, 'pull_preview', **kwargs)
[docs] def push(self, **kwargs):
"""
Push files to data repo
"""
return _taskpipeoperation(self, 'push', **kwargs)
[docs] def push_preview(self, **kwargs):
"""
Preview push files to data repo
"""
return _taskpipeoperation(self, 'push_preview', **kwargs)
[docs]class TaskCache(TaskData):
"""
Task which saves to cache
"""
target_class = d6tflow.targets.CacheTarget
target_ext = 'cache'
[docs]class TaskCachePandas(TaskData):
"""
Task which saves to cache pandas dataframes
"""
target_class = d6tflow.targets.PdCacheTarget
target_ext = 'cache'
[docs]class TaskJson(TaskData):
"""
Task which saves to json
"""
target_class = d6tflow.targets.JsonTarget
target_ext = 'json'
[docs]class TaskPickle(TaskData):
"""
Task which saves to pickle
"""
target_class = d6tflow.targets.PickleTarget
target_ext = 'pkl'
[docs]class TaskCSVPandas(TaskData):
"""
Task which saves to CSV
"""
target_class = d6tflow.targets.CSVPandasTarget
target_ext = 'csv'
[docs]class TaskCSVGZPandas(TaskData):
"""
Task which saves to CSV
"""
target_class = d6tflow.targets.CSVGZPandasTarget
target_ext = 'csv.gz'
[docs]class TaskExcelPandas(TaskData):
"""
Task which saves to Excel
"""
target_class = d6tflow.targets.ExcelPandasTarget
target_ext = 'xlsx'
[docs]class TaskPqPandas(TaskData):
"""
Task which saves to parquet
"""
target_class = d6tflow.targets.PqPandasTarget
target_ext = 'parquet'
[docs]class TaskAggregator(luigi.Task):
"""
Task which yields other tasks
NB: Use this function by implementing `run()` which should do nothing but yield other tasks
example::
class TaskCollector(d6tflow.tasks.TaskAggregator):
def run(self):
yield Task1()
yield Task2()
"""
[docs] def reset(self, confirm=True):
return self.invalidate(confirm=confirm)
[docs] def invalidate(self, confirm=True):
[t.invalidate(confirm) for t in self.run()]
[docs] def complete(self, cascade=True):
return all([t.complete(cascade) for t in self.run()])
[docs] def output(self):
return [t.output() for t in self.run()]
[docs] def outputLoad(self, keys=None, as_dict=False, cached=False):
return [t.outputLoad(keys, as_dict, cached) for t in self.run()]