# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Providing multi process compile with json""" import json import math import os import subprocess import sys from multiprocessing import Pool def _compiletask(platform, *jsons): """ compile func called in single process Parameters: platform: str. AKG platform or TBE platform *jsons: str. json str contain kernel info, suitable for json compile api """ if platform == "AKG": p = __import__("_akg", globals(), locals(), ['ms'], 0) func = getattr(p.ms, "compilewithjson") for json_item in jsons: res = func(json_item) if not res: raise ValueError("Compile error") if platform == "TBE": tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py") for json_item in jsons: res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True) if res.returncode != 0: raise ValueError("Tbe compile error") def compilekernelparallel(jsons, process, waitime): """ compile kernel use multi processes Parameters: jsons: list. json str list contain kernel info process: int. processes num waittime: int. max time the function blocked """ if not isinstance(jsons, list): raise ValueError("jsons must be a list") if not isinstance(process, int): raise ValueError("process must be a num") if not isinstance(waitime, int): raise ValueError("waittime must be a num") jsons_akg = [] jsons_tbe = [] for json_ in jsons: j = json.loads(json_) if j["platform"] == "TBE": jsons_tbe.append(json_) continue if j["platform"] == "AKG": jsons_akg.append(json_) continue raise RuntimeError( "not support this platform {0}".format(j["platform"])) if jsons_akg: process_akg = math.floor(len(jsons)/len(jsons_akg)*process) else: process_akg = 0 if process_akg == 0 and jsons_akg: process_akg = 1 process_tbe = process-process_akg if process_tbe == 0 and jsons_tbe: process_tbe = 1 raise RuntimeWarning("we add a process for compile more operator") args = [[] for _ in range(process_akg+process_tbe)] args_lens = len(args) for p in range(args_lens): if p < process_tbe: args[p].append("TBE") else: args[p].append("AKG") jsons_tbe_lens = len(jsons_tbe) for p in range(jsons_tbe_lens): args[p % process_tbe].append(jsons_tbe[p]) jsons_akg_lens = len(jsons_akg) for p in range(jsons_akg_lens): args[process-p % process_akg-1].append(jsons_akg[p]) for p in range(args_lens): args[p] = tuple(args[p]) with Pool(processes=process) as pool: res = pool.starmap_async(_compiletask, args) res.get(timeout=waitime) return True