GitOrigin-RevId: bd969d1339
tags/v1.3.0
| @@ -529,7 +529,11 @@ class InputNode(OpNode): | |||||
| @property | @property | ||||
| def device(self): | def device(self): | ||||
| return self.outputs[0].device | |||||
| var = self.outputs[0] | |||||
| if isinstance(var, VarNode): | |||||
| return var.device | |||||
| else: | |||||
| return var.comp_node | |||||
| @property | @property | ||||
| def dtype(self): | def dtype(self): | ||||
| @@ -36,6 +36,10 @@ def _str2device_type(type_str: str, allow_unspec: bool = True): | |||||
| return DeviceType.CPU | return DeviceType.CPU | ||||
| elif type_str == "GPU" or type_str == "CUDA": | elif type_str == "GPU" or type_str == "CUDA": | ||||
| return DeviceType.CUDA | return DeviceType.CUDA | ||||
| elif type_str == "CAMBRICON": | |||||
| return DeviceType.CAMBRICON | |||||
| elif type_str == "ATLAS": | |||||
| return DeviceType.ATLAS | |||||
| else: | else: | ||||
| assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu" | assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu" | ||||
| return DeviceType.UNSPEC | return DeviceType.UNSPEC | ||||
| @@ -65,6 +69,24 @@ def is_cuda_available() -> bool: | |||||
| return CompNode._get_device_count(t, False) > 0 | return CompNode._get_device_count(t, False) > 0 | ||||
| def is_cambricon_available() -> bool: | |||||
| """ | |||||
| Returns whether cambricon device is available on this system. | |||||
| """ | |||||
| t = _str2device_type("cambricon") | |||||
| return CompNode._get_device_count(t, False) > 0 | |||||
| def is_atlas_available() -> bool: | |||||
| """ | |||||
| Returns whether atlas device is available on this system. | |||||
| """ | |||||
| t = _str2device_type("atlas") | |||||
| return CompNode._get_device_count(t, False) > 0 | |||||
| def set_default_device(device: str = "xpux"): | def set_default_device(device: str = "xpux"): | ||||
| r""" | r""" | ||||
| Sets default computing node. | Sets default computing node. | ||||
| @@ -20,3 +20,30 @@ def tensorrt_runtime_opr(inputs, *, data: bytes = None): | |||||
| op = builtin.TensorRTRuntime(data, len(data)) | op = builtin.TensorRTRuntime(data, len(data)) | ||||
| # return sequence of outputs | # return sequence of outputs | ||||
| return apply(op, *inputs) | return apply(op, *inputs) | ||||
| def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable): | |||||
| r""" | |||||
| Load a serialized Cambricon model as a runtime operator in MegEngine. | |||||
| :param inputs: list of input tensors. | |||||
| :param data: the serialized Cambricon model. | |||||
| :param symbol: name of the function in Cambricon model. | |||||
| :param tensor_dim_mutable: whether the input tensors' shapes are mutable | |||||
| in ``cnrtModel_t``. | |||||
| """ | |||||
| op = builtin.CambriconRuntime(data, len(data), symbol, tensor_dim_mutable) | |||||
| return apply(op, *inputs) | |||||
| def atlas_runtime_opr(inputs, data): | |||||
| r""" | |||||
| Load a serialized Atlas model as a runtime operator in MegEngine. | |||||
| :param inputs: list of input tensors. | |||||
| :param data: the serialized Atlas model. | |||||
| """ | |||||
| op = builtin.AtlasRuntime(data, len(data)) | |||||
| return apply(op, *inputs) | |||||
| @@ -786,7 +786,11 @@ class trace: | |||||
| ) | ) | ||||
| output_names = output_names or self._output_names | output_names = output_names or self._output_names | ||||
| dumped_device = as_device("xpux") | |||||
| def dumped_device(info): | |||||
| device_name = info.device.logical_name | |||||
| if device_name[:3] in ("cpu", "gpu", "xpu"): | |||||
| return as_device("xpux") | |||||
| return info.device | |||||
| h2v = {} | h2v = {} | ||||
| graph = G.Graph() | graph = G.Graph() | ||||
| @@ -794,19 +798,21 @@ class trace: | |||||
| # apply graph_opt_level in dump | # apply graph_opt_level in dump | ||||
| if self._graph_opt_level is not None: | if self._graph_opt_level is not None: | ||||
| graph.options.graph_opt_level = self._graph_opt_level | graph.options.graph_opt_level = self._graph_opt_level | ||||
| for i, h in enumerate(self._arg_bindings): | for i, h in enumerate(self._arg_bindings): | ||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| h2v[h] = graph.make_h2d( | h2v[h] = graph.make_h2d( | ||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| device=dumped_device, | |||||
| device=dumped_device(info), | |||||
| shape=info.shape or (1,), | shape=info.shape or (1,), | ||||
| name=arg_names[i] if arg_names else None, | name=arg_names[i] if arg_names else None, | ||||
| ) | ) | ||||
| for k, h in self._kwarg_bindings.items(): | for k, h in self._kwarg_bindings.items(): | ||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| h2v[h] = graph.make_h2d( | h2v[h] = graph.make_h2d( | ||||
| dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k | |||||
| dtype=info.dtype, | |||||
| device=dumped_device(info), | |||||
| shape=info.shape or (1,), | |||||
| name=k, | |||||
| ) | ) | ||||
| for op, ihandles, ohandles in self._seq: | for op, ihandles, ohandles in self._seq: | ||||
| @@ -833,7 +839,7 @@ class trace: | |||||
| h2v[h] = graph.make_const( | h2v[h] = graph.make_const( | ||||
| info.bound_data.numpy(), | info.bound_data.numpy(), | ||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| device=dumped_device, | |||||
| device=dumped_device(info), | |||||
| name=info.name, | name=info.name, | ||||
| ) | ) | ||||
| ivars.append(h2v[h]) | ivars.append(h2v[h]) | ||||
| @@ -9,7 +9,11 @@ | |||||
| # pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
| import numpy as np | import numpy as np | ||||
| from ..functional.external import tensorrt_runtime_opr | |||||
| from ..functional.external import ( | |||||
| atlas_runtime_opr, | |||||
| cambricon_runtime_opr, | |||||
| tensorrt_runtime_opr, | |||||
| ) | |||||
| from .module import Module | from .module import Module | ||||
| @@ -33,3 +37,52 @@ class TensorrtRuntimeSubgraph(Module): | |||||
| def forward(self, *inputs): | def forward(self, *inputs): | ||||
| return tensorrt_runtime_opr(inputs, data=self._data) | return tensorrt_runtime_opr(inputs, data=self._data) | ||||
| class CambriconRuntimeSubgraph(Module): | |||||
| r"""Load a serialized CambriconRuntime subgraph. | |||||
| See :func:`~.cambricon_runtime_opr` for more details. | |||||
| """ | |||||
| def __init__(self, data, symbol, tensor_dim_mutable, **kwargs): | |||||
| super(CambriconRuntimeSubgraph, self).__init__(**kwargs) | |||||
| self._data = data | |||||
| self.symbol = symbol | |||||
| self.tensor_dim_mutable = tensor_dim_mutable | |||||
| @property | |||||
| def data(self): | |||||
| return self._data | |||||
| @data.setter | |||||
| def data(self, val): | |||||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||||
| def forward(self, *inputs): | |||||
| outputs = cambricon_runtime_opr( | |||||
| inputs, self._data, self.symbol, self.tensor_dim_mutable | |||||
| ) | |||||
| return outputs | |||||
| class AtlasRuntimeSubgraph(Module): | |||||
| r"""Load a serialized AtlasRuntime subgraph. | |||||
| See :func:`~.atlas_runtime_opr` for more details. | |||||
| """ | |||||
| def __init__(self, data, **kwargs): | |||||
| super(AtlasRuntimeSubgraph, self).__init__(**kwargs) | |||||
| self._data = data | |||||
| @property | |||||
| def data(self): | |||||
| return self._data | |||||
| @data.setter | |||||
| def data(self, val): | |||||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||||
| def forward(self, *inputs): | |||||
| return atlas_runtime_opr(inputs, data=self._data) | |||||
| @@ -427,8 +427,9 @@ class GraphInference: | |||||
| list(self._inp_dict.keys()), list(inputs.keys()) | list(self._inp_dict.keys()), list(inputs.keys()) | ||||
| ) | ) | ||||
| for key in self._inp_dict: | for key in self._inp_dict: | ||||
| self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor()) | |||||
| self._inp_dict[key].set_value( | |||||
| Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor() | |||||
| ) | |||||
| self._func.execute() | self._func.execute() | ||||
| self._func.wait() | self._func.wait() | ||||
| @@ -171,6 +171,8 @@ void init_common(py::module m) { | |||||
| .value("UNSPEC", CompNode::DeviceType::UNSPEC) | .value("UNSPEC", CompNode::DeviceType::UNSPEC) | ||||
| .value("CUDA", CompNode::DeviceType::CUDA) | .value("CUDA", CompNode::DeviceType::CUDA) | ||||
| .value("CPU", CompNode::DeviceType::CPU) | .value("CPU", CompNode::DeviceType::CPU) | ||||
| .value("CAMBRICON", CompNode::DeviceType::CAMBRICON) | |||||
| .value("ATLAS", CompNode::DeviceType::ATLAS) | |||||
| .value("MULTITHREAD", CompNode::DeviceType::MULTITHREAD) | .value("MULTITHREAD", CompNode::DeviceType::MULTITHREAD) | ||||
| .value("MAX_DEVICE_ID", CompNode::DeviceType::MAX_DEVICE_ID); | .value("MAX_DEVICE_ID", CompNode::DeviceType::MAX_DEVICE_ID); | ||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/tensorrt_runtime.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "../op_trait.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #if MGB_ATLAS | |||||
| #include "megbrain/opr/atlas_runtime_op.h" | |||||
| namespace mgb::imperative { | |||||
| namespace { | |||||
| namespace atlas_runtime { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const AtlasRuntime&>(def); | |||||
| SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::AtlasRuntimeOpr::make(op.buf.c_str(), op.buf_size, | |||||
| symbol_var_inputs, config); | |||||
| } | |||||
| OP_TRAIT_REG(AtlasRuntime, AtlasRuntime) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| } // namespace atlas_runtime | |||||
| } // namespace | |||||
| } // namespace mgb::imperative | |||||
| #endif | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/tensorrt_runtime.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "../op_trait.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #if MGB_CAMBRICON | |||||
| #include "megbrain/cambricon/cambricon_runtime_opr.h" | |||||
| namespace mgb::imperative { | |||||
| namespace { | |||||
| namespace cambricon_runtime { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const CambriconRuntime&>(def); | |||||
| SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::CambriconRuntimeOpr::make(op.buf.c_str(), op.buf_size, | |||||
| op.symbol, symbol_var_inputs, | |||||
| op.tensor_dim_mutable, config); | |||||
| } | |||||
| OP_TRAIT_REG(CambriconRuntime, CambriconRuntime) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| } // namespace cambricon_runtime | |||||
| } // namespace | |||||
| } // namespace mgb::imperative | |||||
| #endif | |||||
| @@ -266,6 +266,22 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> { | |||||
| ); | ); | ||||
| } | } | ||||
| def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> { | |||||
| let extraArguments = (ins | |||||
| MgbStringAttr:$buf, | |||||
| MgbSizeTAddr:$buf_size | |||||
| ); | |||||
| } | |||||
| def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { | |||||
| let extraArguments = (ins | |||||
| MgbStringAttr:$buf, | |||||
| MgbSizeTAddr:$buf_size, | |||||
| MgbStringAttr:$symbol, | |||||
| MgbBoolAttr:$tensor_dim_mutable | |||||
| ); | |||||
| } | |||||
| def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | ||||
| #endif // MGB_OPS | #endif // MGB_OPS | ||||