GitOrigin-RevId: bd969d1339
tags/v1.3.0
| @@ -529,7 +529,11 @@ class InputNode(OpNode): | |||
| @property | |||
| def device(self): | |||
| return self.outputs[0].device | |||
| var = self.outputs[0] | |||
| if isinstance(var, VarNode): | |||
| return var.device | |||
| else: | |||
| return var.comp_node | |||
| @property | |||
| def dtype(self): | |||
| @@ -36,6 +36,10 @@ def _str2device_type(type_str: str, allow_unspec: bool = True): | |||
| return DeviceType.CPU | |||
| elif type_str == "GPU" or type_str == "CUDA": | |||
| return DeviceType.CUDA | |||
| elif type_str == "CAMBRICON": | |||
| return DeviceType.CAMBRICON | |||
| elif type_str == "ATLAS": | |||
| return DeviceType.ATLAS | |||
| else: | |||
| assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu" | |||
| return DeviceType.UNSPEC | |||
| @@ -65,6 +69,24 @@ def is_cuda_available() -> bool: | |||
| return CompNode._get_device_count(t, False) > 0 | |||
| def is_cambricon_available() -> bool: | |||
| """ | |||
| Returns whether cambricon device is available on this system. | |||
| """ | |||
| t = _str2device_type("cambricon") | |||
| return CompNode._get_device_count(t, False) > 0 | |||
| def is_atlas_available() -> bool: | |||
| """ | |||
| Returns whether atlas device is available on this system. | |||
| """ | |||
| t = _str2device_type("atlas") | |||
| return CompNode._get_device_count(t, False) > 0 | |||
| def set_default_device(device: str = "xpux"): | |||
| r""" | |||
| Sets default computing node. | |||
| @@ -20,3 +20,30 @@ def tensorrt_runtime_opr(inputs, *, data: bytes = None): | |||
| op = builtin.TensorRTRuntime(data, len(data)) | |||
| # return sequence of outputs | |||
| return apply(op, *inputs) | |||
| def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable): | |||
| r""" | |||
| Load a serialized Cambricon model as a runtime operator in MegEngine. | |||
| :param inputs: list of input tensors. | |||
| :param data: the serialized Cambricon model. | |||
| :param symbol: name of the function in Cambricon model. | |||
| :param tensor_dim_mutable: whether the input tensors' shapes are mutable | |||
| in ``cnrtModel_t``. | |||
| """ | |||
| op = builtin.CambriconRuntime(data, len(data), symbol, tensor_dim_mutable) | |||
| return apply(op, *inputs) | |||
| def atlas_runtime_opr(inputs, data): | |||
| r""" | |||
| Load a serialized Atlas model as a runtime operator in MegEngine. | |||
| :param inputs: list of input tensors. | |||
| :param data: the serialized Atlas model. | |||
| """ | |||
| op = builtin.AtlasRuntime(data, len(data)) | |||
| return apply(op, *inputs) | |||
| @@ -786,7 +786,11 @@ class trace: | |||
| ) | |||
| output_names = output_names or self._output_names | |||
| dumped_device = as_device("xpux") | |||
| def dumped_device(info): | |||
| device_name = info.device.logical_name | |||
| if device_name[:3] in ("cpu", "gpu", "xpu"): | |||
| return as_device("xpux") | |||
| return info.device | |||
| h2v = {} | |||
| graph = G.Graph() | |||
| @@ -794,19 +798,21 @@ class trace: | |||
| # apply graph_opt_level in dump | |||
| if self._graph_opt_level is not None: | |||
| graph.options.graph_opt_level = self._graph_opt_level | |||
| for i, h in enumerate(self._arg_bindings): | |||
| info = self._tinfo[h] | |||
| h2v[h] = graph.make_h2d( | |||
| dtype=info.dtype, | |||
| device=dumped_device, | |||
| device=dumped_device(info), | |||
| shape=info.shape or (1,), | |||
| name=arg_names[i] if arg_names else None, | |||
| ) | |||
| for k, h in self._kwarg_bindings.items(): | |||
| info = self._tinfo[h] | |||
| h2v[h] = graph.make_h2d( | |||
| dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k | |||
| dtype=info.dtype, | |||
| device=dumped_device(info), | |||
| shape=info.shape or (1,), | |||
| name=k, | |||
| ) | |||
| for op, ihandles, ohandles in self._seq: | |||
| @@ -833,7 +839,7 @@ class trace: | |||
| h2v[h] = graph.make_const( | |||
| info.bound_data.numpy(), | |||
| dtype=info.dtype, | |||
| device=dumped_device, | |||
| device=dumped_device(info), | |||
| name=info.name, | |||
| ) | |||
| ivars.append(h2v[h]) | |||
| @@ -9,7 +9,11 @@ | |||
| # pylint: disable=redefined-builtin | |||
| import numpy as np | |||
| from ..functional.external import tensorrt_runtime_opr | |||
| from ..functional.external import ( | |||
| atlas_runtime_opr, | |||
| cambricon_runtime_opr, | |||
| tensorrt_runtime_opr, | |||
| ) | |||
| from .module import Module | |||
| @@ -33,3 +37,52 @@ class TensorrtRuntimeSubgraph(Module): | |||
| def forward(self, *inputs): | |||
| return tensorrt_runtime_opr(inputs, data=self._data) | |||
| class CambriconRuntimeSubgraph(Module): | |||
| r"""Load a serialized CambriconRuntime subgraph. | |||
| See :func:`~.cambricon_runtime_opr` for more details. | |||
| """ | |||
| def __init__(self, data, symbol, tensor_dim_mutable, **kwargs): | |||
| super(CambriconRuntimeSubgraph, self).__init__(**kwargs) | |||
| self._data = data | |||
| self.symbol = symbol | |||
| self.tensor_dim_mutable = tensor_dim_mutable | |||
| @property | |||
| def data(self): | |||
| return self._data | |||
| @data.setter | |||
| def data(self, val): | |||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||
| def forward(self, *inputs): | |||
| outputs = cambricon_runtime_opr( | |||
| inputs, self._data, self.symbol, self.tensor_dim_mutable | |||
| ) | |||
| return outputs | |||
| class AtlasRuntimeSubgraph(Module): | |||
| r"""Load a serialized AtlasRuntime subgraph. | |||
| See :func:`~.atlas_runtime_opr` for more details. | |||
| """ | |||
| def __init__(self, data, **kwargs): | |||
| super(AtlasRuntimeSubgraph, self).__init__(**kwargs) | |||
| self._data = data | |||
| @property | |||
| def data(self): | |||
| return self._data | |||
| @data.setter | |||
| def data(self, val): | |||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||
| def forward(self, *inputs): | |||
| return atlas_runtime_opr(inputs, data=self._data) | |||
| @@ -427,8 +427,9 @@ class GraphInference: | |||
| list(self._inp_dict.keys()), list(inputs.keys()) | |||
| ) | |||
| for key in self._inp_dict: | |||
| self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor()) | |||
| self._inp_dict[key].set_value( | |||
| Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor() | |||
| ) | |||
| self._func.execute() | |||
| self._func.wait() | |||
| @@ -171,6 +171,8 @@ void init_common(py::module m) { | |||
| .value("UNSPEC", CompNode::DeviceType::UNSPEC) | |||
| .value("CUDA", CompNode::DeviceType::CUDA) | |||
| .value("CPU", CompNode::DeviceType::CPU) | |||
| .value("CAMBRICON", CompNode::DeviceType::CAMBRICON) | |||
| .value("ATLAS", CompNode::DeviceType::ATLAS) | |||
| .value("MULTITHREAD", CompNode::DeviceType::MULTITHREAD) | |||
| .value("MAX_DEVICE_ID", CompNode::DeviceType::MAX_DEVICE_ID); | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/tensorrt_runtime.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "../op_trait.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #if MGB_ATLAS | |||
| #include "megbrain/opr/atlas_runtime_op.h" | |||
| namespace mgb::imperative { | |||
| namespace { | |||
| namespace atlas_runtime { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const AtlasRuntime&>(def); | |||
| SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::AtlasRuntimeOpr::make(op.buf.c_str(), op.buf_size, | |||
| symbol_var_inputs, config); | |||
| } | |||
| OP_TRAIT_REG(AtlasRuntime, AtlasRuntime) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace atlas_runtime | |||
| } // namespace | |||
| } // namespace mgb::imperative | |||
| #endif | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/tensorrt_runtime.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "../op_trait.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #if MGB_CAMBRICON | |||
| #include "megbrain/cambricon/cambricon_runtime_opr.h" | |||
| namespace mgb::imperative { | |||
| namespace { | |||
| namespace cambricon_runtime { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const CambriconRuntime&>(def); | |||
| SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::CambriconRuntimeOpr::make(op.buf.c_str(), op.buf_size, | |||
| op.symbol, symbol_var_inputs, | |||
| op.tensor_dim_mutable, config); | |||
| } | |||
| OP_TRAIT_REG(CambriconRuntime, CambriconRuntime) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace cambricon_runtime | |||
| } // namespace | |||
| } // namespace mgb::imperative | |||
| #endif | |||
| @@ -266,6 +266,22 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> { | |||
| ); | |||
| } | |||
| def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> { | |||
| let extraArguments = (ins | |||
| MgbStringAttr:$buf, | |||
| MgbSizeTAddr:$buf_size | |||
| ); | |||
| } | |||
| def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { | |||
| let extraArguments = (ins | |||
| MgbStringAttr:$buf, | |||
| MgbSizeTAddr:$buf_size, | |||
| MgbStringAttr:$symbol, | |||
| MgbBoolAttr:$tensor_dim_mutable | |||
| ); | |||
| } | |||
| def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | |||
| #endif // MGB_OPS | |||