|
- #!/usr/bin/env python3
- # -*- coding: utf-8; mode: python; tab-width: 4; indent-tabs-mode: nil -*-
- # vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 fileencoding=utf-8
-
- import os
- import sys
- import pkgutil
- import multiprocessing
- import multiprocessing.queues
- # from multiprocessing.managers import BaseManager
-
- from logger import (D, I, E, C, W)
-
- from tasks.itask import ITask, TaskResult
-
-
- SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
-
-
- class Dag():
-
- (UNVISITED, VISITED, SUB_VISITED) = range(3)
-
- def __init__(self):
- # Load all the tasks
- self._tasks = {}
- self._roots = []
- self._queue = multiprocessing.Queue()
- self._finish_num = 0
-
- # self._manager = BaseManager()
- # self._manager.register('ITask', ITask)
- # self._manager.register('TaskResult', TaskResult)
- # self._manager.start()
-
- def load_tasks(self, *pkg_dir_list):
- all_classes = {}
- for pkgdir in pkg_dir_list:
- all_classes.update(self._load_classes(pkgdir, ITask))
-
- for task in all_classes.values():
- t = task()
- self._tasks[t.name()] = t
- # Find out all the start nodes, which has no dependent nodes.
- if not t.prev_tasks():
- self._roots.append(t.name())
-
- def _load_classes(self, pkg_dir, base_class):
- """Load class from a package directory or a list of directories, filter all the sub-class of specific base class.
- NOTE: When you create a sub-class, please use the entire class name when import the super class.
- NOTE: Please guarantee the name of class or sub-class is unique, otherwise some sub-classes may be overlaied."""
- import importlib
- pkgdirs = list()
- if isinstance(pkg_dir, list):
- for d in pkg_dir:
- if os.path.isdir(d):
- pkgdirs.append(os.path.abspath(d))
- else:
- W("Path<{}> is not a valid directory!".format(d))
- elif isinstance(pkg_dir, str) and os.path.isdir(pkg_dir):
- pkgdirs.append(os.path.abspath(pkg_dir))
- else:
- raise Exception(f"Invalid package directory path: {pkg_dir}!")
-
- pkg = ''
-
- all_classes = dict()
-
- for pkgdir in pkgdirs:
- # Check whether the pkg_dir is a package directory
- if os.path.isfile(os.path.join(pkgdir, '__init__.py')):
- sys.path.append(os.path.dirname(pkgdir))
- pkg = os.path.basename(pkgdir)
- else:
- sys.path.append(pkgdir)
-
- for _, name, _ in pkgutil.iter_modules([pkgdir]):
- try:
- imported_mod = importlib.import_module('.'.join([pkg, name]))
- except Exception as e:
- W(f"Failed to load package: {pkg}.{name} in {pkgdir}")
-
- def _collect_subcls(subcls):
- clses = dict()
- for skls in subcls:
- clses[skls.__name__] = skls
- if skls.__subclasses__():
- clses.update(_collect_subcls(skls.__subclasses__()))
- return clses
-
- all_classes.update(_collect_subcls(base_class.__subclasses__()))
-
- return all_classes
-
- def gen(self):
- if self._check_ring():
- raise Exception("Ring exists in current DAG!")
-
- def start(self):
- def is_ready(task: ITask) -> bool:
- for dep in task.prev_tasks():
- if self._tasks[dep].status() != ITask.DONE:
- return False
- return True
-
- for root in self._roots:
- self._create_task_proc(self._tasks[root], [])
-
- while True:
- task_result = self._queue.get()
- self._tasks[task_result.name].update(task_result)
-
- if task_result.status != ITask.DONE:
- E(f"Task({task_result.name}) failed!")
-
- # TODO: Check the terminal condition and jump out of this loop
-
- for next_task in self._tasks[task_result.name].next_tasks():
- if is_ready(self._tasks[next_task]):
- param_list = []
- for dep in self._tasks[next_task].prev_tasks():
- param_list.append(self._tasks[dep].output())
- self._create_task_proc(self._tasks[next_task], param_list)
-
- def _create_task_proc(self, task, params: list) -> None:
- def run_task(task, queue: multiprocessing.Queue, params) -> None:
- print(os.getpid(), task.name())
- task_result = task.run(params)
- queue.put(task_result)
- p = multiprocessing.Process(target=run_task, name=task.name(), args=(task, self._queue, params))
- p.daemon = True
- p.start()
-
-
- def _check_ring(self):
- """If has ring in this DAG, return True, otherwise return False"""
- return False
-
- def _dfs(self, node):
- # node.next_tasks()
- pass
-
-
- if __name__ == '__main__':
- g = Dag()
- g.load_tasks(os.path.join(SCRIPT_DIR, 'pipeline'))
- g.start()
|