GitOrigin-RevId: 4ade455897
tags/v1.6.0
| @@ -613,17 +613,17 @@ void init_ops(py::module m) { | |||
| } | |||
| #define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \ | |||
| case mgb::custom::ParamDynType::dyn_type: { \ | |||
| case custom::ParamDynType::dyn_type: { \ | |||
| param_val = py::handle(kv.second).cast<static_type>(); \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \ | |||
| case mgb::custom::ParamDynType::dyn_type: { \ | |||
| case custom::ParamDynType::dyn_type: { \ | |||
| auto pyvals = py::handle(kv.second).cast<py::list>(); \ | |||
| static_type vals; \ | |||
| using basic_type = \ | |||
| mgb::custom::get_vector_template_arg_type<static_type>::type; \ | |||
| custom::get_vector_template_arg_type<static_type>::type; \ | |||
| for (auto &pyval: pyvals) { \ | |||
| vals.push_back(py::handle(pyval).cast<basic_type>()); \ | |||
| } \ | |||
| @@ -631,7 +631,7 @@ void init_ops(py::module m) { | |||
| break; \ | |||
| } | |||
| PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyObject *kwnames) { | |||
| PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs) { | |||
| auto op_name = py::handle(args[0]).cast<std::string>(); | |||
| auto kwargs = py::handle(args[1]).cast<py::dict>(); | |||
| @@ -680,7 +680,7 @@ PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyOb | |||
| py::list install_custom(const std::string &name, const std::string &path) { | |||
| py::list ret; | |||
| const auto &ops_in_lib = mgb::custom::LibManager::inst()->install(name, path); | |||
| const auto &ops_in_lib = custom::LibManager::inst()->install(name, path); | |||
| for (const auto &op: ops_in_lib) { | |||
| ret.append(op); | |||
| } | |||
| @@ -688,7 +688,7 @@ py::list install_custom(const std::string &name, const std::string &path) { | |||
| } | |||
| bool uninstall_custom(const std::string &name) { | |||
| return mgb::custom::LibManager::inst()->uninstall(name); | |||
| return custom::LibManager::inst()->uninstall(name); | |||
| } | |||
| py::list get_custom_op_list(void) { | |||
| @@ -697,16 +697,28 @@ py::list get_custom_op_list(void) { | |||
| for (auto &op: all_ops) { | |||
| ret.append(op); | |||
| } | |||
| return std::move(ret); | |||
| return ret; | |||
| } | |||
| #ifndef METH_FASTCALL | |||
| PyObject* py35_make_custom_op(PyObject* self, PyObject* args) { | |||
| auto* arr = &PyTuple_GET_ITEM(args, 0); | |||
| auto size = PyTuple_GET_SIZE(args); | |||
| return make_custom_op(self, arr, size); | |||
| }; | |||
| #endif | |||
| void init_custom(pybind11::module m) { | |||
| m.def("_install", &install_custom); | |||
| m.def("_uninstall", &uninstall_custom); | |||
| m.def("_get_custom_op_list", &get_custom_op_list); | |||
| static PyMethodDef method_def = { | |||
| #ifdef METH_FASTCALL | |||
| "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | |||
| #else | |||
| "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, "" | |||
| #endif | |||
| }; | |||
| auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | |||
| pybind11::setattr(m, method_def.ml_name, func); | |||
| @@ -70,7 +70,7 @@ void CustomOpDef::compute(const SmallVector<DeviceTensorND> &inputs, | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( | |||
| const SmallVector<TensorPtr> &inputs) const { | |||
| SmallVector<LogicalTensorDesc> input_descs(inputs.size()); | |||
| for (int i=0; i<inputs.size(); i++) { | |||
| for (size_t i=0; i<inputs.size(); i++) { | |||
| input_descs[i].comp_node = inputs[i]->comp_node(); | |||
| input_descs[i].layout = inputs[i]->layout(); | |||
| } | |||
| @@ -84,7 +84,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs | |||
| SmallVector<megdnn::DType> i_dtypes(inputs.size()); | |||
| SmallVector<TensorFormat> i_formats(inputs.size()); | |||
| for (int i=0; i<inputs.size(); i++) { | |||
| for (size_t i=0; i<inputs.size(); i++) { | |||
| i_devices[i] = inputs[i].comp_node; | |||
| i_shapes[i] = inputs[i].layout; // TensorLayout is derived from TensorShape | |||
| i_dtypes[i] = inputs[i].layout.dtype; | |||
| @@ -132,7 +132,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs | |||
| } | |||
| SmallVector<LogicalTensorDesc> outputs(this->output_num()); | |||
| for (int i=0; i<this->output_num(); i++) { | |||
| for (size_t i=0; i<this->output_num(); i++) { | |||
| outputs[i].comp_node = std::move(o_devices[i]); | |||
| outputs[i].layout = std::move( | |||
| TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) | |||
| @@ -0,0 +1,181 @@ | |||
| /** | |||
| * \file src/custom/impl/manager.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/manager.h" | |||
| #include "megbrain/common.h" | |||
| #include <unordered_set> | |||
| #ifndef _WIN32 | |||
| #include <dlfcn.h> | |||
| #endif | |||
| using namespace mgb; | |||
| namespace custom { | |||
| CustomOpManager *CustomOpManager::inst(void) { | |||
| static CustomOpManager op_manager; | |||
| return &op_manager; | |||
| } | |||
| CustomOpManager::~CustomOpManager() { | |||
| mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); | |||
| LibManager::inst()->m_custom_libs.clear(); | |||
| } | |||
| std::shared_ptr<CustomOp> CustomOpManager::insert(const std::string &name, uint32_t version) { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| auto iter = m_name2op.find(name); | |||
| if (iter != m_name2op.end()) { | |||
| mgb_log_warn("Register Custom Op Failed! Op %s has been registered", name.c_str()); | |||
| return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second); | |||
| } | |||
| std::shared_ptr<const CustomOp> op = std::make_shared<const CustomOp>(name, version); | |||
| m_name2op[op->op_type()] = op; | |||
| m_id2op[op->runtime_id()] = op; | |||
| return std::const_pointer_cast<CustomOp, const CustomOp>(op); | |||
| } | |||
| bool CustomOpManager::erase(const std::string &name) { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| auto iter = m_name2op.find(name); | |||
| if (iter == m_name2op.end()) { | |||
| mgb_log_warn("Erase Custom Op Failed! %s has not been registered", name.c_str()); | |||
| return false; | |||
| } | |||
| std::shared_ptr<const CustomOp> op = iter->second; | |||
| m_id2op.erase(op->runtime_id()); | |||
| m_name2op.erase(op->op_type()); | |||
| return true; | |||
| } | |||
| bool CustomOpManager::erase(const RunTimeId &id) { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| auto iter = m_id2op.find(id); | |||
| if (iter == m_id2op.end()) { | |||
| mgb_log_warn("Erase Custom Op Failed! The Op has not been registered"); | |||
| return false; | |||
| } | |||
| std::shared_ptr<const CustomOp> op = iter->second; | |||
| m_id2op.erase(op->runtime_id()); | |||
| m_name2op.erase(op->op_type()); | |||
| return true; | |||
| } | |||
| std::shared_ptr<CustomOp> CustomOpManager::find_or_reg(const std::string &name, uint32_t version) { | |||
| auto iter = m_name2op.find(name); | |||
| if (iter == m_name2op.end()) { | |||
| return insert(name, version); | |||
| } | |||
| return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second); | |||
| } | |||
| RunTimeId CustomOpManager::to_id(const std::string &name) const { | |||
| std::shared_ptr<const CustomOp> op = find(name); | |||
| return op->runtime_id(); | |||
| } | |||
| std::string CustomOpManager::to_name(const RunTimeId &id) const { | |||
| std::shared_ptr<const CustomOp> op = find(id); | |||
| return op->op_type(); | |||
| } | |||
| std::shared_ptr<const CustomOp> CustomOpManager::find(const std::string &name) const { | |||
| auto ret = m_name2op.find(name); | |||
| mgb_assert(ret != m_name2op.end(), | |||
| "Find Custom Op Failed! Op %s has not been registered", name.c_str() | |||
| ); | |||
| return ret->second; | |||
| } | |||
| std::shared_ptr<const CustomOp> CustomOpManager::find(const RunTimeId &id) const { | |||
| auto ret = m_id2op.find(id); | |||
| mgb_assert(ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered"); | |||
| return ret->second; | |||
| } | |||
| std::vector<std::string> CustomOpManager::op_name_list(void) { | |||
| std::vector<std::string> ret; | |||
| for (auto kv: m_name2op) { | |||
| ret.emplace_back(kv.first); | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<RunTimeId> CustomOpManager::op_id_list(void) { | |||
| std::vector<RunTimeId> ret; | |||
| for (auto kv: m_id2op) { | |||
| ret.emplace_back(kv.first); | |||
| } | |||
| return ret; | |||
| } | |||
| #ifndef _WIN32 | |||
| CustomLib::CustomLib(const std::string &path, int mode = RTLD_LAZY) | |||
| : m_handle(nullptr, [](void* handle) {dlclose(handle);}) { | |||
| auto op_list_before_load = CustomOpManager::inst()->op_name_list(); | |||
| std::unordered_set<std::string> op_set_before_load( | |||
| op_list_before_load.begin(), op_list_before_load.end()); | |||
| m_handle.reset(dlopen(path.c_str(), mode)); | |||
| mgb_assert(m_handle != nullptr, "open custom op lib failed, error type: %s", dlerror()); | |||
| auto op_list_after_load = CustomOpManager::inst()->op_name_list(); | |||
| for (auto &op: op_list_after_load) { | |||
| if (op_set_before_load.find(op) == op_set_before_load.end()) { | |||
| m_ops.emplace_back(op); | |||
| } | |||
| } | |||
| } | |||
| #else | |||
| CustomLib::CustomLib(const std::string &path, int mode = 0) | |||
| : m_handle(nullptr, [](void* handle) {}) { | |||
| mgb_assert(false, "custom op is only supported on Linux now"); | |||
| } | |||
| #endif | |||
| const std::vector<std::string> &CustomLib::ops_in_lib(void) const { | |||
| return m_ops; | |||
| } | |||
| CustomLib::~CustomLib() { | |||
| for (auto &op: m_ops) { | |||
| CustomOpManager::inst()->erase(op); | |||
| } | |||
| } | |||
| bool CustomLib::valid() const { | |||
| return m_handle != nullptr; | |||
| } | |||
| LibManager *LibManager::inst(void) { | |||
| static LibManager custom_libs; | |||
| return &custom_libs; | |||
| } | |||
| const std::vector<std::string> &LibManager::install(const std::string &name, const std::string &path) { | |||
| MGB_LOCK_GUARD(m_mtx);; | |||
| LibHandle handle = std::make_shared<CustomLib>(path); | |||
| m_custom_libs.insert({name, handle}); | |||
| return m_custom_libs[name]->ops_in_lib(); | |||
| } | |||
| bool LibManager::uninstall(const std::string &name) { | |||
| MGB_LOCK_GUARD(m_mtx);; | |||
| mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error"); | |||
| return true; | |||
| } | |||
| std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) { | |||
| return CustomOpManager::inst()->insert(opname, version); | |||
| } | |||
| } | |||
| @@ -0,0 +1,531 @@ | |||
| /** | |||
| * \file src/custom/impl/op.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/custom/op.h" | |||
| #include "megbrain/custom/utils.h" | |||
| #include <unordered_set> | |||
| #include <sstream> | |||
| using namespace mgb; | |||
| namespace custom { | |||
| class ArgInfoImpl { | |||
| std::string m_name; | |||
| std::string m_desc; | |||
| std::unordered_set<std::string> m_dtypes; | |||
| int m_ndim; // use int rather than size_t for representing m_dims = -1 | |||
| std::string m_mem_stgy; | |||
| friend class ArgInfo; | |||
| }; | |||
| CUSTOM_PIMPL_CLS_DEFINE(ArgInfo) | |||
| ArgInfo::ArgInfo(const std::string &name, | |||
| const std::string &desc, | |||
| const std::unordered_set<std::string> &dtypes, | |||
| const int &ndim, | |||
| const std::string &mem_stgy): m_impl(new ArgInfoImpl(), impl_deleter<ArgInfoImpl>) { | |||
| for (auto &&dtype: dtypes) { | |||
| mgb_assert(DType::is_legal(dtype), "unsupported tensor data type: %s", dtype.c_str()); | |||
| } | |||
| mgb_assert(mem_stgy == "default", "only default mem strategy is supported now!"); | |||
| TypedRef(ArgInfoImpl, m_impl.get()).m_name = name; | |||
| TypedRef(ArgInfoImpl, m_impl.get()).m_desc = desc; | |||
| TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes = dtypes; | |||
| TypedRef(ArgInfoImpl, m_impl.get()).m_ndim = ndim; | |||
| TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy = mem_stgy; | |||
| } | |||
| const std::string &ArgInfo::name(void) const { | |||
| return TypedRef(ArgInfoImpl, m_impl.get()).m_name; | |||
| } | |||
| const std::string &ArgInfo::desc(void) const { | |||
| return TypedRef(ArgInfoImpl, m_impl.get()).m_desc; | |||
| } | |||
| const std::unordered_set<std::string> &ArgInfo::dtypes(void) const { | |||
| return TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes; | |||
| } | |||
| int ArgInfo::ndim(void) const { | |||
| return TypedRef(ArgInfoImpl, m_impl.get()).m_ndim; | |||
| } | |||
| const std::string &ArgInfo::mem_strategy(void) const { | |||
| return TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; | |||
| } | |||
| std::string ArgInfo::str() const { | |||
| std::stringstream ss; | |||
| ss << "name: " << TypedRef(ArgInfoImpl, m_impl.get()).m_name << "\n" | |||
| << "desc: " << TypedRef(ArgInfoImpl, m_impl.get()).m_desc << "\nlegal_dtypes: {"; | |||
| for (auto &val: TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes) { | |||
| ss << val << ", "; | |||
| } | |||
| if (TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes.size() != 0) { | |||
| ss.seekp(ss.tellp()-std::streampos(2)); | |||
| } | |||
| ss << "}\ndims: " << TypedRef(ArgInfoImpl, m_impl.get()).m_ndim << "\n" | |||
| << "memory_strategy: " << TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; | |||
| return ss.str(); | |||
| } | |||
| #define assert_inputs_size_right(inputs_vec) \ | |||
| mgb_assert( \ | |||
| inputs_vec.size() == input_num(), \ | |||
| "op %s need %lu inputs but given %lu", \ | |||
| op_type().c_str(), static_cast<unsigned long>(input_num()), \ | |||
| static_cast<unsigned long>(inputs_vec.size()) \ | |||
| ) | |||
| #define assert_outputs_size_right(outputs_vec) \ | |||
| mgb_assert( \ | |||
| outputs_vec.size() == output_num(), \ | |||
| "op %s have %lu outputs but given %lu", \ | |||
| op_type().c_str(), static_cast<unsigned long>(output_num()), \ | |||
| static_cast<unsigned long>(outputs_vec.size()) \ | |||
| ) | |||
| #define assert_arg_shape_dim_right(real_shape, arg_info) \ | |||
| mgb_assert( \ | |||
| (arg_info).ndim() == -1 || static_cast<int>((real_shape).ndim()) == \ | |||
| static_cast<int>((arg_info).ndim()), \ | |||
| "%s's args: %s dim match error, need %d but given %d", op_type().c_str(), \ | |||
| (arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \ | |||
| static_cast<int>((real_shape).ndim()) \ | |||
| ) | |||
| template <typename T> | |||
| class Function; | |||
| template<typename RType, typename... Args> | |||
| class Function<RType(Args...)> { | |||
| public: | |||
| using Functor = RType (*)(Args...); | |||
| Function() = default; | |||
| Function(Functor f): m_f(f) {} | |||
| Function(const Function &rhs) { | |||
| m_f = rhs.m_f; | |||
| } | |||
| RType operator()(Args... args) { | |||
| custom_assert(m_f != nullptr, "invalid function ptr\n"); | |||
| return m_f(std::forward<Args>(args)...); | |||
| } | |||
| void operator=(const Function &rhs) { // not allowed continuous assignment | |||
| m_f = rhs.m_f; | |||
| } | |||
| void operator=(const Functor f) { | |||
| m_f = f; | |||
| } | |||
| private: | |||
| Functor m_f = nullptr; | |||
| }; | |||
| template <typename Functions> | |||
| class FuncWithSig: public Functions { | |||
| public: | |||
| using Functions::operator(); | |||
| using Functions::operator=; | |||
| }; | |||
| class CustomOpImpl { | |||
| static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION; | |||
| const uint32_t m_version; | |||
| const std::string m_op_type; | |||
| std::string m_op_desc; | |||
| std::vector<ArgInfo> m_input_infos; | |||
| std::vector<ArgInfo> m_output_infos; | |||
| ParamInfo m_param_infos; | |||
| using DeviceInfer = FuncWithSig<Function<void(const std::vector<Device>&, const Param&, std::vector<Device>&)>>; | |||
| using ShapeInfer = FuncWithSig<Function<void(const std::vector<Shape>&, const Param&, std::vector<Shape>&)>>; | |||
| using DTypeInfer = FuncWithSig<Function<void(const std::vector<DType>&, const Param&, std::vector<DType>&)>>; | |||
| using FormatInfer = FuncWithSig<Function<void(const std::vector<Format>&, const Param&, std::vector<Format>&)>>; | |||
| using Preprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||
| using Postprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||
| using Compute = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||
| DeviceInfer infer_output_device_func; | |||
| ShapeInfer infer_output_shape_func; | |||
| DTypeInfer infer_output_dtype_func; | |||
| FormatInfer infer_output_format_func; | |||
| std::unordered_map<std::string, Compute> compute_funcs; | |||
| std::unordered_map<std::string, Preprocess> preprocess_funcs; | |||
| std::unordered_map<std::string, Postprocess> postprocess_funcs; | |||
| public: | |||
| CustomOpImpl(const std::string&, uint32_t version); | |||
| PREVENT_COPY_AND_ASSIGN(CustomOpImpl); | |||
| friend CustomOp; | |||
| }; | |||
| CustomOpImpl::CustomOpImpl(const std::string &op_type, uint32_t version) | |||
| : m_version(version), m_op_type(op_type) { | |||
| if (m_version != CURRENT_VERSION) { | |||
| mgb_log_warn( | |||
| "the version of loaded custom op %s is %u, but custom op version " | |||
| "of the system is %u\n", op_type.c_str(), m_version, CURRENT_VERSION | |||
| ); | |||
| } | |||
| infer_output_device_func = [](const std::vector<Device> &inputs, | |||
| const Param&, | |||
| std::vector<Device> &outputs) -> void { | |||
| static UnImpleWarnLog log_once("output_device_infer", "device", "x86"); | |||
| for (size_t i=0; i<outputs.size(); ++i) { | |||
| outputs[i] = inputs.size() > 0 ? inputs[0] : Device("x86"); | |||
| } | |||
| }; | |||
| infer_output_shape_func = [](const std::vector<Shape> &inputs, | |||
| const Param&, | |||
| std::vector<Shape> &outputs) -> void { | |||
| static UnImpleWarnLog log_once("output_shape_infer", "shape", "{1}"); | |||
| for (size_t i=0; i<outputs.size(); ++i) { | |||
| outputs[i] = inputs.size() > 0 ? inputs[0] : Shape({1}); | |||
| } | |||
| }; | |||
| infer_output_dtype_func = [](const std::vector<DType> &inputs, | |||
| const Param&, | |||
| std::vector<DType> &outputs) -> void { | |||
| static UnImpleWarnLog log_once("output_dtype_infer", "dtype", "float32"); | |||
| for (size_t i=0; i<outputs.size(); ++i) { | |||
| outputs[i] = inputs.size() > 0 ? inputs[0] : DType("float32"); | |||
| } | |||
| }; | |||
| infer_output_format_func = [](const std::vector<Format> &inputs, | |||
| const Param&, | |||
| std::vector<Format> &outputs) -> void { | |||
| for (size_t i=0; i<outputs.size(); ++i) { | |||
| outputs[i] = inputs.size() > 0 ? inputs[0] : Format("default"); | |||
| } | |||
| }; | |||
| for (const auto &device: Device::legal_devices()) { | |||
| compute_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor> &outputs) -> void { | |||
| auto device = outputs[0].device(); | |||
| mgb_assert(false, "There is no forward function for your op on device `%s`. " | |||
| "Please implement this function and register it.", device.str().c_str()); | |||
| }; | |||
| preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void { | |||
| return; | |||
| }; | |||
| postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void { | |||
| return; | |||
| }; | |||
| } | |||
| m_param_infos.set_tag(op_type); | |||
| } | |||
| CustomOp::CustomOp(const std::string &op_type, uint32_t version) | |||
| : m_impl(new CustomOpImpl(op_type, version), impl_deleter<CustomOpImpl>) { | |||
| } | |||
| #define OpImplRef(raw_ptr) reinterpret_cast<CustomOpImpl*>(raw_ptr) | |||
| CustomOp &CustomOp::set_device_infer(DeviceInferFuncPtr func) { | |||
| OpImplRef(m_impl.get())->infer_output_device_func = func; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_shape_infer(ShapeInferFuncPtr func) { | |||
| OpImplRef(m_impl.get())->infer_output_shape_func = func; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_dtype_infer(DTypeInferFuncPtr func) { | |||
| OpImplRef(m_impl.get())->infer_output_dtype_func = func; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_format_infer(FormatInferFuncPtr func) { | |||
| OpImplRef(m_impl.get())->infer_output_format_func = func; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_preprocess(PreprocessFuncPtr func) { | |||
| set_preprocess("x86", func); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_preprocess(const std::string &device, PreprocessFuncPtr func) { | |||
| OpImplRef(m_impl.get())->preprocess_funcs[device] = func; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_postprocess(PostprocessFuncPtr func) { | |||
| set_postprocess("x86", func); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_postprocess(const std::string &device, PostprocessFuncPtr func) { | |||
| OpImplRef(m_impl.get())->postprocess_funcs[device] = func; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_compute(ComputeFuncPtr func) { | |||
| set_compute("x86", func); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_compute(const std::string &device, ComputeFuncPtr func) { | |||
| OpImplRef(m_impl.get())->compute_funcs[device] = func; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::set_description(const std::string &op_desc) { | |||
| OpImplRef(m_impl.get())->m_op_desc = op_desc; | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
| auto &ref = OpImplRef(m_impl.get())->m_input_infos; | |||
| for (const auto &input: ref) { | |||
| mgb_assert(input.name() != name, "input %s has been registered", name.c_str()); | |||
| } | |||
| ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
| auto &ref = OpImplRef(m_impl.get())->m_output_infos; | |||
| for (const auto &output: ref) { | |||
| mgb_assert(output.name() != name, "output %s has been registered", name.c_str()); | |||
| } | |||
| ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
| add_input(name, name, legal_dtypes, dims, mem_stgy); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
| add_output(name, name, legal_dtypes, dims, mem_stgy); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_inputs(const size_t &num) { | |||
| size_t cur_inp_num = input_num(); | |||
| for (size_t i=cur_inp_num; i<cur_inp_num+num; i++) { | |||
| add_input(op_type() + "_Input_" + std::to_string(i)); | |||
| } | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_outputs(const size_t &num) { | |||
| size_t cur_oup_num = output_num(); | |||
| for (size_t i=cur_oup_num; i<cur_oup_num+num; i++) { | |||
| add_output(op_type() + "_Output_" + std::to_string(i)); | |||
| } | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_param(const std::string &name, const ParamVal &default_val) { | |||
| add_param(name, name, default_val); | |||
| return *this; | |||
| } | |||
| CustomOp &CustomOp::add_param(const std::string &name, const std::string &desc, const ParamVal &default_val) { | |||
| auto &meta = OpImplRef(m_impl.get())->m_param_infos.meta(); | |||
| for(const auto &schema: meta) { | |||
| mgb_assert(name != schema.name(), "param %s has been registered\n", name.c_str()); | |||
| } | |||
| ParamSchema sch = ParamSchema(name, default_val, desc); | |||
| meta.emplace_back(sch); | |||
| return *this; | |||
| } | |||
| std::string CustomOp::op_type(void) const { | |||
| return OpImplRef(m_impl.get())->m_op_type; | |||
| } | |||
| std::string CustomOp::op_desc(void) const { | |||
| return OpImplRef(m_impl.get())->m_op_desc; | |||
| } | |||
| RunTimeId CustomOp::runtime_id(void) const { | |||
| return (RunTimeId)(this); | |||
| } | |||
| size_t CustomOp::input_num(void) const { | |||
| return OpImplRef(m_impl.get())->m_input_infos.size(); | |||
| } | |||
| size_t CustomOp::output_num(void) const { | |||
| return OpImplRef(m_impl.get())->m_output_infos.size(); | |||
| } | |||
| std::string CustomOp::str(void) const { | |||
| std::stringstream ss; | |||
| ss << "op name: " << op_type() << "\nop desc: " << op_desc() << "\n\ninputs:\n"; | |||
| for (const auto &input: inputs_info()) { | |||
| ss << input.str(); | |||
| ss << "\n--------------------\n"; | |||
| } | |||
| ss << "\noutputs:\n"; | |||
| for (const auto &output: outputs_info()) { | |||
| ss << output.str(); | |||
| ss << "\n--------------------\n"; | |||
| } | |||
| ss << "\nparams:\n"; | |||
| for (const auto ¶m: param_info().meta()) { | |||
| ss << param.str(); | |||
| ss << "\n--------------------\n"; | |||
| } | |||
| return ss.str(); | |||
| } | |||
| const ParamInfo &CustomOp::param_info(void) const { | |||
| return OpImplRef(m_impl.get())->m_param_infos; | |||
| } | |||
| ArgInfo CustomOp::input_info(size_t idx) const { | |||
| return OpImplRef(m_impl.get())->m_input_infos[idx]; | |||
| } | |||
| ArgInfo CustomOp::output_info(size_t idx) const { | |||
| return OpImplRef(m_impl.get())->m_output_infos[idx]; | |||
| } | |||
| const std::vector<ArgInfo> &CustomOp::inputs_info(void) const { | |||
| return OpImplRef(m_impl.get())->m_input_infos; | |||
| } | |||
| const std::vector<ArgInfo> &CustomOp::outputs_info(void) const { | |||
| return OpImplRef(m_impl.get())->m_output_infos; | |||
| } | |||
| std::vector<Device> CustomOp::infer_output_device(const std::vector<Device> &inputs, const Param ¶m) const { | |||
| assert_inputs_size_right(inputs); | |||
| std::vector<Device> outputs(output_num()); | |||
| OpImplRef(m_impl.get())->infer_output_device_func(inputs, param, outputs); | |||
| assert_outputs_size_right(outputs); | |||
| return outputs; | |||
| } | |||
| std::vector<Shape> CustomOp::infer_output_shape(const std::vector<Shape> &inputs, const Param ¶m) const { | |||
| assert_inputs_size_right(inputs); | |||
| for (size_t i=0; i<inputs_info().size(); i++) { | |||
| assert_arg_shape_dim_right(inputs[i], input_info(i)); | |||
| } | |||
| std::vector<Shape> outputs(output_num()); | |||
| OpImplRef(m_impl.get())->infer_output_shape_func(inputs, param, outputs); | |||
| for (size_t i=0; i<outputs_info().size(); i++) { | |||
| assert_arg_shape_dim_right(outputs[i], output_info(i)); | |||
| } | |||
| assert_outputs_size_right(outputs); | |||
| return outputs; | |||
| } | |||
| std::vector<DType> CustomOp::infer_output_dtype(const std::vector<DType> &inputs, const Param ¶m) const { | |||
| assert_inputs_size_right(inputs); | |||
| for (size_t i=0; i<inputs_info().size(); i++) { | |||
| std::unordered_set<std::string> legal_input_dtypes_i = input_info(i).dtypes(); | |||
| mgb_assert( | |||
| legal_input_dtypes_i.find(inputs[i].str()) != legal_input_dtypes_i.end(), | |||
| "dtypes of input: %s(%s) is not allowed, the info of this input is:\n%s", | |||
| input_info(i).name().c_str(), inputs[i].str().c_str(), | |||
| input_info(i).str().c_str() | |||
| ); | |||
| } | |||
| std::vector<DType> outputs(output_num()); | |||
| OpImplRef(m_impl.get())->infer_output_dtype_func(inputs, param, outputs); | |||
| for (size_t i=0; i<outputs_info().size(); i++) { | |||
| std::unordered_set<std::string> legal_output_dtypes_i = output_info(i).dtypes(); | |||
| mgb_assert( | |||
| legal_output_dtypes_i.find(outputs[i].str()) != legal_output_dtypes_i.end(), | |||
| "dtypes of output: %s is %s, the info of this output is:\n%s", | |||
| output_info(i).name().c_str(), outputs[i].str().c_str(), | |||
| output_info(i).str().c_str() | |||
| ); | |||
| } | |||
| assert_outputs_size_right(outputs); | |||
| return outputs; | |||
| } | |||
| std::vector<Format> CustomOp::infer_output_format(const std::vector<Format> &inputs, const Param ¶m) const { | |||
| assert_inputs_size_right(inputs); | |||
| for (size_t i=0; i<inputs.size(); i++) { | |||
| mgb_assert( | |||
| inputs[i].is_default(), | |||
| "the tensor format of %s:%s is not default", | |||
| op_type().c_str(), input_info(i).name().c_str() | |||
| ); | |||
| } | |||
| std::vector<Format> outputs(output_num()); | |||
| OpImplRef(m_impl.get())->infer_output_format_func(inputs, param, outputs); | |||
| for (size_t i=0; i<outputs.size(); i++) { | |||
| mgb_assert( | |||
| outputs[i].is_default(), | |||
| "the tensor format of %s:%s is not default", | |||
| op_type().c_str(), output_info(i).name().c_str() | |||
| ); | |||
| } | |||
| assert_outputs_size_right(outputs); | |||
| return outputs; | |||
| } | |||
| void CustomOp::compute(const std::vector<Tensor> &inputs, const Param ¶m, std::vector<Tensor> &outputs) const { | |||
| assert_inputs_size_right(inputs); | |||
| assert_outputs_size_right(outputs); | |||
| if (outputs.size() == 0) { | |||
| return; | |||
| } | |||
| std::string device = outputs[0].device().str(); | |||
| for (size_t i=1; i<outputs.size(); ++i) { | |||
| mgb_assert( | |||
| outputs[i].device().str() == device, | |||
| "all output tensors should have the same device attribute" | |||
| ); | |||
| } | |||
| // need to add other input/output check | |||
| mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str()); | |||
| auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device]; | |||
| auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device]; | |||
| auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device]; | |||
| preprocess_func(inputs, param, outputs); | |||
| forward_func(inputs, param, outputs); | |||
| postprocess_func(outputs, param, outputs); | |||
| assert_outputs_size_right(outputs); | |||
| } | |||
| } | |||
| @@ -0,0 +1,179 @@ | |||
| /** | |||
| * \file src/custom/impl/param.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/param.h" | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/utils/hash.h" | |||
| #include <limits> | |||
| #include <sstream> | |||
| #include <map> | |||
| using namespace mgb; | |||
| namespace custom { | |||
| class ParamSchemaImpl { | |||
| std::string m_name; | |||
| std::string m_desc; | |||
| ParamVal m_default; | |||
| friend ParamSchema; | |||
| }; | |||
| class ParamInfoImpl { | |||
| std::vector<ParamSchema> m_meta; | |||
| uint32_t TAG; | |||
| friend ParamInfo; | |||
| }; | |||
| class ParamImpl { | |||
| std::unordered_map<std::string, ParamVal> m_vals; | |||
| ParamImpl() = default; | |||
| ParamImpl(const ParamImpl &rhs) = default; | |||
| ParamImpl &operator=(const ParamImpl &rhs) { | |||
| mgb_assert( | |||
| m_vals.size() == rhs.m_vals.size(), | |||
| "params of different op, assignment failed!" | |||
| ); | |||
| for (const auto &kv: rhs.m_vals) { | |||
| auto iter = m_vals.find(kv.first); | |||
| mgb_assert(iter != m_vals.end(), "params of different op, assignment failed!"); | |||
| iter->second = kv.second; | |||
| } | |||
| return *this; | |||
| } | |||
| friend Param; | |||
| }; | |||
| CUSTOM_PIMPL_CLS_DEFINE(ParamSchema) | |||
| ParamSchema::ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc) | |||
| : m_impl(new ParamSchemaImpl(), impl_deleter<ParamSchemaImpl>) { | |||
| TypedRef(ParamSchemaImpl, m_impl.get()).m_name = name; | |||
| TypedRef(ParamSchemaImpl, m_impl.get()).m_default = value; | |||
| TypedRef(ParamSchemaImpl, m_impl.get()).m_desc = desc; | |||
| } | |||
| const std::string &ParamSchema::name(void) const { | |||
| return TypedRef(ParamSchemaImpl, m_impl.get()).m_name; | |||
| } | |||
| const std::string &ParamSchema::desc(void) const { | |||
| return TypedRef(ParamSchemaImpl, m_impl.get()).m_desc; | |||
| } | |||
| const ParamVal &ParamSchema::default_val(void) const { | |||
| return TypedRef(ParamSchemaImpl, m_impl.get()).m_default; | |||
| } | |||
| ParamDynType ParamSchema::type(void) const { | |||
| return TypedRef(ParamSchemaImpl, m_impl.get()).m_default.type(); | |||
| } | |||
| std::string ParamSchema::str(void) const { | |||
| std::stringstream ss; | |||
| ss << "name: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_name | |||
| << "\ndesc: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_desc | |||
| << "\n" << TypedRef(ParamSchemaImpl, m_impl.get()).m_default.str(); | |||
| return ss.str(); | |||
| } | |||
| CUSTOM_PIMPL_CLS_DEFINE(ParamInfo) | |||
| void ParamInfo::set_tag(const std::string &hash_str) { | |||
| const char *ptr = hash_str.c_str(); | |||
| TypedRef(ParamInfoImpl, m_impl.get()).TAG = 0; | |||
| for (size_t i=0; i<hash_str.size(); i++) { | |||
| TypedRef(ParamInfoImpl, m_impl.get()).TAG = | |||
| mgb::hash_pair_combine(TypedRef(ParamInfoImpl, m_impl.get()).TAG, mgb::hash(*(ptr++))) % | |||
| std::numeric_limits<uint32_t>::max(); | |||
| } | |||
| } | |||
| void ParamInfo::set_meta(const std::vector<ParamSchema> &meta) { | |||
| TypedRef(ParamInfoImpl, m_impl.get()).m_meta = meta; | |||
| } | |||
| uint32_t ParamInfo::tag(void) const { | |||
| return TypedRef(ParamInfoImpl, m_impl.get()).TAG; | |||
| } | |||
| std::vector<ParamSchema> &ParamInfo::meta(void) { | |||
| return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; | |||
| } | |||
| const std::vector<ParamSchema> &ParamInfo::meta(void) const { | |||
| return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; | |||
| } | |||
| CUSTOM_PIMPL_CLS_DEFINE(Param) | |||
| Param::Param(const ParamInfo &info): m_impl(new ParamImpl(), impl_deleter<ParamImpl>) { | |||
| for (const auto &schema: info.meta()) { | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.emplace(schema.name(), schema.default_val()); | |||
| } | |||
| } | |||
| ParamVal &Param::operator[](const std::string &name) { | |||
| return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; | |||
| } | |||
| const ParamVal &Param::operator[](const std::string &name) const { | |||
| return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; | |||
| } | |||
| const std::unordered_map<std::string, ParamVal> &Param::raw() const { | |||
| return TypedRef(ParamImpl, m_impl.get()).m_vals; | |||
| } | |||
| bool Param::exist(const std::string &name) const { | |||
| return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name) != | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.end(); | |||
| } | |||
| std::string Param::to_bytes(void) const { | |||
| std::string res; | |||
| std::map<std::string, ParamVal> ordered_vals( | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.end()); | |||
| for (auto &&kv: ordered_vals) { | |||
| res += ParamVal::to_bytes(kv.second); | |||
| } | |||
| return res; | |||
| } | |||
| void Param::from_bytes(const std::string &bytes) { | |||
| std::map<std::string, ParamVal> ordered_vals( | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.end()); | |||
| size_t offset = 0; | |||
| for (auto &kv: ordered_vals) { | |||
| kv.second = ParamVal::from_bytes(bytes, offset); | |||
| } | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.clear(); | |||
| TypedRef(ParamImpl, m_impl.get()).m_vals.insert(ordered_vals.begin(), ordered_vals.end()); | |||
| mgb_assert(offset == bytes.size(), "wrong data loader"); | |||
| } | |||
| bool operator==(const Param &lhs, const Param &rhs) { | |||
| if (lhs.raw().size() != rhs.raw().size()) | |||
| return false; | |||
| for (const auto &kv: lhs.raw()) { | |||
| auto riter = rhs.raw().find(kv.first); | |||
| if (riter == rhs.raw().end() || !((kv.second) == riter->second)) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| @@ -0,0 +1,400 @@ | |||
| /** | |||
| * \file src/custom/impl/param_val.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/param_val.h" | |||
| #include "megbrain/common.h" | |||
| #pragma GCC diagnostic ignored "-Wsign-compare" | |||
| using namespace mgb; | |||
| namespace custom { | |||
| /** | |||
| * Macro Callback for Case | |||
| */ | |||
| #define CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| std::unique_ptr<void, void_deleter> new_ptr( \ | |||
| new static_type(TypedRef(static_type, rhs.m_ptr.get())), \ | |||
| impl_deleter<static_type> \ | |||
| ); \ | |||
| m_ptr.swap(new_ptr); \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| TypedRef(static_type, m_ptr.get()) = TypedRef(static_type, rhs.m_ptr.get());\ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_ASSERT_OPERAND_VALID(operand, opr) \ | |||
| mgb_assert( \ | |||
| operand.m_ptr != nullptr && operand.m_type != ParamDynType::Invalid, \ | |||
| "invalid %s of operator %s of ParamVal", #operand, #opr \ | |||
| ) | |||
| #define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \ | |||
| mgb_assert( \ | |||
| lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \ | |||
| type2name[lhs.m_type].c_str(), #op, \ | |||
| type2name[rhs.m_type].c_str() \ | |||
| ) | |||
| #define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ | |||
| return lval op rval; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC(dyn_type, static_type, op) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ | |||
| switch (rhs.m_type) { \ | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY( \ | |||
| CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL, op) \ | |||
| default: \ | |||
| CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
| } \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC(dyn_type, static_type, op) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
| const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ | |||
| const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ | |||
| return lval op rval; \ | |||
| } | |||
| #define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(op, ret_type) \ | |||
| ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||
| CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||
| CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||
| \ | |||
| switch (lhs.m_type) { \ | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||
| CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||
| default: \ | |||
| CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
| } \ | |||
| return {}; \ | |||
| } | |||
| #define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(op, ret_type) \ | |||
| ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||
| CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||
| CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||
| \ | |||
| switch (lhs.m_type) { \ | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||
| CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||
| CUSTOM_FOR_STRING_PARAMTYPE( \ | |||
| CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||
| default: \ | |||
| CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
| } \ | |||
| return {}; \ | |||
| } | |||
| #define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(op, ret_type) \ | |||
| ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||
| CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||
| CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||
| \ | |||
| switch (lhs.m_type) { \ | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||
| CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||
| CUSTOM_FOR_STRING_PARAMTYPE( \ | |||
| CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||
| CUSTOM_FOR_EACH_LIST_PARAMTYPE( \ | |||
| CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||
| default: \ | |||
| CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
| } \ | |||
| return {}; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_PRINT_NONLIST(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| auto rval = TypedRef(static_type, m_ptr.get()); \ | |||
| ss << rval; \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_PRINT_LIST(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| auto rval = TypedRef(static_type, m_ptr.get()); \ | |||
| ss << vec2str(rval); \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_RET_SIZE(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| return TypedRef(static_type, m_ptr.get()).size(); \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_DUMP_BASIC(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| res.resize(sizeof(ParamDynType) + sizeof(static_type)); \ | |||
| memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ | |||
| memcpy(&res[sizeof(ParamDynType)], value.m_ptr.get(), sizeof(static_type)); \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_DUMP_LIST(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| auto &ref = TypedRef(static_type, value.m_ptr.get()); \ | |||
| size_t len = ref.size(); \ | |||
| size_t elem_size = len != 0 ? sizeof(ref[0]) : 0; \ | |||
| res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); \ | |||
| memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ | |||
| memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); \ | |||
| memcpy(&res[sizeof(ParamDynType)+sizeof(len)], ref.data(), len*elem_size); \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_LOAD_BASIC(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| static_type val; \ | |||
| memcpy(&val, &bytes[offset], sizeof(val)); \ | |||
| offset += sizeof(val); \ | |||
| return val; \ | |||
| break; \ | |||
| } | |||
| #define CUSTOM_CASE_TO_LOAD_LIST(dyn_type, static_type) \ | |||
| case (ParamDynType::dyn_type): { \ | |||
| size_t len = 0; \ | |||
| memcpy(&len, &bytes[offset], sizeof(len)); \ | |||
| offset += sizeof(len); \ | |||
| static_type vals; \ | |||
| vals.resize(len); \ | |||
| size_t elem_size = len != 0 ? sizeof(vals[0]) : 0; \ | |||
| memcpy(&vals[0], &bytes[offset], len*elem_size); \ | |||
| offset += len*elem_size; \ | |||
| return vals; \ | |||
| break; \ | |||
| } | |||
| ParamVal::ParamVal(): m_ptr(nullptr, [](void*) -> void {}) { | |||
| m_type = ParamDynType::Invalid; | |||
| } | |||
| ParamVal::ParamVal(const char *str): ParamVal(std::string(str)) { | |||
| } | |||
| ParamVal::ParamVal(const std::initializer_list<const char*> &strs): ParamVal(std::vector<const char*>(strs)) { | |||
| } | |||
| ParamVal::ParamVal(const std::vector<const char*> &strs) | |||
| : m_ptr(new std::vector<std::string>(), impl_deleter<std::vector<std::string>>) { | |||
| m_type = ParamDynType::StringList; | |||
| for (const auto &str: strs) { | |||
| TypedRef(std::vector<std::string>, m_ptr.get()).emplace_back(str); | |||
| } | |||
| } | |||
| ParamVal::ParamVal(const ParamVal &rhs): m_ptr(nullptr, [](void*) -> void {}) { | |||
| mgb_assert( | |||
| rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, | |||
| "invalid rhs of copy constructor of ParamVal" | |||
| ); | |||
| m_type = rhs.m_type; | |||
| switch(m_type) { | |||
| CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS) | |||
| default: { | |||
| mgb_assert(false, "invalid rhs of copy constructor of ParamVal"); | |||
| } | |||
| } | |||
| } | |||
| ParamVal &ParamVal::operator=(const char *str) { | |||
| this->operator=(std::string(str)); | |||
| return *this; | |||
| } | |||
| ParamVal &ParamVal::operator=(const std::initializer_list<const char*> &strs) { | |||
| this->operator=(std::vector<const char*>(strs)); | |||
| return *this; | |||
| } | |||
| ParamVal &ParamVal::operator=(const std::vector<const char*> &strs) { | |||
| std::vector<std::string> tmp_strs; | |||
| for (const auto &str: strs) { | |||
| tmp_strs.emplace_back(str); | |||
| } | |||
| this->operator=(tmp_strs); | |||
| return *this; | |||
| } | |||
| ParamVal &ParamVal::operator=(const ParamVal &rhs) { | |||
| if (&rhs == this) | |||
| return *this; | |||
| mgb_assert( | |||
| rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, | |||
| "invalid rhs of assignment operator of ParamVal" | |||
| ); | |||
| if (rhs.m_type == m_type) { | |||
| switch(m_type) { | |||
| CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS); | |||
| default: | |||
| mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); | |||
| } | |||
| } | |||
| else { | |||
| m_type = rhs.m_type; | |||
| switch(m_type) { | |||
| CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS); | |||
| default: | |||
| mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); | |||
| } | |||
| } | |||
| return *this; | |||
| } | |||
| const void *ParamVal::raw_ptr(void) const { | |||
| return m_ptr.get(); | |||
| } | |||
| void *ParamVal::raw_ptr(void) { | |||
| return m_ptr.get(); | |||
| } | |||
| ParamDynType ParamVal::type(void) const { | |||
| return m_type; | |||
| } | |||
| std::string ParamVal::str() const { | |||
| std::stringstream ss; | |||
| ss << "type: " << type2name[m_type] << "\n" << "value: "; | |||
| switch (m_type) { | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) | |||
| CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) | |||
| CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST) | |||
| default: | |||
| mgb_assert(false, "invalid data of assignment operator of ParamVal"); | |||
| } | |||
| return ss.str(); | |||
| } | |||
| size_t ParamVal::size(void) const { | |||
| switch (m_type) { | |||
| CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) | |||
| CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) | |||
| default: | |||
| mgb_assert(false, "there is no size() for basic data types"); | |||
| } | |||
| } | |||
| std::string ParamVal::to_bytes(const ParamVal &value) { | |||
| std::string res; | |||
| // because the specialization of std::vector<bool> | |||
| if (value.type() == ParamDynType::BoolList) { | |||
| std::vector<bool> &ref = TypedRef(std::vector<bool>, value.m_ptr.get()); | |||
| size_t len = ref.size(); | |||
| size_t elem_size = sizeof(bool); | |||
| res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); | |||
| memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); | |||
| memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); | |||
| size_t startpos = sizeof(ParamDynType)+sizeof(len); | |||
| for (size_t idx=0; idx<len; idx++) { | |||
| bool b = ref[idx]; | |||
| memcpy(&res[startpos+idx*sizeof(b)], &b, sizeof(b)); | |||
| } | |||
| return res; | |||
| } | |||
| else if (value.type() == ParamDynType::StringList) { | |||
| std::vector<std::string> &ref = TypedRef(std::vector<std::string>, value.m_ptr.get()); | |||
| size_t len = ref.size(); | |||
| res.resize(sizeof(ParamDynType) + sizeof(len)); | |||
| memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); | |||
| memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); | |||
| for (size_t idx=0; idx<ref.size(); ++idx) { | |||
| size_t str_len = ref[idx].size(); | |||
| std::string bytes(sizeof(str_len) + str_len, ' '); | |||
| memcpy(&bytes[0], &str_len, sizeof(str_len)); | |||
| memcpy(&bytes[sizeof(str_len)], ref[idx].data(), str_len); | |||
| res += bytes; | |||
| } | |||
| return res; | |||
| } | |||
| switch(value.type()) { | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_DUMP_BASIC) | |||
| CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) | |||
| CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) | |||
| default: | |||
| mgb_assert(false, "invalid param type"); | |||
| } | |||
| return res; | |||
| } | |||
| ParamVal ParamVal::from_bytes(const std::string &bytes, size_t &offset) { | |||
| ParamDynType data_type = ParamDynType::Invalid; | |||
| memcpy(&data_type, &bytes[offset], sizeof(ParamDynType)); | |||
| offset += sizeof(ParamDynType); | |||
| if (data_type == ParamDynType::BoolList) { | |||
| std::vector<bool> ret; | |||
| size_t len = 0; | |||
| memcpy(&len, &bytes[offset], sizeof(len)); | |||
| offset += sizeof(len); | |||
| for (size_t idx =0; idx<len; ++idx) { | |||
| bool b = true; | |||
| memcpy(&b, &bytes[offset], sizeof(bool)); | |||
| offset += sizeof(bool); | |||
| ret.push_back(b); | |||
| } | |||
| return ret; | |||
| } | |||
| else if (data_type == ParamDynType::StringList) { | |||
| std::vector<std::string> ret; | |||
| size_t len = 0; | |||
| memcpy(&len, &bytes[offset], sizeof(len)); | |||
| offset += sizeof(len); | |||
| for (size_t idx =0; idx<len; ++idx) { | |||
| size_t str_len = 0; | |||
| memcpy(&str_len, &bytes[offset], sizeof(str_len)); | |||
| offset += sizeof(str_len); | |||
| std::string str(str_len, ' '); | |||
| memcpy(&str[0], &bytes[offset], str_len); | |||
| offset += str_len; | |||
| ret.push_back(str); | |||
| } | |||
| return ret; | |||
| } | |||
| switch (data_type) { | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_LOAD_BASIC) | |||
| CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST) | |||
| CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST); | |||
| default: | |||
| mgb_assert(false, "invalid param type"); | |||
| } | |||
| return {}; | |||
| } | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(+, ParamVal) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(-, ParamVal) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(*, ParamVal) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(/, ParamVal) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(==, bool) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(!=, bool) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>=, bool) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<=, bool) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>, bool) | |||
| CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<, bool) | |||
| } | |||
| @@ -0,0 +1,486 @@ | |||
| /** | |||
| * \file src/custom/impl/tensor.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/tensor.h" | |||
| #include "megbrain/comp_node.h" | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/tensor.h" | |||
| #include <cctype> | |||
| #include <algorithm> | |||
| using namespace mgb; | |||
| namespace custom { | |||
| template<typename T> | |||
| SmallVector<T> to_builtin_vector(const std::vector<T> &custom_data) { | |||
| SmallVector<T> builtin_data(custom_data.size()); | |||
| memcpy(builtin_data.data(), custom_data.data(), sizeof(T)*custom_data.size()); | |||
| return builtin_data; | |||
| } | |||
| using DeviceImpl = CompNode; | |||
| using ShapeImpl = megdnn::TensorShape; | |||
| using DTypeImpl = megdnn::DType; | |||
| using FormatImpl = megdnn::TensorLayout::Format; | |||
| using TensorImpl = DeviceTensorND; | |||
| #define DeviceImplRef(rawptr) (*reinterpret_cast<DeviceImpl*>(rawptr)) | |||
| #define ShapeImplRef(rawptr) (*reinterpret_cast<ShapeImpl*>(rawptr)) | |||
| #define DTypeImplRef(rawptr) (*reinterpret_cast<DTypeImpl*>(rawptr)) | |||
| #define FormatImplRef(rawptr) (*reinterpret_cast<FormatImpl*>(rawptr)) | |||
| #define TensorImplRef(rawptr) (*reinterpret_cast<TensorImpl*>(rawptr)) | |||
| #define DeviceImplConstRef(rawptr) static_cast<const DeviceImpl&>(*reinterpret_cast<const DeviceImpl*>(rawptr)) | |||
| #define ShapeImplConstRef(rawptr) static_cast<const ShapeImpl&>(*reinterpret_cast<const ShapeImpl*>(rawptr)) | |||
| #define DTypeImplConstRef(rawptr) static_cast<const DTypeImpl&>(*reinterpret_cast<const DTypeImpl*>(rawptr)) | |||
| #define FormatImplConstRef(rawptr) static_cast<const FormatImpl&>(*reinterpret_cast<const FormatImpl*>(rawptr)) | |||
| #define TensorImplConstRef(rawptr) static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr)) | |||
| static std::unordered_map<DeviceImpl::DeviceType, std::string, | |||
| EnumHash<DeviceImpl::DeviceType>, | |||
| EnumCmp<DeviceImpl::DeviceType>> dev_benum2cstr; | |||
| static std::unordered_map<DeviceImpl::DeviceType, DeviceEnum, | |||
| EnumHash<DeviceImpl::DeviceType>, | |||
| EnumCmp<DeviceImpl::DeviceType>> dev_benum2cenum; | |||
| static std::unordered_map<std::string, std::string> dev_cstr2bstr; | |||
| static std::unordered_map<DeviceEnum, std::string, | |||
| EnumHash<DeviceEnum>, | |||
| EnumCmp<DeviceEnum>> dev_cenum2bstr; | |||
| #define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \ | |||
| auto be2cs##custom_impl = dev_benum2cstr.emplace( \ | |||
| DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \ | |||
| auto be2ce##custom_impl = dev_benum2cenum.emplace( \ | |||
| DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \ | |||
| auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \ | |||
| std::string(#custom_impl), std::string(builtin_str)); \ | |||
| auto ce2bs##custom_impl = dev_cenum2bstr.emplace( \ | |||
| DeviceEnum::custom_impl, std::string(builtin_str)); | |||
| CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE) | |||
| #undef CUSTOM_BIND_DEVICE | |||
| CUSTOM_PIMPL_CLS_DEFINE(Device) | |||
| const void *Device::impl() const { | |||
| return m_impl.get(); | |||
| } | |||
| Device::Device(const void *impl): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||
| mgb_assert(impl != nullptr, "invalid ptr"); | |||
| if (!DeviceImplConstRef(impl).valid()) { | |||
| m_impl.reset(new DeviceImpl()); | |||
| return; | |||
| } | |||
| auto builtin_device_enum = DeviceImplConstRef(impl).device_type(); | |||
| mgb_assert( | |||
| dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(), | |||
| "unsupported compnode type: %s", DeviceImplConstRef(impl).to_string().c_str() | |||
| ); | |||
| m_impl.reset(new DeviceImpl(DeviceImplConstRef(impl))); | |||
| } | |||
| Device::Device(const std::string &device): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||
| mgb_assert(is_legal(device), "invalid device type: %s", device.c_str()); | |||
| std::string builtin_device = dev_cstr2bstr[device]; | |||
| m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); | |||
| } | |||
| // to avoid the ambiguous from Device(const void *impl) | |||
| Device::Device(const char *device): Device(std::string(device)) { | |||
| } | |||
| Device::Device(DeviceEnum device): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||
| mgb_assert(is_legal(device), "invalid device type"); | |||
| std::string builtin_device = dev_cenum2bstr[device]; | |||
| m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); | |||
| } | |||
| std::string Device::str(void) const { | |||
| if (!DeviceImplRef(m_impl.get()).valid()) { | |||
| return "invalid"; | |||
| } | |||
| auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); | |||
| auto iter = dev_benum2cstr.find(builtin_device_type); | |||
| mgb_assert( | |||
| iter != dev_benum2cstr.end(), "invalid device type %s\n", | |||
| DeviceImplRef(m_impl.get()).to_string().c_str() | |||
| ); | |||
| return iter->second; | |||
| } | |||
| DeviceEnum Device::enumv(void) const { | |||
| mgb_assert( | |||
| DeviceImplRef(m_impl.get()).valid(), | |||
| "cannot get the enum value of invalid device" | |||
| ); | |||
| auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); | |||
| auto iter = dev_benum2cenum.find(builtin_device_type); | |||
| mgb_assert( | |||
| iter != dev_benum2cenum.end(), "invalid device type %s\n", | |||
| DeviceImplRef(m_impl.get()).to_string().c_str() | |||
| ); | |||
| return iter->second; | |||
| } | |||
| bool Device::is_legal(const std::string &device_type) { | |||
| return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end(); | |||
| } | |||
| bool Device::is_legal(DeviceEnum device_type) { | |||
| return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end(); | |||
| } | |||
| std::vector<std::string> Device::legal_devices(void) { | |||
| std::vector<std::string> ret; | |||
| for (const auto &kv: dev_cstr2bstr) { | |||
| ret.emplace_back(kv.first); | |||
| } | |||
| return ret; | |||
| } | |||
| bool operator==(const Device &lhs, const Device &rhs) { | |||
| return lhs.str() == rhs.str(); | |||
| } | |||
| CUSTOM_PIMPL_CLS_DEFINE(Shape) | |||
| const void *Shape::impl() const { | |||
| return m_impl.get(); | |||
| } | |||
| Shape::Shape(const void *impl): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||
| mgb_assert(impl != nullptr, "invalid ptr"); | |||
| m_impl.reset(new ShapeImpl(ShapeImplConstRef(impl))); | |||
| } | |||
| Shape::Shape(const std::vector<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||
| m_impl.reset(new ShapeImpl(to_builtin_vector<size_t>(rhs))); | |||
| } | |||
| Shape::Shape(const std::initializer_list<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||
| m_impl.reset(new ShapeImpl(rhs)); | |||
| } | |||
| size_t &Shape::operator[](size_t idx) { | |||
| mgb_assert(idx < ndim(), "wrong tensor dimension idx: %lu < %lu", static_cast<unsigned long>(idx), static_cast<unsigned long>(ndim())); | |||
| return ShapeImplRef(m_impl.get()).operator[](idx); | |||
| } | |||
| size_t Shape::operator[](size_t idx) const { | |||
| return const_cast<Shape*>(this)->operator[](idx); | |||
| } | |||
| void Shape::ndim(size_t dim) { | |||
| mgb_assert(dim < ShapeImpl::MAX_NDIM, "dimension must <= %lu", static_cast<unsigned long>(ShapeImpl::MAX_NDIM)); | |||
| ShapeImplRef(m_impl.get()).ndim = dim; | |||
| } | |||
| size_t Shape::ndim(void) const { | |||
| return ShapeImplRef(m_impl.get()).ndim; | |||
| } | |||
| bool operator==(const Shape &lhs, const Shape &rhs) { | |||
| return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get())); | |||
| } | |||
| static std::unordered_map<std::string, megdnn::DTypeEnum> dtype_cstr2benum; | |||
| static std::unordered_map<DTypeEnum, megdnn::DTypeEnum, | |||
| EnumHash<DTypeEnum>, | |||
| EnumCmp<DTypeEnum>> dtype_cenum2benum; | |||
| static std::unordered_map<megdnn::DTypeEnum, std::string, | |||
| EnumHash<megdnn::DTypeEnum>, | |||
| EnumCmp<megdnn::DTypeEnum>> dtype_benum2cstr; | |||
| static std::unordered_map<megdnn::DTypeEnum, DTypeEnum, | |||
| EnumHash<megdnn::DTypeEnum>, | |||
| EnumCmp<megdnn::DTypeEnum>> dtype_benum2cenum; | |||
| static std::unordered_map<DTypeEnum, std::string, | |||
| EnumHash<DTypeEnum>, | |||
| EnumCmp<DTypeEnum>> dtype_cenum2cstr; | |||
| #define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \ | |||
| auto cs2be##custom_impl = dtype_cstr2benum.emplace( \ | |||
| std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \ | |||
| auto ce2be##custom_impl = dtype_cenum2benum.emplace( \ | |||
| DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \ | |||
| auto be2cs##custom_impl = dtype_benum2cstr.emplace( \ | |||
| megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \ | |||
| auto be2ce##custom_impl = dtype_benum2cenum.emplace( \ | |||
| megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \ | |||
| auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \ | |||
| DTypeEnum::custom_impl, std::string(#custom_impl)); | |||
| CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE) | |||
| #undef CUSTOM_BIND_DTYPE | |||
| CUSTOM_PIMPL_CLS_DEFINE(DType) | |||
| const void *DType::impl() const { | |||
| return m_impl.get(); | |||
| } | |||
| DType::DType(const void *impl): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
| mgb_assert(impl != nullptr, "invalid ptr"); | |||
| m_impl.reset(new DTypeImpl(DTypeImplConstRef(impl))); | |||
| } | |||
| DType::DType(const std::string &dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
| auto iter = dtype_cstr2benum.find(dtype); | |||
| mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); | |||
| mgb_assert( | |||
| dtype[0] != 'q', "can not construct quantized dtype " | |||
| "%s without scale and zero_point", dtype.c_str() | |||
| ); | |||
| m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); | |||
| } | |||
| DType::DType(const char *dtype): DType(std::string(dtype)) { | |||
| } | |||
| DType::DType(const std::string &dtype, float scale, uint8_t zero_point) | |||
| : m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
| auto iter = dtype_cstr2benum.find(dtype); | |||
| mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); | |||
| mgb_assert( | |||
| dtype[0] == 'q', "given scale/zero_point to construct " | |||
| "non-quantized dtype: %s is not allowed", dtype.c_str() | |||
| ); | |||
| if (dtype == "quint8") { | |||
| m_impl.reset(new megdnn::ParameterizedDType< | |||
| megdnn::DTypeEnum::Quantized8Asymm>(scale, zero_point)); | |||
| } | |||
| else { | |||
| mgb_assert( | |||
| zero_point == 0, "invalid zero point %d for dtype %s", | |||
| zero_point, dtype.c_str() | |||
| ); | |||
| if (dtype == "qint8") { | |||
| m_impl.reset(new megdnn::ParameterizedDType< | |||
| megdnn::DTypeEnum::QuantizedS8>(scale)); | |||
| } | |||
| else if (dtype == "qint16") { | |||
| m_impl.reset(new megdnn::ParameterizedDType< | |||
| megdnn::DTypeEnum::QuantizedS16>(scale)); | |||
| } | |||
| else if (dtype == "qint32") { | |||
| m_impl.reset(new megdnn::ParameterizedDType< | |||
| megdnn::DTypeEnum::QuantizedS32>(scale)); | |||
| } | |||
| else { | |||
| mgb_assert(false, "invalid dtype %s", dtype.c_str()); | |||
| } | |||
| } | |||
| } | |||
| DType::DType(const char *dtype, float scale, uint8_t zero_point) | |||
| : DType(std::string(dtype), scale, zero_point) { | |||
| } | |||
| DType::DType(DTypeEnum dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
| auto iter = dtype_cenum2benum.find(dtype); | |||
| mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype"); | |||
| mgb_assert(dtype < DTypeEnum::quint8, | |||
| "can not construct quantized dtype without scale and zero_point"); | |||
| m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); | |||
| } | |||
| DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point) | |||
| : DType(dtype_cenum2cstr.find(dtype)->second, scale, zero_point) { | |||
| } | |||
| std::string DType::str(void) const { | |||
| if (!DTypeImplRef(m_impl.get()).valid()) | |||
| return "invalid"; | |||
| auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv()); | |||
| if (iter == dtype_benum2cstr.end()) | |||
| return "invalid"; | |||
| return iter->second; | |||
| } | |||
| DTypeEnum DType::enumv(void) const { | |||
| auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv()); | |||
| mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype"); | |||
| return iter->second; | |||
| } | |||
| float DType::scale() const { | |||
| if (enumv() == DTypeEnum::qint8) { | |||
| return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS8>().scale; | |||
| } | |||
| else if (enumv() == DTypeEnum::qint16) { | |||
| return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS16>().scale; | |||
| } | |||
| else if (enumv() == DTypeEnum::qint32) { | |||
| return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS32>().scale; | |||
| } | |||
| else if (enumv() == DTypeEnum::quint8) { | |||
| return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().scale; | |||
| } | |||
| else { | |||
| mgb_assert(false, "dtype %s has no scale", str().c_str()); | |||
| return 0.f; | |||
| } | |||
| } | |||
| uint8_t DType::zero_point() const { | |||
| mgb_assert(enumv()==DTypeEnum::quint8, "dtype %s has no zero point", str().c_str()); | |||
| return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().zero_point; | |||
| } | |||
| bool DType::is_legal(const std::string &dtype) { | |||
| return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end(); | |||
| } | |||
| bool DType::is_legal(const DTypeEnum &dtype) { | |||
| return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end(); | |||
| } | |||
| std::vector<std::string> DType::legal_dtypes(void) { | |||
| std::vector<std::string> ret; | |||
| for (const auto &kv: dtype_cstr2benum) | |||
| ret.emplace_back(kv.first); | |||
| return ret; | |||
| } | |||
| bool operator==(const DType &lhs, const DType &rhs) { | |||
| return DTypeImplRef(lhs.m_impl.get()) == DTypeImplRef(rhs.m_impl.get()); | |||
| } | |||
| bool operator==(const DType &lhs, const std::string &rhs) { | |||
| return lhs.str() == rhs; | |||
| } | |||
| bool operator==(const DType &lhs, const char *rhs) { | |||
| return operator==(lhs, std::string(rhs)); | |||
| } | |||
| bool operator==(const std::string &lhs, const DType &rhs) { | |||
| return operator==(rhs, lhs); | |||
| } | |||
| bool operator==(const char *lhs, const DType &rhs) { | |||
| return operator==(rhs, std::string(lhs)); | |||
| } | |||
| CUSTOM_PIMPL_CLS_DEFINE(Format) | |||
| const void *Format::impl() const { | |||
| return m_impl.get(); | |||
| } | |||
| Format::Format(const void *impl): m_impl(nullptr, impl_deleter<FormatImpl>) { | |||
| mgb_assert(impl != nullptr, "invalid ptr"); | |||
| mgb_assert(FormatImplConstRef(impl).is_default(), "only default format is supported now"); | |||
| m_impl.reset(new FormatImpl(FormatImplConstRef(impl))); | |||
| } | |||
| Format::Format(const std::string &format): m_impl(nullptr, impl_deleter<FormatImpl>) { | |||
| mgb_assert(format == "default", "only default format is supported now"); | |||
| m_impl.reset(new FormatImpl()); | |||
| } | |||
| Format::Format(const char *format): Format(std::string(format)) { | |||
| } | |||
| std::string Format::str(void) const { | |||
| return FormatImplRef(m_impl.get()).to_string(); | |||
| } | |||
| bool Format::is_default(void) const { | |||
| return FormatImplRef(m_impl.get()).is_default(); | |||
| } | |||
| const void *Tensor::impl(void) const { | |||
| return m_tensor; | |||
| } | |||
| Tensor::Tensor(const void *impl) { | |||
| mgb_assert(impl != nullptr, "invalid ptr"); | |||
| m_tensor = const_cast<void*>(impl); | |||
| } | |||
| const size_t *Tensor::shapes_raw(void) const { | |||
| return TensorImplRef(m_tensor).shape().shape; | |||
| } | |||
| const ptrdiff_t *Tensor::strides_raw(void) const { | |||
| return TensorImplRef(m_tensor).layout().stride; | |||
| } | |||
| Tensor::Tensor(const Tensor &rhs) { | |||
| mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for copy constructor\n"); | |||
| m_tensor = rhs.m_tensor; | |||
| } | |||
| Tensor &Tensor::operator=(const Tensor &rhs) { | |||
| mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for assignment operator"); | |||
| if (&rhs == this || rhs.m_tensor == m_tensor) | |||
| return *this; | |||
| m_tensor = rhs.m_tensor; | |||
| return *this; | |||
| } | |||
| Shape Tensor::shape(void) const { | |||
| auto builtin = TensorImplRef(m_tensor).shape(); | |||
| return Shape(&builtin); | |||
| } | |||
| DType Tensor::dtype(void) const { | |||
| auto builtin = TensorImplRef(m_tensor).dtype(); | |||
| return DType(&builtin); | |||
| } | |||
| Format Tensor::format(void) const { | |||
| auto builtin = TensorImplRef(m_tensor).format(); | |||
| return Format(&builtin); | |||
| } | |||
| Device Tensor::device(void) const { | |||
| auto builtin = TensorImplRef(m_tensor).comp_node(); | |||
| return Device(&builtin); | |||
| } | |||
| size_t Tensor::size(void) const { | |||
| return TensorImplRef(m_tensor).shape().total_nr_elems(); | |||
| } | |||
| std::vector<ptrdiff_t> Tensor::stride(void) const { | |||
| std::vector<ptrdiff_t> ret(TensorImplRef(m_tensor).shape().ndim); | |||
| for (size_t i=0; i<ret.size(); i++) | |||
| ret[i] = TensorImplRef(m_tensor).layout().stride[i]; | |||
| return ret; | |||
| } | |||
| float Tensor::scale(void) const { | |||
| return dtype().scale(); | |||
| } | |||
| uint8_t Tensor::zero_point(void) const { | |||
| return dtype().zero_point(); | |||
| } | |||
| void *Tensor::data(void) { | |||
| return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr()); | |||
| } | |||
| const void *Tensor::data(void) const { | |||
| return static_cast<const void*>(TensorImplRef(m_tensor).raw_ptr()); | |||
| } | |||
| } // namespace custom | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * \file src/custom/impl/utils.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/utils.h" | |||
| #include "megbrain/common.h" | |||
| #include <sstream> | |||
| using namespace mgb; | |||
| namespace custom { | |||
| void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...) { | |||
| std::string msg = ssprintf("`%s' is true at %s:%d: %s", expr, file, line, func); | |||
| if (msg_fmt) { | |||
| msg_fmt = convert_fmt_str(msg_fmt); | |||
| va_list ap; | |||
| va_start(ap, msg_fmt); | |||
| msg.append("\nextra message: "); | |||
| msg.append(svsprintf(msg_fmt, ap)); | |||
| va_end(ap); | |||
| } | |||
| printf("%s\n", msg.c_str()); | |||
| } | |||
| UnImpleWarnLog::UnImpleWarnLog(const std::string &func, const std::string &attr, | |||
| const std::string &val) { | |||
| mgb_log_warn("you are using the default custom %s function, the `%s` attribute " | |||
| "of all the outputs tensor will be the same with inputs tensor[0]. " | |||
| "If there is no input tensor, it will be `%s`", | |||
| func.c_str(), attr.c_str(), val.c_str()); | |||
| } | |||
| } | |||
| @@ -0,0 +1,185 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/accessor.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <cstddef> | |||
| #include <cstdint> | |||
| namespace custom { | |||
| #ifdef __CUDACC__ | |||
| #define CUSTOM_HOST __host__ | |||
| #define CUSTOM_DEVICE __device__ | |||
| #else | |||
| #define CUSTOM_HOST | |||
| #define CUSTOM_DEVICE | |||
| #endif | |||
| #define CUSTOM_HOST_DEVICE CUSTOM_HOST CUSTOM_DEVICE | |||
| template <typename T> | |||
| struct DefaultPtrTraits { | |||
| using PtrType = T*; | |||
| }; | |||
| #ifdef __CUDACC__ | |||
| template <typename T> | |||
| struct RestrictPtrTraits { | |||
| using PtrType = T* __restrict__; | |||
| }; | |||
| #endif | |||
| template <typename T, size_t N, | |||
| template <typename U> class PtrTraits = DefaultPtrTraits, | |||
| typename index_t = int64_t> | |||
| class TensorAccessorProxyBase { | |||
| public: | |||
| using PtrType = typename PtrTraits<T>::PtrType; | |||
| protected: | |||
| PtrType m_data; | |||
| const index_t* m_sizes; | |||
| const index_t* m_strides; | |||
| public: | |||
| CUSTOM_HOST_DEVICE TensorAccessorProxyBase(PtrType data, const index_t *sizes, const index_t *strides) { | |||
| m_data = data; | |||
| m_sizes = sizes; | |||
| m_strides = strides; | |||
| } | |||
| CUSTOM_HOST_DEVICE index_t stride(index_t i) const { | |||
| return m_strides[i]; | |||
| } | |||
| CUSTOM_HOST_DEVICE index_t size(index_t i) const { | |||
| return m_sizes[i]; | |||
| } | |||
| CUSTOM_HOST_DEVICE PtrType data() const { | |||
| return m_data; | |||
| } | |||
| }; | |||
| template<typename T, size_t N, | |||
| template <typename U> class PtrTraits = DefaultPtrTraits, | |||
| typename index_t = int64_t> | |||
| class TensorAccessorProxy: public TensorAccessorProxyBase<T, N, PtrTraits, index_t> { | |||
| public: | |||
| using PtrType = typename PtrTraits<T>::PtrType; | |||
| CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) | |||
| : TensorAccessorProxyBase<T, N, PtrTraits, index_t>(data, sizes, strides) { | |||
| } | |||
| CUSTOM_HOST_DEVICE TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) { | |||
| return TensorAccessorProxy<T, N-1, PtrTraits, index_t>( | |||
| this->m_data + this->m_strides[0] * i, | |||
| this->m_sizes + 1, | |||
| this->m_strides + 1 | |||
| ); | |||
| } | |||
| CUSTOM_HOST_DEVICE const TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) const { | |||
| return TensorAccessorProxy<T, N-1, PtrTraits, index_t>( | |||
| this->m_data + this->m_strides[0] * i, | |||
| this->m_sizes + 1, | |||
| this->m_strides + 1 | |||
| ); | |||
| } | |||
| }; | |||
| template<typename T, template <typename U> class PtrTraits, typename index_t> | |||
| class TensorAccessorProxy<T, 1, PtrTraits, index_t> | |||
| : public TensorAccessorProxyBase<T, 1, PtrTraits, index_t> { | |||
| public: | |||
| using PtrType = typename PtrTraits<T>::PtrType; | |||
| CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) | |||
| : TensorAccessorProxyBase<T, 1, PtrTraits, index_t>(data, sizes, strides ) { | |||
| } | |||
| CUSTOM_HOST_DEVICE T &operator[](index_t i) { | |||
| return this->m_data[this->m_strides[0]*i]; | |||
| } | |||
| CUSTOM_HOST_DEVICE const T &operator[](index_t i) const { | |||
| return this->m_data[this->m_strides[0]*i]; | |||
| } | |||
| }; | |||
| template<typename T, size_t N, | |||
| template <typename U> class PtrTraits = DefaultPtrTraits, | |||
| typename index_t = int64_t> | |||
| class TensorAccessorBase { | |||
| public: | |||
| using PtrType = typename PtrTraits<T>::PtrType; | |||
| protected: | |||
| PtrType m_data; | |||
| index_t m_sizes[N]; | |||
| index_t m_strides[N]; | |||
| public: | |||
| CUSTOM_HOST_DEVICE TensorAccessorBase(PtrType data, const size_t *sizes, const ptrdiff_t *strides) { | |||
| m_data = data; | |||
| for (size_t i=0; i<N; ++i) { | |||
| m_sizes[i] = sizes[i]; | |||
| m_strides[i] = strides[i]; | |||
| } | |||
| } | |||
| CUSTOM_HOST_DEVICE index_t stride(index_t i) const { | |||
| return m_strides[i]; | |||
| } | |||
| CUSTOM_HOST_DEVICE index_t size(index_t i) const { | |||
| return m_sizes[i]; | |||
| } | |||
| CUSTOM_HOST_DEVICE PtrType data() const { | |||
| return m_data; | |||
| } | |||
| }; | |||
| template<typename T, size_t N, | |||
| template <typename U> class PtrTraits = DefaultPtrTraits, | |||
| typename index_t = int64_t> | |||
| class TensorAccessor: public TensorAccessorBase<T, N, PtrTraits, index_t> { | |||
| public: | |||
| using PtrType = typename PtrTraits<T>::PtrType; | |||
| CUSTOM_HOST_DEVICE TensorAccessor(PtrType data, const size_t *sizes, const ptrdiff_t *strides) | |||
| : TensorAccessorBase<T, N, PtrTraits, index_t>(data, sizes, strides) { | |||
| } | |||
| CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) { | |||
| return TensorAccessorProxy<T, N, PtrTraits, index_t>( | |||
| this->m_data, | |||
| this->m_sizes, | |||
| this->m_strides | |||
| )[i]; | |||
| } | |||
| CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) const { | |||
| return TensorAccessorProxy<T, N, PtrTraits, index_t>( | |||
| this->m_data, | |||
| this->m_sizes, | |||
| this->m_strides | |||
| )[i]; | |||
| } | |||
| }; | |||
| } | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/custom.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "op.h" | |||
| #include "tensor.h" | |||
| #include "param.h" | |||
| namespace custom { | |||
| std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version); | |||
| } | |||
| #define CUSTOM_OP_REG(OpName) CustomOp &_##OpName = (*(op_insert(#OpName, CUSTOM_OP_VERSION))) | |||
| #define CUSTOM_OP_REG_BEGIN(OpName) \ | |||
| namespace custom { \ | |||
| namespace OpName { | |||
| #define CUSTOM_OP_REG_END(OpName) \ | |||
| } \ | |||
| } | |||
| #define CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, hint, ...) \ | |||
| case (case_type): { \ | |||
| using hint = real_type; \ | |||
| return __VA_ARGS__(); \ | |||
| } | |||
| #define CASE_TO_PERFORM_ON_SCALAR(name, case_type, real_type, ...) \ | |||
| CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, scalar_t, __VA_ARGS__) | |||
| #define DISPATCH_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||
| [&]() { \ | |||
| const auto &dtype = tensor_dtype; \ | |||
| switch (dtype.enumv()) { \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||
| default: \ | |||
| custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
| name, dtype.str().c_str()); \ | |||
| } \ | |||
| }() | |||
| #define DISPATCH_INT_TYPES(tensor_dtype, name, ...) \ | |||
| [&]() { \ | |||
| const auto &dtype = tensor_dtype; \ | |||
| switch (dtype.enumv()) { \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
| default: \ | |||
| custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
| name, dtype.str().c_str()); \ | |||
| } \ | |||
| }() | |||
| #define DISPATCH_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||
| [&]() { \ | |||
| const auto &dtype = tensor_dtype; \ | |||
| switch (dtype.enumv()) { \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||
| default: \ | |||
| custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
| name, dtype.str().c_str()); \ | |||
| } \ | |||
| }() | |||
| #define DISPATCH_SIGN_INT_TYPES(tensor_dtype, name, ...) \ | |||
| [&]() { \ | |||
| const auto &dtype = tensor_dtype; \ | |||
| switch (dtype.enumv()) { \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
| default: \ | |||
| custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
| name, dtype.str().c_str()); \ | |||
| } \ | |||
| }() | |||
| #define DISPATCH_SIGN_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||
| [&]() { \ | |||
| const auto &dtype = tensor_dtype; \ | |||
| switch (dtype.enumv()) { \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
| CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
| default: \ | |||
| custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
| name, dtype.str().c_str()); \ | |||
| } \ | |||
| }() | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/data_adaptor.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/thin/small_vector.h" | |||
| namespace custom { | |||
| template <typename BuiltinT, typename CustomT> | |||
| BuiltinT to_builtin(const CustomT &custom) { | |||
| return *reinterpret_cast<const BuiltinT*>(custom.impl()); | |||
| } | |||
| template <typename BuiltinT, typename CustomT> | |||
| CustomT to_custom(const BuiltinT &builtin) { | |||
| return std::move(CustomT(&builtin)); | |||
| } | |||
| template <typename BuiltinT, typename CustomT> | |||
| megdnn::SmallVector<BuiltinT> to_builtin(const std::vector<CustomT> &customs) { | |||
| megdnn::SmallVector<BuiltinT> builtins; | |||
| for (size_t i=0; i<customs.size(); ++i) { | |||
| builtins.push_back(std::move(to_builtin<BuiltinT, CustomT>(customs[i]))); | |||
| } | |||
| return std::move(builtins); | |||
| } | |||
| template <typename BuiltinT, typename CustomT> | |||
| std::vector<CustomT> to_custom( | |||
| const megdnn::SmallVector<BuiltinT> &builtins) { | |||
| std::vector<CustomT> customs; | |||
| for (size_t i=0; i<builtins.size(); ++i) { | |||
| customs.push_back(std::move(to_custom<BuiltinT, CustomT>(builtins[i]))); | |||
| } | |||
| return std::move(customs); | |||
| } | |||
| } | |||
| #define to_custom_device(expr) custom::to_custom<CompNode, custom::Device>(expr) | |||
| #define to_builtin_device(expr) custom::to_builtin<CompNode, custom::Device>(expr) | |||
| #define to_custom_shape(expr) custom::to_custom<megdnn::TensorShape, custom::Shape>(expr) | |||
| #define to_builtin_shape(expr) custom::to_builtin<megdnn::TensorShape, custom::Shape>(expr) | |||
| #define to_custom_dtype(expr) custom::to_custom<megdnn::DType, custom::DType>(expr) | |||
| #define to_builtin_dtype(expr) custom::to_builtin<megdnn::DType, custom::DType>(expr) | |||
| #define to_custom_format(expr) custom::to_custom<megdnn::TensorLayout::Format, custom::Format>(expr) | |||
| #define to_builtin_format(expr) custom::to_builtin<megdnn::TensorLayout::Format, custom::Format>(expr) | |||
| #define to_custom_tensor(expr) custom::to_custom<DeviceTensorND, custom::Tensor>(expr) | |||
| #define to_builtin_tensor(expr) custom::to_builtin<DeviceTensorND, custom::Tensor>(expr) | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/manager.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "custom.h" | |||
| #include "megbrain/common.h" | |||
| namespace custom { | |||
| class CustomOpManager { | |||
| std::unordered_map<std::string, std::shared_ptr<const CustomOp>> m_name2op; | |||
| std::unordered_map<RunTimeId, std::shared_ptr<const CustomOp>> m_id2op; | |||
| MGB_MUTEX m_mtx; | |||
| CustomOpManager() = default; | |||
| public: | |||
| PREVENT_COPY_AND_ASSIGN(CustomOpManager); | |||
| static CustomOpManager *inst(void); | |||
| ~CustomOpManager(); | |||
| std::shared_ptr<CustomOp> insert(const std::string &name, uint32_t version); | |||
| bool erase(const std::string &name); | |||
| bool erase(const RunTimeId &id); | |||
| std::shared_ptr<CustomOp> find_or_reg(const std::string &name, uint32_t version); | |||
| RunTimeId to_id(const std::string &name) const; | |||
| std::string to_name(const RunTimeId &id) const; | |||
| std::shared_ptr<const CustomOp> find(const std::string &name) const; | |||
| std::shared_ptr<const CustomOp> find(const RunTimeId &id) const; | |||
| std::vector<std::string> op_name_list(void); | |||
| std::vector<RunTimeId> op_id_list(void); | |||
| }; | |||
| class CustomLib { | |||
| std::unique_ptr<void, void_deleter> m_handle; | |||
| std::vector<std::string> m_ops; | |||
| public: | |||
| PREVENT_COPY_AND_ASSIGN(CustomLib); | |||
| CustomLib(const std::string &path, int mode); | |||
| const std::vector<std::string> &ops_in_lib(void) const; | |||
| ~CustomLib(); | |||
| bool valid(void) const; | |||
| }; | |||
| using LibHandle = std::shared_ptr<CustomLib>; | |||
| class LibManager { | |||
| std::unordered_map<std::string, LibHandle> m_custom_libs; | |||
| MGB_MUTEX m_mtx; | |||
| LibManager() = default; | |||
| public: | |||
| PREVENT_COPY_AND_ASSIGN(LibManager); | |||
| static LibManager *inst(void); | |||
| const std::vector<std::string> &install(const std::string &name, const std::string &path); | |||
| bool uninstall(const std::string &name); | |||
| friend class CustomOpManager; | |||
| }; | |||
| } | |||
| @@ -0,0 +1,109 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/op.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "tensor.h" | |||
| #include "param.h" | |||
| #include <unordered_set> | |||
| #define PREVENT_COPY_AND_ASSIGN(Cls) \ | |||
| Cls(const Cls&) = delete; \ | |||
| Cls(const Cls&&) = delete; \ | |||
| Cls &operator=(const Cls&) = delete; \ | |||
| Cls &operator=(const Cls&&) = delete | |||
| #define CUSTOM_OP_MAJOR 0 | |||
| #define CUSTOM_OP_MINOR 1 | |||
| #define CUSTOM_OP_PATCH 0 | |||
| #define CUSTOM_OP_VERSION CUSTOM_OP_MAJOR*10000 + CUSTOM_OP_MINOR*100 + CUSTOM_OP_PATCH | |||
| namespace custom { | |||
| using RunTimeId = uint64_t; | |||
| class ArgInfo { | |||
| CUSTOM_PIMPL_CLS_DECL(ArgInfo); | |||
| ArgInfo(const std::string &name, | |||
| const std::string &desc, | |||
| const std::unordered_set<std::string> &dtypes, | |||
| const int &ndim, | |||
| const std::string &mem_stgy); | |||
| const std::string &name(void) const; | |||
| const std::string &desc(void) const; | |||
| const std::unordered_set<std::string> &dtypes(void) const; | |||
| int ndim(void) const; | |||
| const std::string &mem_strategy(void) const; | |||
| std::string str() const; | |||
| }; | |||
| class CustomOp { | |||
| std::unique_ptr<void, void_deleter> m_impl; | |||
| public: | |||
| CustomOp(const std::string &op_type, uint32_t version); | |||
| PREVENT_COPY_AND_ASSIGN(CustomOp); | |||
| using DeviceInferFuncPtr = void(*)(const std::vector<Device>&, const Param&, std::vector<Device>&); | |||
| using ShapeInferFuncPtr = void(*)(const std::vector<Shape>&, const Param&, std::vector<Shape>&); | |||
| using DTypeInferFuncPtr = void(*)(const std::vector<DType>&, const Param&, std::vector<DType>&); | |||
| using FormatInferFuncPtr = void(*)(const std::vector<Format>&, const Param&, std::vector<Format>&); | |||
| using PreprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||
| using PostprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||
| using ComputeFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||
| // write for forward | |||
| CustomOp &set_device_infer(DeviceInferFuncPtr func); | |||
| CustomOp &set_shape_infer(ShapeInferFuncPtr func); | |||
| CustomOp &set_dtype_infer(DTypeInferFuncPtr func); | |||
| CustomOp &set_format_infer(FormatInferFuncPtr func); | |||
| CustomOp &set_preprocess(PreprocessFuncPtr func); | |||
| CustomOp &set_preprocess(const std::string &device, PreprocessFuncPtr func); | |||
| CustomOp &set_postprocess(PostprocessFuncPtr func); | |||
| CustomOp &set_postprocess(const std::string &device, PostprocessFuncPtr func); | |||
| CustomOp &set_compute(ComputeFuncPtr func); | |||
| CustomOp &set_compute(const std::string &device, ComputeFuncPtr func); | |||
| CustomOp &set_description(const std::string &op_desc); | |||
| CustomOp &add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
| CustomOp &add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
| CustomOp &add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
| CustomOp &add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
| CustomOp &add_inputs(const size_t &input_num); | |||
| CustomOp &add_outputs(const size_t &output_num); | |||
| CustomOp &add_param(const std::string &name, const ParamVal &default_val); | |||
| CustomOp &add_param(const std::string &name, const std::string &desc, const ParamVal &default_val); | |||
| // read | |||
| std::string op_type(void) const; | |||
| std::string op_desc(void) const; | |||
| RunTimeId runtime_id(void) const; | |||
| size_t input_num(void) const; | |||
| size_t output_num(void) const; | |||
| std::string str(void) const; | |||
| const ParamInfo ¶m_info(void) const; | |||
| ArgInfo input_info(size_t idx) const; | |||
| ArgInfo output_info(size_t idx) const; | |||
| const std::vector<ArgInfo> &inputs_info(void) const; | |||
| const std::vector<ArgInfo> &outputs_info(void) const; | |||
| // use | |||
| std::vector<Device> infer_output_device(const std::vector<Device>&, const Param&) const; | |||
| std::vector<Shape> infer_output_shape (const std::vector<Shape>&, const Param&) const; | |||
| std::vector<DType> infer_output_dtype (const std::vector<DType>&, const Param&) const; | |||
| std::vector<Format> infer_output_format(const std::vector<Format>&, const Param&) const; | |||
| void compute(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) const; | |||
| }; | |||
| } | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/param.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "param_val.h" | |||
| namespace custom { | |||
| class ParamSchemaImpl; | |||
| class ParamInfoImpl; | |||
| class ParamImpl; | |||
| // Schema of a param element | |||
| class ParamSchema { | |||
| CUSTOM_PIMPL_CLS_DECL(ParamSchema); | |||
| ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc=""); | |||
| const std::string &name(void) const; | |||
| const std::string &desc(void) const; | |||
| const ParamVal &default_val(void) const; | |||
| ParamDynType type(void) const; | |||
| std::string str(void) const; | |||
| }; | |||
| class ParamInfo { | |||
| CUSTOM_PIMPL_CLS_DECL(ParamInfo); | |||
| void set_tag(const std::string&); | |||
| void set_meta(const std::vector<ParamSchema> &meta); | |||
| uint32_t tag(void) const; | |||
| std::vector<ParamSchema> &meta(void); | |||
| const std::vector<ParamSchema> &meta(void) const; | |||
| }; | |||
| class Param { | |||
| CUSTOM_PIMPL_CLS_DECL(Param); | |||
| Param(const ParamInfo&); | |||
| ParamVal &operator[](const std::string&); | |||
| const ParamVal &operator[](const std::string&) const; | |||
| const std::unordered_map<std::string, ParamVal> &raw() const; | |||
| bool exist(const std::string &name) const; | |||
| std::string to_bytes(void) const; | |||
| void from_bytes(const std::string&); | |||
| }; | |||
| bool operator==(const Param&, const Param&); | |||
| } // custom | |||
| @@ -0,0 +1,290 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/param_val.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| #include <vector> | |||
| #include <cassert> | |||
| #include <sstream> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "utils.h" | |||
| namespace custom { | |||
| /** | |||
| * we can add a new basic data type here, basic means we can perform binary | |||
| * op such as: +, -, *, /, ==, != between any two of them | |||
| */ | |||
| #define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \ | |||
| cb(Int32, int32_t, ##__VA_ARGS__) \ | |||
| cb(Int64, int64_t, ##__VA_ARGS__) \ | |||
| cb(Uint32, uint32_t, ##__VA_ARGS__) \ | |||
| cb(Uint64, uint64_t, ##__VA_ARGS__) \ | |||
| cb(Float32, float, ##__VA_ARGS__) \ | |||
| cb(Float64, double, ##__VA_ARGS__) \ | |||
| cb(Bool, bool, ##__VA_ARGS__) | |||
| #define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) \ | |||
| cb(String, std::string, ##__VA_ARGS__) | |||
| #define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \ | |||
| cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \ | |||
| cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \ | |||
| cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \ | |||
| cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \ | |||
| cb(Float32List, std::vector<float>, ##__VA_ARGS__) \ | |||
| cb(Float64List, std::vector<double>, ##__VA_ARGS__) | |||
| #define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \ | |||
| cb(BoolList, std::vector<bool>, ##__VA_ARGS__) | |||
| #define CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ...) \ | |||
| cb(StringList, std::vector<std::string>, ##__VA_ARGS__) | |||
| /** | |||
| * to avoid the recursive of MACRO | |||
| */ | |||
| #define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \ | |||
| cb(Int32, int32_t, ##__VA_ARGS__) \ | |||
| cb(Int64, int64_t, ##__VA_ARGS__) \ | |||
| cb(Uint32, uint32_t, ##__VA_ARGS__) \ | |||
| cb(Uint64, uint64_t, ##__VA_ARGS__) \ | |||
| cb(Float32, float, ##__VA_ARGS__) \ | |||
| cb(Float64, double, ##__VA_ARGS__) \ | |||
| cb(Bool, bool, ##__VA_ARGS__) | |||
| #define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \ | |||
| CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
| CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
| CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
| CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
| CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) | |||
| #define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \ | |||
| CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
| CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
| CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) | |||
| /** | |||
| * Macro Callback for Register | |||
| */ | |||
| #define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type, | |||
| #define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) {ParamDynType::dyn_type, #dyn_type}, | |||
| #define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \ | |||
| template <> \ | |||
| struct get_dyn_type<static_type> { \ | |||
| static constexpr ParamDynType type = ParamDynType::dyn_type;\ | |||
| }; | |||
| #define CUSTOM_REG_STATIC_PARAMTYPE_GETTER(dyn_type, static_type) \ | |||
| template <> \ | |||
| struct get_static_type<ParamDynType::dyn_type> { \ | |||
| using type = static_type; \ | |||
| }; | |||
| enum class ParamDynType: uint32_t { | |||
| CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE) | |||
| Invalid=255 | |||
| }; | |||
| static std::unordered_map<ParamDynType, std::string, EnumHash<ParamDynType>, EnumCmp<ParamDynType>> type2name = { | |||
| CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME) | |||
| {ParamDynType::Invalid, "Invalid"} | |||
| }; | |||
| /** | |||
| * get the dynamic data type according to the builtin static data type | |||
| * we can use it like: | |||
| * ParamDynType dyn_type = get_dyn_type<int32_t>::type; | |||
| * assert(dyn_type == ParamDynType::Int32) | |||
| */ | |||
| template <typename T> | |||
| struct get_dyn_type { | |||
| static constexpr ParamDynType type = ParamDynType::Invalid; | |||
| }; | |||
| /** | |||
| * get the static data type according to the dynamic data type | |||
| * we can use it like: | |||
| * get_static_type<ParamDynType::Int32>::type int_32_value; | |||
| * assert(std::is_same<decltype(int_32_value), int>::value) | |||
| */ | |||
| template <ParamDynType> | |||
| struct get_static_type; | |||
| CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER) | |||
| CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER) | |||
| #undef CUSTOM_REG_DYN_PARAMTYPE | |||
| #undef CUSTOM_REG_DYN_PARAMTYPE_NAME | |||
| #undef CUSTOM_REG_DYN_PARAMTYPE_GETTER | |||
| #undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER | |||
| template <typename T> | |||
| struct get_vector_template_arg_type; | |||
| template <typename T> | |||
| struct get_vector_template_arg_type<std::vector<T>> { | |||
| using type = std::decay_t<T>; | |||
| }; | |||
| template <typename T> | |||
| struct is_vector { | |||
| static constexpr bool value = false; | |||
| }; | |||
| template <typename T> | |||
| struct is_vector <std::vector<T>> { | |||
| static constexpr bool value = true; | |||
| }; | |||
| template <typename T> | |||
| std::string vec2str(const std::vector<T> &vec) { | |||
| std::stringstream ss; | |||
| ss << "{"; | |||
| for (const auto &val: vec) { | |||
| ss << val << ", "; | |||
| } | |||
| if (vec.size() != 0) { | |||
| ss.seekp(ss.tellp()-std::streampos(2)); | |||
| } | |||
| ss << "}"; | |||
| return ss.str(); | |||
| } | |||
| /** | |||
| * we use void* rather than template to help us realise a complete dynamic type | |||
| * if we use template such as: | |||
| * template <typename T> | |||
| * class ParamVal { | |||
| * T m_data; | |||
| * } | |||
| * Con1: user need to set the type explicitly when class template instantiation | |||
| * Con2: ParamVal<int> can not be assigned to ParamVal<double> | |||
| */ | |||
| class ParamVal { | |||
| std::unique_ptr<void, void_deleter> m_ptr; | |||
| ParamDynType m_type; | |||
| public: | |||
| template <typename T> | |||
| ParamVal(const T &val); | |||
| template <typename T> | |||
| ParamVal(const std::initializer_list<T> &val); | |||
| ParamVal(); | |||
| ParamVal(const char *str); | |||
| ParamVal(const std::initializer_list<const char*> &strs); | |||
| ParamVal(const std::vector<const char*> &strs); | |||
| ParamVal(const ParamVal &rhs); | |||
| template <typename T> | |||
| ParamVal &operator=(const T &rhs); | |||
| template <typename T> | |||
| ParamVal &operator=(const std::initializer_list<T> &val); | |||
| ParamVal &operator=(const char *str); | |||
| ParamVal &operator=(const std::initializer_list<const char*> &strs); | |||
| ParamVal &operator=(const std::vector<const char*> &strs); | |||
| ParamVal &operator=(const ParamVal &rhs); | |||
| template <typename T> | |||
| const T &as(void) const; | |||
| template <typename T> | |||
| T &as(void); | |||
| const void *raw_ptr(void) const; | |||
| void *raw_ptr(void); | |||
| ParamDynType type(void) const; | |||
| std::string str(void) const; | |||
| size_t size(void) const; | |||
| static std::string to_bytes(const ParamVal &value); | |||
| static ParamVal from_bytes(const std::string &bytes, size_t &offset); | |||
| friend ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); | |||
| friend ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); | |||
| friend ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); | |||
| friend ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); | |||
| friend bool operator==(const ParamVal &lhs, const ParamVal &rhs); | |||
| friend bool operator!=(const ParamVal &lhs, const ParamVal &rhs); | |||
| friend bool operator> (const ParamVal &lhs, const ParamVal &rhs); | |||
| friend bool operator< (const ParamVal &lhs, const ParamVal &rhs); | |||
| friend bool operator>=(const ParamVal &lhs, const ParamVal &rhs); | |||
| friend bool operator<=(const ParamVal &lhs, const ParamVal &rhs); | |||
| }; | |||
| ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); | |||
| ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); | |||
| ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); | |||
| ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); | |||
| bool operator==(const ParamVal &lhs, const ParamVal &rhs); | |||
| bool operator!=(const ParamVal &lhs, const ParamVal &rhs); | |||
| bool operator> (const ParamVal &lhs, const ParamVal &rhs); | |||
| bool operator< (const ParamVal &lhs, const ParamVal &rhs); | |||
| bool operator>=(const ParamVal &lhs, const ParamVal &rhs); | |||
| bool operator<=(const ParamVal &lhs, const ParamVal &rhs); | |||
| template <typename T> | |||
| ParamVal::ParamVal(const T &val): m_ptr(nullptr, impl_deleter<std::decay_t<T>>) { | |||
| using DecayType = std::decay_t<T>; | |||
| m_type = get_dyn_type<DecayType>::type; | |||
| custom_assert(m_type != ParamDynType::Invalid, "param construct error! unsupported builtin type"); | |||
| m_ptr.reset(new DecayType(val)); | |||
| } | |||
| template <typename T> | |||
| ParamVal::ParamVal(const std::initializer_list<T> &val): ParamVal(std::vector<std::decay_t<T>>(val)) { | |||
| } | |||
| template <typename T> | |||
| ParamVal &ParamVal::operator=(const T &rhs) { | |||
| using DecayType = std::decay_t<T>; | |||
| ParamDynType rhs_dyn_type = get_dyn_type<DecayType>::type; | |||
| custom_assert(rhs_dyn_type != ParamDynType::Invalid, "unsupported builtin dtype"); | |||
| if (rhs_dyn_type == m_type) { | |||
| TypedRef(DecayType, m_ptr.get()) = rhs; | |||
| } | |||
| else { | |||
| m_type = rhs_dyn_type; | |||
| std::unique_ptr<void, void_deleter> new_ptr(new DecayType(rhs), impl_deleter<DecayType>); | |||
| m_ptr.swap(new_ptr); | |||
| } | |||
| return *this; | |||
| } | |||
| template <typename T> | |||
| ParamVal &ParamVal::operator=(const std::initializer_list<T> &val) { | |||
| return this->operator=(std::vector<std::decay_t<T>>(val)); | |||
| } | |||
| template <typename T> | |||
| const T &ParamVal::as(void) const { | |||
| return const_cast<ParamVal*>(this)->as<T>(); | |||
| } | |||
| template <typename T> | |||
| T &ParamVal::as(void) { | |||
| using DecayType = std::decay_t<T>; | |||
| ParamDynType t_dyn_type = get_dyn_type<DecayType>::type; | |||
| custom_assert( | |||
| t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n", | |||
| type2name[m_type].c_str(), type2name[t_dyn_type].c_str() | |||
| ); | |||
| return TypedRef(T, m_ptr.get()); | |||
| } | |||
| } | |||
| @@ -0,0 +1,280 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/tensor.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <vector> | |||
| #include <string> | |||
| #include "utils.h" | |||
| #include "accessor.h" | |||
| namespace custom { | |||
| #define CUSTOM_DATA_ADAPTOR_FRIEND_DECL \ | |||
| template <typename BuiltinT, typename CustomT> \ | |||
| friend BuiltinT to_builtin(const CustomT &custom); \ | |||
| template <typename BuiltinT, typename CustomT> \ | |||
| friend CustomT to_custom(const BuiltinT &builtin) | |||
| #define CUSTOM_FOR_EACH_DEVICE_TYPE(cb) \ | |||
| cb(x86, CPU, "cpux") \ | |||
| cb(cuda, CUDA, "gpux") | |||
| #define CUSTOM_DEVICE_TYPE_ENUM_DECL(custom_type, builtin_type, builtin_str) custom_type, | |||
| class Device { | |||
| const void *impl() const; | |||
| Device(const void *impl); | |||
| CUSTOM_PIMPL_CLS_DECL(Device); | |||
| public: | |||
| enum class DeviceEnum: uint32_t { | |||
| CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL) | |||
| }; | |||
| Device(const std::string &device); | |||
| Device(const char *device); | |||
| Device(DeviceEnum device); | |||
| std::string str(void) const; | |||
| DeviceEnum enumv(void) const; | |||
| static bool is_legal(const std::string &device); | |||
| static bool is_legal(DeviceEnum device); | |||
| static std::vector<std::string> legal_devices(void); | |||
| friend class Tensor; | |||
| friend bool operator==(const Device &lhs, const Device &rhs); | |||
| CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
| }; | |||
| using DeviceEnum = Device::DeviceEnum; | |||
| bool operator==(const Device &lhs, const Device &rhs); | |||
| class Shape { | |||
| const void *impl() const; | |||
| Shape(const void *impl); | |||
| CUSTOM_PIMPL_CLS_DECL(Shape); | |||
| public: | |||
| Shape(const std::vector<size_t> &rhs); | |||
| Shape(const std::initializer_list<size_t> &rhs); | |||
| size_t &operator[](size_t idx); | |||
| size_t operator[](size_t idx) const; | |||
| void ndim(size_t dim); | |||
| size_t ndim(void) const; | |||
| friend class Tensor; | |||
| friend bool operator==(const Shape &lhs, const Shape &rhs); | |||
| CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
| }; | |||
| bool operator==(const Shape &lhs, const Shape &rhs); | |||
| using float16_t = uint16_t; | |||
| using bfloat16_t = uint16_t; | |||
| #if MEGDNN_DISABLE_FLOAT16 | |||
| #define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) | |||
| #else | |||
| #define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) cb(custom_dtype, dnn_dtype, c_dtype) | |||
| #endif | |||
| #define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \ | |||
| cb(float32, Float32, float) \ | |||
| cb(uint8, Uint8, uint8_t) \ | |||
| cb(int8, Int8, int8_t) \ | |||
| cb(int16, Int16, int16_t) \ | |||
| cb(int32, Int32, int32_t) \ | |||
| fp16_wrap(cb, float16, Float16, float16_t) \ | |||
| fp16_wrap(cb, bfloat16, BFloat16, bfloat16_t) \ | |||
| cb(uint16, Uint16, uint16_t) \ | |||
| cb(quint8, Quantized8Asymm, uint8_t) \ | |||
| cb(qint32, QuantizedS32, int32_t) \ | |||
| cb(qint8, QuantizedS8, int8_t) \ | |||
| cb(qint16, QuantizedS16, int16_t) | |||
| #define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type, | |||
| class DType { | |||
| const void *impl() const; | |||
| DType(const void *impl); | |||
| CUSTOM_PIMPL_CLS_DECL(DType); | |||
| public: | |||
| enum class DTypeEnum: uint32_t { | |||
| CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL) | |||
| }; | |||
| DType(const std::string &dtype); | |||
| DType(const char *dtype); | |||
| DType(const std::string &dtype, float scale, uint8_t zero_point = 0); | |||
| DType(const char *dtype, float scale, uint8_t zero_point = 0); | |||
| DType(DTypeEnum dtype); | |||
| DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0); | |||
| std::string str(void) const; | |||
| DTypeEnum enumv() const; | |||
| float scale(void) const; | |||
| uint8_t zero_point(void) const; | |||
| template<typename T> | |||
| bool is_compatible(void) const; | |||
| static bool is_legal(const std::string &dtype); | |||
| static bool is_legal(const DTypeEnum &dtype); | |||
| static std::vector<std::string> legal_dtypes(void); | |||
| friend class Tensor; | |||
| friend bool operator==(const DType &lhs, const DType &rhs); | |||
| CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
| }; | |||
| using DTypeEnum = DType::DTypeEnum; | |||
| template <DTypeEnum> | |||
| struct DTypeTrait; | |||
| #define CUSTOM_DEFINE_DTYPE_TRAIT(custom_type, builtin_type, ctype) \ | |||
| template <> \ | |||
| struct DTypeTrait<DTypeEnum::custom_type> { \ | |||
| using type = ctype; \ | |||
| }; | |||
| #define CUSTOM_CASE_TO_COMPARE_DTYPE(custom_type, builtin_type, ctype) \ | |||
| case (DTypeEnum::custom_type): { \ | |||
| return std::is_same<DecayT, ctype>::value; \ | |||
| } | |||
| CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DEFINE_DTYPE_TRAIT) | |||
| template<typename T> | |||
| bool DType::is_compatible(void) const { | |||
| using DecayT = typename std::decay<T>::type; | |||
| auto dtype_enum = enumv(); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| if (dtype_enum == DTypeEnum::float16) { | |||
| return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::float16>::type); | |||
| } | |||
| else if (dtype_enum == DTypeEnum::bfloat16) { | |||
| return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::bfloat16>::type); | |||
| } | |||
| #endif | |||
| switch (dtype_enum) { | |||
| CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_CASE_TO_COMPARE_DTYPE) | |||
| default: | |||
| return false; | |||
| } | |||
| } | |||
| bool operator==(const DType &lhs, const DType &rhs); | |||
| bool operator==(const DType &lhs, const std::string &rhs); | |||
| bool operator==(const DType &lhs, const char *rhs); | |||
| bool operator==(const std::string &lhs, const DType &rhs); | |||
| bool operator==(const char *lhs, const DType &rhs); | |||
| class Format { | |||
| const void *impl() const; | |||
| Format(const void *impl); | |||
| CUSTOM_PIMPL_CLS_DECL(Format); | |||
| public: | |||
| Format(const std::string &format); | |||
| Format(const char *format); | |||
| std::string str(void) const; | |||
| bool is_default(void) const; | |||
| friend class Tensor; | |||
| CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
| }; | |||
| class Tensor { | |||
| void *m_tensor; | |||
| const void *impl(void) const; | |||
| Tensor(const void *impl); | |||
| const size_t *shapes_raw(void) const; | |||
| const ptrdiff_t *strides_raw(void) const; | |||
| public: | |||
| Tensor() = delete; | |||
| Tensor(const Tensor &rhs); | |||
| Tensor &operator=(const Tensor &rhs); | |||
| Shape shape(void) const; | |||
| DType dtype(void) const; | |||
| Format format(void) const; | |||
| Device device(void) const; | |||
| size_t size(void) const; | |||
| std::vector<ptrdiff_t> stride(void) const; | |||
| float scale(void) const; | |||
| uint8_t zero_point(void) const; | |||
| void *data(void); | |||
| const void *data(void) const; | |||
| template <typename T> | |||
| T *data(void); | |||
| template <typename T> | |||
| const T *data(void) const; | |||
| template <typename T, size_t N, | |||
| template <typename U> class PtrTraits = DefaultPtrTraits, | |||
| typename index_t = int64_t> | |||
| const TensorAccessor<T, N, PtrTraits, index_t> accessor() const; | |||
| template <typename T, size_t N, | |||
| template <typename U> class PtrTraits = DefaultPtrTraits, | |||
| typename index_t = int64_t> | |||
| TensorAccessor<T, N, PtrTraits, index_t> accessor(); | |||
| CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
| }; | |||
| template <typename T> | |||
| T *Tensor::data(void) { | |||
| custom_assert(dtype().is_compatible<T>(), | |||
| "invalid convert, tensor data type is %s", dtype().str().c_str()); | |||
| return reinterpret_cast<T*>(data()); | |||
| } | |||
| template <typename T> | |||
| const T *Tensor::data(void) const { | |||
| return const_cast<Tensor*>(this)->data<T>(); | |||
| } | |||
| template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t> | |||
| const TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() const { | |||
| return const_cast<Tensor*>(this)->accessor<T, N, PtrTraits, index_t>(); | |||
| } | |||
| template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t> | |||
| TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() { | |||
| custom_assert(N == shape().ndim(), | |||
| "cannot get a %lu-d accessor for a tensor with dim %lu", static_cast<unsigned long>(N), static_cast<unsigned long>(shape().ndim())); | |||
| custom_assert(N > 0, "cannot get 0-d accessor"); | |||
| T *ptr = data<T>(); | |||
| return TensorAccessor<T, N, PtrTraits, index_t>(ptr, shapes_raw(), strides_raw()); | |||
| } | |||
| #undef CUSTOM_DATA_ADAPTOR_FRIEND_DECL | |||
| #undef CUSTOM_DEVICE_TYPE_ENUM_DECL | |||
| #undef CUSTOM_DTYPE_ENUM_DECL | |||
| #undef CUSTOM_DEFINE_DTYPE_TRAIT | |||
| #undef CUSTOM_CASE_TO_COMPARE_DTYPE | |||
| } // custom | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * \file src/custom/include/megbrain/custom/utils.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <cassert> | |||
| namespace custom { | |||
| void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...); | |||
| #define custom_expect(expr, msg...) \ | |||
| if (!(expr)) { \ | |||
| assert_failed_log( \ | |||
| __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ | |||
| ); \ | |||
| } | |||
| #define custom_assert(expr, msg...) \ | |||
| if (!(expr)) { \ | |||
| assert_failed_log( \ | |||
| __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ | |||
| ); \ | |||
| } \ | |||
| assert((expr)) | |||
| class UnImpleWarnLog { | |||
| public: | |||
| UnImpleWarnLog(const std::string &func, const std::string &attr, | |||
| const std::string &val); | |||
| }; | |||
| using void_deleter = void(*)(void*); | |||
| template<typename Impl> | |||
| void impl_deleter(void *ptr) { | |||
| delete reinterpret_cast<Impl*>(ptr); | |||
| } | |||
| #define TypedPtr(type, raw_ptr) reinterpret_cast<type*>(raw_ptr) | |||
| #define TypedRef(type, raw_ptr) (*reinterpret_cast<type*>(raw_ptr)) | |||
| #define CUSTOM_PIMPL_CLS_DECL(Cls) \ | |||
| std::unique_ptr<void, void_deleter> m_impl; \ | |||
| public: \ | |||
| Cls(); \ | |||
| Cls(const Cls &rhs); \ | |||
| Cls &operator=(const Cls &rhs) | |||
| #define CUSTOM_PIMPL_CLS_DEFINE(Cls) \ | |||
| Cls::Cls(): m_impl(new Cls##Impl(), impl_deleter<Cls##Impl>) {} \ | |||
| \ | |||
| Cls::Cls(const Cls &rhs): m_impl(nullptr, impl_deleter<Cls##Impl>) { \ | |||
| custom_assert( \ | |||
| rhs.m_impl != nullptr, \ | |||
| "invalid rhs for the copy constructor of %s", #Cls \ | |||
| ); \ | |||
| m_impl.reset(new Cls##Impl(TypedRef(Cls##Impl, rhs.m_impl.get()))); \ | |||
| } \ | |||
| \ | |||
| Cls &Cls::operator=(const Cls &rhs) { \ | |||
| custom_assert( \ | |||
| m_impl != nullptr && rhs.m_impl != nullptr, \ | |||
| "invalid assignment of %s, lhs or rhs is invalid", #Cls \ | |||
| ); \ | |||
| if (&rhs == this) \ | |||
| return *this; \ | |||
| \ | |||
| TypedRef(Cls##Impl, m_impl.get()) = TypedRef(Cls##Impl, rhs.m_impl.get()); \ | |||
| return *this; \ | |||
| } | |||
| /** | |||
| * we define this two function explicitly used for std::unordered_map | |||
| * to improve the compatibility with different compiler versions | |||
| */ | |||
| template <typename T> | |||
| struct EnumHash { | |||
| size_t operator()(const T &rhs) const { | |||
| return static_cast<size_t>(rhs); | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct EnumCmp { | |||
| bool operator()(const T &lhs, const T &rhs) const { | |||
| return static_cast<size_t>(lhs) == static_cast<size_t>(rhs); | |||
| } | |||
| }; | |||
| } // custom | |||
| @@ -0,0 +1,96 @@ | |||
| /** | |||
| * \file src/custom/test/manager.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/manager.h" | |||
| #include "megbrain/custom/custom.h" | |||
| #include "gtest/gtest.h" | |||
| #define MANAGER_TEST_LOG 0 | |||
| namespace custom { | |||
| TEST(TestOpManager, TestOpManager) { | |||
| CustomOpManager *com = CustomOpManager::inst(); | |||
| com->insert("Op1", CUSTOM_OP_VERSION); | |||
| com->insert("Op2", CUSTOM_OP_VERSION); | |||
| std::shared_ptr<CustomOp> ptr = com->find_or_reg("Op3", CUSTOM_OP_VERSION); | |||
| ASSERT_TRUE(ptr != nullptr); | |||
| std::vector<std::string> op_names = com->op_name_list(); | |||
| std::vector<RunTimeId> op_ids = com->op_id_list(); | |||
| ASSERT_TRUE(op_names.size() == 3); | |||
| ASSERT_TRUE(op_ids.size() == 3); | |||
| #if MANAGER_TEST_LOG | |||
| for (std::string &name: op_names) { | |||
| std::cout << name << std::endl; | |||
| } | |||
| #endif | |||
| for (std::string &name: op_names) { | |||
| std::shared_ptr<const CustomOp> op = com->find(name); | |||
| ASSERT_TRUE(op != nullptr); | |||
| ASSERT_TRUE(op->op_type() == name); | |||
| RunTimeId id = com->to_id(name); | |||
| ASSERT_TRUE(com->find(id) == op); | |||
| } | |||
| for (RunTimeId &id: op_ids) { | |||
| std::shared_ptr<const CustomOp> op = com->find(id); | |||
| ASSERT_TRUE(op != nullptr); | |||
| ASSERT_TRUE(op->runtime_id() == id); | |||
| std::string name = com->to_name(id); | |||
| ASSERT_TRUE(com->find(name) == op); | |||
| } | |||
| ASSERT_FALSE(com->erase("Op0")); | |||
| #if MANAGER_TEST_LOG | |||
| for (auto &name: com->op_name_list()) { | |||
| std::cout << name << std::endl; | |||
| } | |||
| #endif | |||
| ASSERT_TRUE(com->erase("Op1")); | |||
| ASSERT_TRUE(com->erase(com->to_id("Op2"))); | |||
| ASSERT_TRUE(com->op_id_list().size() == 1); | |||
| ASSERT_TRUE(com->op_name_list().size() == 1); | |||
| ASSERT_TRUE(com->op_name_list()[0] == "Op3"); | |||
| ptr.reset(); | |||
| ASSERT_TRUE(com->erase("Op3")); | |||
| } | |||
| TEST(TestOpManager, TestOpReg) { | |||
| CUSTOM_OP_REG(Op1) | |||
| .add_inputs(2) | |||
| .add_outputs(3) | |||
| .add_input("lhs") | |||
| .add_param("param1", 1) | |||
| .add_param("param2", 3.45); | |||
| CUSTOM_OP_REG(Op2) | |||
| .add_input("lhs") | |||
| .add_input("rhs") | |||
| .add_output("out") | |||
| .add_param("param1", "test") | |||
| .add_param("param2", true) | |||
| .add_param("", "no name"); | |||
| (void)_Op1; | |||
| (void)_Op2; | |||
| #if MANAGER_TEST_LOG | |||
| for (const auto &name: CustomOpManager::inst()->op_name_list()) { | |||
| std::cout << CustomOpManager::inst()->find(name)->str() << std::endl; | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| @@ -0,0 +1,205 @@ | |||
| /** | |||
| * \file src/custom/test/op.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/op.h" | |||
| #include "megbrain/comp_node.h" | |||
| #include "megbrain/tensor.h" | |||
| #include "megbrain/custom/data_adaptor.h" | |||
| #include "gtest/gtest.h" | |||
| #include "megbrain_build_config.h" | |||
| #define OP_TEST_LOG 0 | |||
| using namespace mgb; | |||
| namespace custom { | |||
| TEST(TestCustomOp, TestCustomOpInfoSetter) { | |||
| CustomOp test("TestOp", CUSTOM_OP_VERSION); | |||
| test.set_description("Test Op") | |||
| .add_input("lhs", "lhs of test op", {"float32", "int32"}, 2) | |||
| .add_inputs(2) | |||
| .add_input("rhs", "rhs of test op", {"float32", "int32"}, 2) | |||
| .add_outputs(1) | |||
| .add_output("out", "out of test op", {"float32", "int32"}, 2) | |||
| .add_outputs(3); | |||
| ASSERT_TRUE(test.op_type() == "TestOp"); | |||
| ASSERT_TRUE(test.op_desc() == "Test Op"); | |||
| ASSERT_TRUE(test.input_num() == 4); | |||
| ASSERT_TRUE(test.output_num() == 5); | |||
| #if OP_TEST_LOG | |||
| for (auto input: test.inputs_info()) { | |||
| std::cout << input.str() << std::endl; | |||
| } | |||
| for (auto output: test.outputs_info()) { | |||
| std::cout << output.str() << std::endl; | |||
| } | |||
| #endif | |||
| test.add_param("param1", "param1 - float", 1.23f) | |||
| .add_param("param2", "param2 - float list", {2.34f, 3.45f}) | |||
| .add_param("param3", "param3 - string", "test-string") | |||
| .add_param("param4", {"test", "string", "list"}) | |||
| .add_param("param5", 1); | |||
| #if OP_TEST_LOG | |||
| ParamInfo pinfo = test.param_info(); | |||
| for (auto kv: pinfo.meta()) { | |||
| std::cout << kv.str() << std::endl; | |||
| } | |||
| #endif | |||
| } | |||
| void device_infer(const std::vector<Device> &inputs, const Param ¶ms, | |||
| std::vector<Device> &outputs) { | |||
| (void)inputs; | |||
| (void)params; | |||
| (void)outputs; | |||
| outputs[0] = inputs[1]; | |||
| outputs[1] = inputs[0]; | |||
| } | |||
| void shape_infer(const std::vector<Shape> &inputs, const Param ¶ms, | |||
| std::vector<Shape> &outputs) { | |||
| (void)inputs; | |||
| (void)params; | |||
| (void)outputs; | |||
| outputs[0] = inputs[1]; | |||
| outputs[1] = inputs[0]; | |||
| } | |||
| void dtype_infer(const std::vector<DType> &inputs, const Param ¶ms, | |||
| std::vector<DType> &outputs) { | |||
| (void)inputs; | |||
| (void)params; | |||
| (void)outputs; | |||
| outputs[0] = inputs[1]; | |||
| outputs[1] = inputs[0]; | |||
| } | |||
| void format_infer(const std::vector<Format> &inputs, const Param ¶ms, | |||
| std::vector<Format> &outputs) { | |||
| (void)inputs; | |||
| (void)params; | |||
| (void)outputs; | |||
| outputs[0] = inputs[1]; | |||
| outputs[1] = inputs[0]; | |||
| } | |||
| void cpu_kernel(const std::vector<Tensor> &inputs, const Param ¶ms, | |||
| std::vector<Tensor> &outputs) { | |||
| (void)inputs; | |||
| (void)params; | |||
| (void)outputs; | |||
| #if OP_TEST_LOG | |||
| std::cout << "Checking CPU Forward - " << params["device"].as<std::string>() << std::endl; | |||
| #endif | |||
| ASSERT_TRUE(params["device"] == "x86"); | |||
| } | |||
| void gpu_kernel(const std::vector<Tensor> &inputs, const Param ¶ms, | |||
| std::vector<Tensor> &outputs) { | |||
| (void)inputs; | |||
| (void)params; | |||
| (void)outputs; | |||
| #if OP_TEST_LOG | |||
| std::cout << "Checking GPU Forward - " << params["device"].as<std::string>() << std::endl; | |||
| #endif | |||
| ASSERT_TRUE(params["device"] == "cuda"); | |||
| } | |||
| TEST(TestCustomOp, TestCustomOpFuncSetter) { | |||
| #if MGB_CUDA | |||
| CustomOp test("TestOp", CUSTOM_OP_VERSION); | |||
| test.set_description("Test Op Forward Backward Union") | |||
| .add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2) | |||
| .add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2) | |||
| .add_output("outl", "outl of Test op", {"float32", "int32"}, 2) | |||
| .add_output("outr", "outr of Test op", {"float32", "int32"}, 2) | |||
| .add_param("smooth", "smooth", 0.f) | |||
| .add_param("device", "using for judge device", "x86"); | |||
| std::vector<Device> idevices = {"x86", "cuda"}; | |||
| std::vector<Shape> ishapes = {{2, 3}, {3, 4}}; | |||
| std::vector<DType> idtypes = {"int32", "float32"}; | |||
| std::vector<Format> iformats = {"default", "default"}; | |||
| Param param(test.param_info()); | |||
| std::vector<Device> odevices = test.infer_output_device(idevices, param); | |||
| std::vector<Shape> oshapes = test.infer_output_shape (ishapes, param); | |||
| std::vector<DType> odtypes = test.infer_output_dtype (idtypes, param); | |||
| std::vector<Format> oformats = test.infer_output_format(iformats, param); | |||
| ASSERT_TRUE(odevices.size() == 2); | |||
| ASSERT_TRUE(oshapes.size() == 2); | |||
| ASSERT_TRUE(odtypes.size() == 2); | |||
| ASSERT_TRUE(oformats.size() == 2); | |||
| ASSERT_TRUE(odevices[0] == "x86"); | |||
| ASSERT_TRUE(odevices[1] == "x86"); | |||
| ASSERT_TRUE(oshapes[0] == Shape({2,3})); | |||
| ASSERT_TRUE(oshapes[1] == Shape({2,3})); | |||
| ASSERT_TRUE(odtypes[0] == "int32"); | |||
| ASSERT_TRUE(odtypes[1] == "int32"); | |||
| ASSERT_TRUE(iformats[0].is_default()); | |||
| ASSERT_TRUE(iformats[1].is_default()); | |||
| test.set_device_infer(device_infer) | |||
| .set_shape_infer(shape_infer) | |||
| .set_dtype_infer(dtype_infer) | |||
| .set_format_infer(format_infer); | |||
| odevices = test.infer_output_device(idevices, param); | |||
| oshapes = test.infer_output_shape (ishapes, param); | |||
| odtypes = test.infer_output_dtype (idtypes, param); | |||
| oformats = test.infer_output_format(iformats, param); | |||
| ASSERT_TRUE(odevices.size() == 2); | |||
| ASSERT_TRUE(oshapes.size() == 2); | |||
| ASSERT_TRUE(odtypes.size() == 2); | |||
| ASSERT_TRUE(oformats.size() == 2); | |||
| ASSERT_TRUE(odevices[0] == "cuda"); | |||
| ASSERT_TRUE(odevices[1] == "x86"); | |||
| ASSERT_TRUE(oshapes[0] == Shape({3,4})); | |||
| ASSERT_TRUE(oshapes[1] == Shape({2,3})); | |||
| ASSERT_TRUE(odtypes[0] == "float32"); | |||
| ASSERT_TRUE(odtypes[1] == "int32"); | |||
| ASSERT_TRUE(iformats[0].is_default()); | |||
| ASSERT_TRUE(iformats[1].is_default()); | |||
| test.set_compute(cpu_kernel); | |||
| DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); | |||
| DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); | |||
| DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); | |||
| DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); | |||
| std::vector<Tensor> cinputs = {to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)}; | |||
| std::vector<Tensor> coutputs ={to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)}; | |||
| param["device"] = "x86"; | |||
| test.compute(cinputs, param, coutputs); | |||
| test.set_compute("cuda", gpu_kernel); | |||
| DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); | |||
| DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); | |||
| DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); | |||
| DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); | |||
| std::vector<Tensor> ginputs = {to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)}; | |||
| std::vector<Tensor> goutputs ={to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)}; | |||
| param["device"] = "cuda"; | |||
| test.compute(ginputs, param, goutputs); | |||
| #endif | |||
| } | |||
| } | |||
| @@ -0,0 +1,208 @@ | |||
| /** | |||
| * \file src/custom/test/param.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/param.h" | |||
| #include "gtest/gtest.h" | |||
| #include <iostream> | |||
| #define PARAM_TEST_LOG 0 | |||
| namespace custom { | |||
| #define SchemaDef \ | |||
| ParamSchema schema_bool("param_bool", true, "bool"); \ | |||
| ParamSchema schema_flt("param_flt", 2.3f, "float"); \ | |||
| ParamSchema schema_int("param_int", 4, "int"); \ | |||
| ParamSchema schema_str("param_str", "test", "string"); \ | |||
| ParamSchema schema_bool_list("param_bl", {true, false, true}, "bool list"); \ | |||
| ParamSchema schema_flt_list("param_fl", {1.1f, 2.2f, 3.3f}, "float list"); \ | |||
| ParamSchema schema_int_list("param_il", {1, 2, 3}, "int list"); \ | |||
| ParamSchema schema_str_list("param_sl", {"test1", "test2", "test3"}, "string list") | |||
| #define InfoDef \ | |||
| info.meta().emplace_back(schema_bool); \ | |||
| info.meta().emplace_back(schema_flt); \ | |||
| info.meta().emplace_back(schema_int); \ | |||
| info.meta().emplace_back(schema_str); \ | |||
| info.meta().emplace_back(schema_bool_list); \ | |||
| info.meta().emplace_back(schema_flt_list); \ | |||
| info.meta().emplace_back(schema_int_list); \ | |||
| info.meta().emplace_back(schema_str_list) | |||
| TEST(TestParam, TestParamScheme) { | |||
| #if PARAM_TEST_LOG | |||
| SchemaDef; | |||
| ParamSchema new_schema = schema_int; | |||
| std::cout << schema_bool.str() << std::endl; | |||
| std::cout << schema_flt.str() << std::endl; | |||
| std::cout << schema_int.str() << std::endl; | |||
| std::cout << schema_str.str() << std::endl; | |||
| std::cout << schema_bool_list.str() << "len: "<< schema_bool_list.default_val().size() << std::endl; | |||
| std::cout << schema_flt_list.str() << "len: "<< schema_flt_list.default_val().size() << std::endl; | |||
| std::cout << schema_int_list.str() << "len: "<< schema_int_list.default_val().size() << std::endl; | |||
| std::cout << schema_str_list.str() << "len: "<< schema_str_list.default_val().size() << std::endl; | |||
| std::cout << new_schema.str() << std::endl; | |||
| #endif | |||
| } | |||
| TEST(TestParam, TestParamVal) { | |||
| ParamVal pv1 = 1.2f, pv2 = true, pv3 = "test", pv4 = {0, 1, 2}, | |||
| pv5 = {true, false, true}; | |||
| #if PARAM_TEST_LOG | |||
| ParamVal pv6 = {"test1", "test2", "test3"}; | |||
| std::cout << pv1.str() << std::endl; | |||
| std::cout << pv2.str() << std::endl; | |||
| std::cout << pv3.str() << std::endl; | |||
| std::cout << pv4.str() << std::endl; | |||
| std::cout << pv5.str() << std::endl; | |||
| std::cout << pv6.str() << std::endl; | |||
| #endif | |||
| ParamVal pv_manip = pv1; | |||
| ASSERT_TRUE(pv_manip.type() == pv1.type()); | |||
| ASSERT_TRUE(pv_manip == pv1); | |||
| pv_manip = 1.3; | |||
| ASSERT_TRUE(pv_manip.type() != pv1.type()); | |||
| ASSERT_TRUE(pv_manip != pv1); | |||
| ASSERT_TRUE(pv_manip > pv1); | |||
| pv_manip = pv_manip + pv1; | |||
| ASSERT_TRUE(pv_manip.type() == ParamDynType::Float64); | |||
| ASSERT_TRUE(pv_manip == 1.3 + 1.2f); | |||
| pv_manip = 1.3f + 1.2f; | |||
| ASSERT_TRUE(pv_manip.type() == pv1.type()); | |||
| pv_manip = false; | |||
| ASSERT_TRUE(pv_manip.type() == pv2.type()); | |||
| ASSERT_TRUE(pv_manip.type() == ParamDynType::Bool); | |||
| ASSERT_TRUE(pv_manip != pv2); | |||
| pv_manip = "test"; | |||
| ASSERT_TRUE(pv_manip.type() == pv3.type()); | |||
| ASSERT_TRUE(pv_manip.type() == ParamDynType::String); | |||
| ASSERT_TRUE(pv_manip == pv3); | |||
| pv_manip = "test1"; | |||
| ASSERT_TRUE(pv_manip > pv3); | |||
| pv_manip = pv_manip + pv3; | |||
| ASSERT_TRUE(pv_manip == "test1test"); | |||
| pv_manip = {0, 1, 2}; | |||
| ASSERT_TRUE(pv_manip.type() == pv4.type()); | |||
| ASSERT_TRUE(pv_manip.type() == ParamDynType::Int32List); | |||
| ASSERT_TRUE(pv_manip == pv4); | |||
| pv_manip = {3, 2, 1}; | |||
| ASSERT_TRUE(pv_manip != pv4); | |||
| ASSERT_TRUE(pv_manip > pv4); | |||
| pv_manip = {true, false, true}; | |||
| ASSERT_TRUE(pv_manip.type() == pv5.type()); | |||
| ASSERT_TRUE(pv_manip.type() == ParamDynType::BoolList); | |||
| ASSERT_TRUE(pv_manip == pv5); | |||
| pv_manip = {false, true, false}; | |||
| ASSERT_TRUE(pv_manip != pv5); | |||
| } | |||
| TEST(TestParam, TestParamInfo) { | |||
| ParamInfo info; | |||
| info.set_tag("Test"); | |||
| #if PARAM_TEST_LOG | |||
| uint32_t tag = info.tag(); | |||
| std::cout << tag << std::endl; | |||
| #endif | |||
| SchemaDef; | |||
| InfoDef; | |||
| ParamInfo new_info1, new_info2; | |||
| new_info1.set_meta(info.meta()); | |||
| new_info2.meta() = info.meta(); | |||
| #if PARAM_TEST_LOG | |||
| for (auto ele: new_info1.meta()) { | |||
| std::cout << ele.str() << std::endl; | |||
| } | |||
| for (auto ele: new_info2.meta()) { | |||
| std::cout << ele.str() << std::endl; | |||
| } | |||
| #endif | |||
| } | |||
| TEST(TestParam, TestParam) { | |||
| ParamInfo info; | |||
| SchemaDef; | |||
| InfoDef; | |||
| Param param(info); | |||
| #if PARAM_TEST_LOG | |||
| std::vector<std::string> names = {"param_bool", "param_flt", "param_int", "param_str", "param_bl", "param_fl", "param_il", "param_sl"}; | |||
| for (auto &name: names) { | |||
| std::cout << param[name].str() << std::endl;; | |||
| } | |||
| #endif | |||
| ASSERT_TRUE(param["param_bool"] == true); | |||
| ASSERT_TRUE(param["param_flt"] == 2.3f); | |||
| ASSERT_TRUE(param["param_int"] == 4); | |||
| ASSERT_TRUE(param["param_str"] == "test"); | |||
| ASSERT_TRUE(param["param_bl"] == ParamVal({true, false, true})); | |||
| ASSERT_TRUE(param["param_fl"] == ParamVal({1.1f, 2.2f, 3.3f})); | |||
| ASSERT_TRUE(param["param_il"] == ParamVal({1, 2, 3})); | |||
| ASSERT_TRUE(param["param_sl"] == ParamVal({"test1", "test2", "test3"})); | |||
| param["param_bool"] = false; | |||
| param["param_flt"] = 3.4f; | |||
| param["param_int"] = 5; | |||
| param["param_str"] = "tset"; | |||
| param["param_bl"] = {false, true, false, true}; | |||
| param["param_fl"] = {7.6f, 6.5f}; | |||
| param["param_il"] = {5, 4, 3, 2, 1}; | |||
| param["param_sl"] = {"1tset", "2tset", "3tset", "4tset", "5tset"}; | |||
| ASSERT_TRUE(param["param_bool"] != true); | |||
| ASSERT_TRUE(param["param_flt"] != 2.3f); | |||
| ASSERT_TRUE(param["param_int"] != 4); | |||
| ASSERT_TRUE(param["param_str"] != "test"); | |||
| ASSERT_TRUE(param["param_bl"] != ParamVal({true, false, true})); | |||
| ASSERT_TRUE(param["param_fl"] != ParamVal({1.1f, 2.2f, 3.3f})); | |||
| ASSERT_TRUE(param["param_il"] != ParamVal({1, 2, 3})); | |||
| ASSERT_TRUE(param["param_sl"] != ParamVal({"test1", "test2", "test3"})); | |||
| ASSERT_TRUE(param["param_bool"] == false); | |||
| ASSERT_TRUE(param["param_flt"] == 3.4f); | |||
| ASSERT_TRUE(param["param_int"] == 5); | |||
| ASSERT_TRUE(param["param_str"] == "tset"); | |||
| ASSERT_TRUE(param["param_bl"] == ParamVal({false, true, false, true})); | |||
| ASSERT_TRUE(param["param_fl"] == ParamVal({7.6f, 6.5f})); | |||
| ASSERT_TRUE(param["param_il"] == ParamVal({5, 4, 3, 2, 1})); | |||
| ASSERT_TRUE(param["param_sl"] == ParamVal({"1tset", "2tset", "3tset", "4tset", "5tset"})); | |||
| #if PARAM_TEST_LOG | |||
| Param copy_param = param; | |||
| for (auto &name: names) { | |||
| std::cout << copy_param[name].str() << std::endl; | |||
| } | |||
| #endif | |||
| Param loaded_param(info); | |||
| std::string bytes = param.to_bytes(); | |||
| loaded_param.from_bytes(bytes); | |||
| #if PARAM_TEST_LOG | |||
| for (auto &kv: loaded_param.raw()) { | |||
| std::cout << kv.first << ":\n" << kv.second.str() << std::endl; | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| @@ -0,0 +1,325 @@ | |||
| /** | |||
| * \file src/custom/test/tensor.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/custom/tensor.h" | |||
| #include "megbrain/custom/data_adaptor.h" | |||
| #include "megbrain/comp_node.h" | |||
| #include "megbrain/tensor.h" | |||
| #include "gtest/gtest.h" | |||
| #include "megbrain_build_config.h" | |||
| #define TENSOR_TEST_LOG 0 | |||
| using namespace mgb; | |||
| namespace custom { | |||
| TEST(TestDevice, TestDevice) { | |||
| #if MGB_CUDA | |||
| ASSERT_TRUE(Device::is_legal("x86")); | |||
| ASSERT_TRUE(Device::is_legal(DeviceEnum::cuda)); | |||
| ASSERT_FALSE(Device::is_legal("cpu")); | |||
| Device dev1; | |||
| ASSERT_TRUE(dev1.str() == "invalid"); | |||
| dev1 = "x86"; | |||
| ASSERT_TRUE("x86" == dev1); | |||
| Device dev2 = "cuda"; | |||
| ASSERT_TRUE(dev2 == "cuda"); | |||
| ASSERT_FALSE(dev2 == dev1); | |||
| Device dev3 = dev2; | |||
| ASSERT_TRUE(dev3 == dev2); | |||
| ASSERT_FALSE(dev3 == dev1); | |||
| Device dev4 = DeviceEnum::cuda; | |||
| ASSERT_TRUE(dev4.enumv() == DeviceEnum::cuda); | |||
| #if TENSOR_TEST_LOG | |||
| std::cout << dev1.str() << "\n" << dev2.str() << "\n" | |||
| << dev3.str() << "\n" << dev4.str() << std::endl; | |||
| #endif | |||
| CompNode compnode = to_builtin<CompNode, Device>(dev3); | |||
| ASSERT_TRUE(compnode.to_string_logical() == "gpux:0"); | |||
| compnode = CompNode::load("cpu0:0"); | |||
| Device dev5 = to_custom<CompNode, Device>(compnode); | |||
| ASSERT_TRUE(dev5.str() == "x86"); | |||
| std::vector<Device> devs1 = {"x86", "cuda", "x86"}; | |||
| megdnn::SmallVector<CompNode> compnodes = to_builtin<CompNode, Device>(devs1); | |||
| ASSERT_TRUE(compnodes[0].to_string_logical() == "cpux:0"); | |||
| ASSERT_TRUE(compnodes[1].to_string_logical() == "gpux:0"); | |||
| ASSERT_TRUE(compnodes[2].to_string_logical() == "cpux:0"); | |||
| std::vector<Device> devs2 = to_custom<CompNode, Device>(compnodes); | |||
| ASSERT_TRUE(devs2[0] == "x86"); | |||
| ASSERT_TRUE(devs2[1].str() == "cuda"); | |||
| ASSERT_TRUE(devs2[2] == "x86"); | |||
| #endif | |||
| } | |||
| TEST(TestShape, TestShape) { | |||
| Shape shape1, shape2; | |||
| ASSERT_TRUE(shape1.ndim() == 0); | |||
| shape1 = {16, 32, 8, 8}; | |||
| shape2 = shape1; | |||
| ASSERT_TRUE(shape2.ndim() == 4); | |||
| ASSERT_TRUE(shape2[0] == 16); | |||
| ASSERT_TRUE(shape2[1] == 32); | |||
| ASSERT_TRUE(shape2[2] == 8); | |||
| ASSERT_TRUE(shape2[3] == 8); | |||
| Shape shape3 = {16, 32, 8, 8}; | |||
| const Shape shape4 = shape1; | |||
| ASSERT_TRUE(shape3 == shape4); | |||
| shape3[0] = 32; | |||
| ASSERT_FALSE(shape3 == shape4); | |||
| ASSERT_TRUE(shape3[0] == 32); | |||
| ASSERT_TRUE(shape4[0] == 16); | |||
| Shape shape5 = {2, 3, 4}; | |||
| TensorShape bshape1 = to_builtin<TensorShape, Shape>(shape5); | |||
| ASSERT_TRUE(bshape1.ndim == 3); | |||
| ASSERT_TRUE(bshape1[0] == 2); | |||
| ASSERT_TRUE(bshape1[1] == 3); | |||
| ASSERT_TRUE(bshape1[2] == 4); | |||
| bshape1 = {4, 2, 3}; | |||
| Shape shape6 = to_custom<TensorShape, Shape>(bshape1); | |||
| ASSERT_TRUE(shape6.ndim() == 3); | |||
| ASSERT_TRUE(shape6[0] == 4); | |||
| ASSERT_TRUE(shape6[1] == 2); | |||
| ASSERT_TRUE(shape6[2] == 3); | |||
| Shape shape7; | |||
| shape7.ndim(3); | |||
| shape7[1] = 4; | |||
| ASSERT_TRUE(shape7 == Shape({0, 4, 0})); | |||
| std::vector<Shape> shapes1 = {{2, 3, 4}, {6}, {5, 7}}; | |||
| megdnn::SmallVector<TensorShape> bshapes = to_builtin<TensorShape, Shape>(shapes1); | |||
| ASSERT_TRUE(bshapes[0].total_nr_elems() == 2*3*4); | |||
| ASSERT_TRUE(bshapes[1].total_nr_elems() == 6); | |||
| ASSERT_TRUE(bshapes[2].total_nr_elems() == 35); | |||
| std::vector<Shape> shapes2 = to_custom<TensorShape, Shape>(bshapes); | |||
| ASSERT_TRUE(shapes2[0] == Shape({2, 3, 4})); | |||
| ASSERT_TRUE(shapes2[1] == Shape({6})); | |||
| ASSERT_TRUE(shapes2[2] == Shape({5, 7})); | |||
| } | |||
| TEST(TestDType, TestDType) { | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| ASSERT_TRUE(DType::is_legal("uint8")); | |||
| ASSERT_TRUE(DType::is_legal(DTypeEnum::bfloat16)); | |||
| DType dtype1, dtype2; | |||
| ASSERT_TRUE(dtype1.str() == "invalid"); | |||
| dtype1 = "float32"; | |||
| ASSERT_TRUE(dtype1.str() == "float32"); | |||
| dtype2 = dtype1; | |||
| DType dtype3 = dtype2; | |||
| ASSERT_TRUE(dtype3 == dtype1); | |||
| ASSERT_TRUE(dtype3 == "float32"); | |||
| dtype3 = "int8"; | |||
| ASSERT_FALSE("float32" == dtype3.str()); | |||
| ASSERT_FALSE(dtype3 == dtype2); | |||
| DType dtype4 = DTypeEnum::int8, dtype5 = dtype3; | |||
| ASSERT_TRUE(dtype4 == dtype5); | |||
| ASSERT_TRUE(dtype4.is_compatible<int8_t>()); | |||
| ASSERT_FALSE(dtype4.is_compatible<uint8_t>()); | |||
| DType dtype6 = "int32"; | |||
| megdnn::DType bdtype1 = to_builtin<megdnn::DType, DType>(dtype6); | |||
| ASSERT_TRUE(bdtype1.name() == std::string("Int32")); | |||
| bdtype1 = megdnn::DType::from_enum(megdnn::DTypeEnum::BFloat16); | |||
| DType dtype7 = to_custom<megdnn::DType, DType>(bdtype1); | |||
| ASSERT_TRUE(dtype7.enumv() == DTypeEnum::bfloat16); | |||
| std::vector<DType> dtypes1 = {"int8", "uint8", "float16"}; | |||
| megdnn::SmallVector<megdnn::DType> bdtypes | |||
| = to_builtin<megdnn::DType, DType>(dtypes1); | |||
| ASSERT_TRUE(bdtypes[0].name() == std::string("Int8")); | |||
| ASSERT_TRUE(bdtypes[1].name() == std::string("Uint8")); | |||
| ASSERT_TRUE(bdtypes[2].name() == std::string("Float16")); | |||
| std::vector<DType> dtypes2 = to_custom<megdnn::DType, DType>(bdtypes); | |||
| ASSERT_TRUE(dtypes2[0] == "int8"); | |||
| ASSERT_TRUE(dtypes2[1] == "uint8"); | |||
| ASSERT_TRUE(dtypes2[2] == "float16"); | |||
| #endif | |||
| } | |||
| TEST(TestDType, TestDTypeQuantized) { | |||
| DType quint8_1("quint8", 3.2, 15); | |||
| DType quint8_2("quint8", 3.2, 15); | |||
| DType quint8_3("quint8", 3.2, 16); | |||
| DType quint8_4("quint8", 3.1, 15); | |||
| ASSERT_TRUE(quint8_1 == quint8_2); | |||
| ASSERT_FALSE(quint8_1 == quint8_3); | |||
| ASSERT_FALSE(quint8_1 == quint8_4); | |||
| ASSERT_TRUE(quint8_1.scale() == 3.2f); | |||
| ASSERT_TRUE(quint8_1.zero_point() == 15); | |||
| DType qint8("qint8", 3.3f); | |||
| DType qint16("qint16", 3.4f); | |||
| DType qint32("qint32", 3.5f); | |||
| ASSERT_TRUE(qint8.scale() == 3.3f); | |||
| ASSERT_TRUE(qint16.scale() == 3.4f); | |||
| ASSERT_TRUE(qint32.scale() == 3.5f); | |||
| ASSERT_TRUE(qint8.enumv() == DTypeEnum::qint8); | |||
| ASSERT_TRUE(qint8.str() == "qint8"); | |||
| } | |||
| TEST(TestFormat, TestFormat) { | |||
| Format format1, format2("default"); | |||
| ASSERT_TRUE(format1.is_default()); | |||
| ASSERT_TRUE(format2.is_default()); | |||
| Format format3 = format1; | |||
| ASSERT_TRUE(format3.is_default()); | |||
| } | |||
| TEST(TestTensor, TestTensor) { | |||
| CompNode builtin_device = CompNode::load("cpux:0"); | |||
| TensorShape builtin_shape = {3, 2, 4}; | |||
| megdnn::DType builtin_dtype = dtype::Int32{}; | |||
| DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
| Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
| Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
| Device device = tensor1.device(); | |||
| Shape shape = tensor1.shape(); | |||
| DType dtype = tensor1.dtype(); | |||
| ASSERT_TRUE(device == "x86"); | |||
| ASSERT_TRUE(shape.ndim() == 3); | |||
| ASSERT_TRUE(shape[0] == 3); | |||
| ASSERT_TRUE(shape[1] == 2); | |||
| ASSERT_TRUE(shape[2] == 4); | |||
| ASSERT_TRUE(shape == std::vector<size_t>({3, 2, 4})); | |||
| ASSERT_TRUE(dtype == "int32"); | |||
| int *raw_ptr1 = tensor1.data<int>(); | |||
| for (size_t i=0; i<tensor1.size(); i++) | |||
| raw_ptr1[i] = i; | |||
| int *raw_ptr2 = tensor2.data<int>(); | |||
| for (size_t i=0; i<tensor2.size(); i++) | |||
| ASSERT_TRUE(raw_ptr2[i] == static_cast<int>(i)); | |||
| Tensor tensor3 = tensor2; | |||
| int *raw_ptr3 = tensor3.data<int>(); | |||
| for (size_t i=0; i<tensor3.size(); i++) | |||
| ASSERT_TRUE(raw_ptr3[i] == static_cast<int>(i)); | |||
| ASSERT_TRUE(raw_ptr1 == raw_ptr2); | |||
| ASSERT_TRUE(raw_ptr1 == raw_ptr3); | |||
| for (size_t i=0; i<tensor3.size(); i++) { | |||
| raw_ptr3[i] = -static_cast<int>(i); | |||
| } | |||
| for (size_t i=0; i<tensor1.size(); i++) { | |||
| ASSERT_TRUE(raw_ptr1[i] == -static_cast<int>(i)); | |||
| } | |||
| DeviceTensorND new_dev_tensor = to_builtin<DeviceTensorND, Tensor>(tensor3); | |||
| int *builtin_ptr = new_dev_tensor.ptr<int>(); | |||
| for (size_t i=0; i<new_dev_tensor.shape().total_nr_elems(); i++) { | |||
| ASSERT_TRUE(builtin_ptr[i] == -static_cast<int>(i)); | |||
| } | |||
| } | |||
| TEST(TestTensor, TestTensorQuantized) { | |||
| #if MGB_CUDA | |||
| CompNode builtin_device = CompNode::load("gpux:0"); | |||
| TensorShape builtin_shape = {3, 2, 4}; | |||
| megdnn::DType builtin_dtype = dtype::Quantized8Asymm{3.2f, uint8_t(15)}; | |||
| DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
| Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
| Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
| Device device1 = tensor1.device(), device2 = tensor2.device(); | |||
| Shape shape1 = tensor1.shape(), shape2 = tensor2.shape(); | |||
| DType dtype1 = tensor1.dtype(), dtype2 = tensor2.dtype(); | |||
| ASSERT_TRUE(device1 == "cuda"); | |||
| ASSERT_TRUE(shape1.ndim() == 3); | |||
| ASSERT_TRUE(shape1[0] == 3); | |||
| ASSERT_TRUE(shape1[1] == 2); | |||
| ASSERT_TRUE(shape1[2] == 4); | |||
| ASSERT_TRUE(shape1 == std::vector<size_t>({3, 2, 4})); | |||
| ASSERT_TRUE(dtype1 == "quint8"); | |||
| ASSERT_TRUE(dtype1.scale() == 3.2f); | |||
| ASSERT_TRUE(dtype1.zero_point() == 15); | |||
| ASSERT_TRUE(device1 == device2); | |||
| ASSERT_TRUE(shape1 == shape2); | |||
| ASSERT_TRUE(dtype1 == dtype2); | |||
| #endif | |||
| } | |||
| TEST(TestTensor, TestTensorAccessorND) { | |||
| size_t N = 2, C = 4, H = 6, W = 8; | |||
| CompNode builtin_device = CompNode::load("cpux"); | |||
| TensorShape builtin_shape = {N, C, H, W}; | |||
| megdnn::DType builtin_dtype = dtype::Int32{}; | |||
| DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
| int *builtin_ptr = dev_tensor.ptr<int>(); | |||
| for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) { | |||
| builtin_ptr[i] = i; | |||
| } | |||
| Tensor tensor = to_custom_tensor(dev_tensor); | |||
| auto accessor = tensor.accessor<int32_t, 4>(); | |||
| for (size_t n=0; n<N; ++n) { | |||
| for (size_t c=0; c<C; ++c) { | |||
| for (size_t h=0; h<H; ++h) { | |||
| for (size_t w=0; w<W; ++w) { | |||
| int32_t idx = n*C*H*W + c*H*W + h*W + w; | |||
| ASSERT_TRUE(accessor[n][c][h][w] == idx); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TEST(TestTensor, TestTensorAccessor1D) { | |||
| CompNode builtin_device = CompNode::load("cpux"); | |||
| TensorShape builtin_shape = {32}; | |||
| megdnn::DType builtin_dtype = dtype::Float32{}; | |||
| DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
| float *builtin_ptr = dev_tensor.ptr<float>(); | |||
| for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) { | |||
| builtin_ptr[i] = i; | |||
| } | |||
| Tensor tensor = to_custom_tensor(dev_tensor); | |||
| auto accessor = tensor.accessor<float, 1>(); | |||
| for (size_t n=0; n<32; ++n) { | |||
| ASSERT_TRUE(accessor[n] == n); | |||
| } | |||
| } | |||
| } | |||
| @@ -18,7 +18,7 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode); | |||
| void CustomOpNode::infer_output_comp_node(void) { | |||
| SmallVector<CompNode> input_comp_nodes(input_num()); | |||
| for (int i=0; i<input_num(); ++i) { | |||
| for (size_t i=0; i<input_num(); ++i) { | |||
| input_comp_nodes[i] = input(i)->comp_node(); | |||
| } | |||
| @@ -28,7 +28,7 @@ void CustomOpNode::infer_output_comp_node(void) { | |||
| ) | |||
| ); | |||
| for (int i=0; i<output_num(); ++i) { | |||
| for (size_t i=0; i<output_num(); ++i) { | |||
| mgb_assert(output_comp_nodes[i] == output_comp_nodes[0], | |||
| "only single comp node operator is supported"); | |||
| output(i)->comp_node(output_comp_nodes[i]); | |||
| @@ -39,7 +39,7 @@ void CustomOpNode::infer_output_comp_node(void) { | |||
| void CustomOpNode::infer_output_dtype(void) { | |||
| SmallVector<DType> input_dtypes(input_num()); | |||
| for (int i=0; i<input_num(); ++i) { | |||
| for (size_t i=0; i<input_num(); ++i) { | |||
| input_dtypes[i] = input(i)->dtype(); | |||
| } | |||
| @@ -49,14 +49,14 @@ void CustomOpNode::infer_output_dtype(void) { | |||
| ) | |||
| ); | |||
| for (int i=0; i<output_num(); ++i) { | |||
| for (size_t i=0; i<output_num(); ++i) { | |||
| output(i)->dtype(output_dtypes[i]); | |||
| } | |||
| } | |||
| void CustomOpNode::infer_output_format(void) { | |||
| SmallVector<TensorFormat> input_formats(input_num()); | |||
| for (int i=0; i<input_num(); ++i) { | |||
| for (size_t i=0; i<input_num(); ++i) { | |||
| input_formats[i] = input(i)->format(); | |||
| } | |||
| @@ -66,14 +66,14 @@ void CustomOpNode::infer_output_format(void) { | |||
| ) | |||
| ); | |||
| for (int i=0; i<output_num(); ++i) { | |||
| for (size_t i=0; i<output_num(); ++i) { | |||
| output(i)->format(output_formats[i]); | |||
| } | |||
| } | |||
| void CustomOpNode::infer_output_shape(void) { | |||
| SmallVector<TensorShape> input_shapes(input_num()); | |||
| for (int i=0; i<input_num(); ++i) { | |||
| for (size_t i=0; i<input_num(); ++i) { | |||
| input_shapes[i] = input(i)->shape(); | |||
| } | |||
| @@ -83,7 +83,7 @@ void CustomOpNode::infer_output_shape(void) { | |||
| ) | |||
| ); | |||
| for (int i=0; i<output_num(); ++i) { | |||
| for (size_t i=0; i<output_num(); ++i) { | |||
| output(i)->shape(output_shapes[i]); | |||
| } | |||
| } | |||
| @@ -235,10 +235,10 @@ CustomOpNode::CustomOpNode(const std::shared_ptr<const custom::CustomOp> &op, | |||
| const OperatorNodeConfig &config): | |||
| OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) { | |||
| mgb_assert(input_num() == inputs.size(), "wrong input tensors list length"); | |||
| for (int i=0; i < input_num(); ++i) | |||
| for (size_t i=0; i < input_num(); ++i) | |||
| add_input({inputs[i]}); | |||
| for (int i=0; i<output_num(); ++i) | |||
| for (size_t i=0; i<output_num(); ++i) | |||
| add_output(output_info(i).name()); | |||
| if (!std::is_empty<custom::Param>::value) { | |||
| @@ -306,11 +306,11 @@ std::string CustomOpNode::op_desc(void) const { | |||
| return m_op->op_desc(); | |||
| } | |||
| int CustomOpNode::input_num(void) const { | |||
| size_t CustomOpNode::input_num(void) const { | |||
| return m_op->input_num(); | |||
| } | |||
| int CustomOpNode::output_num(void) const { | |||
| size_t CustomOpNode::output_num(void) const { | |||
| return m_op->output_num(); | |||
| } | |||
| @@ -93,8 +93,8 @@ public: | |||
| custom::Param param(void) const; | |||
| std::string op_type(void) const; | |||
| std::string op_desc(void) const; | |||
| int input_num(void) const; | |||
| int output_num(void) const; | |||
| size_t input_num(void) const; | |||
| size_t output_num(void) const; | |||
| custom::ArgInfo input_info(size_t idx) const; | |||
| custom::ArgInfo output_info(size_t idx) const; | |||
| }; | |||
| @@ -1,7 +1,7 @@ | |||
| include_directories("./src/include") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | |||
| file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp) | |||
| file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp ../src/custom/test/*.cpp) | |||
| if(MGE_WITH_JIT) | |||
| file(GLOB_RECURSE SOURCES_ ../src/jit/test/*.cpp) | |||
| list(APPEND SOURCES ${SOURCES_}) | |||