| @@ -561,7 +561,8 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd( | |||||
| src.enumv() == DTypeEnum::QuantizedS8 || | src.enumv() == DTypeEnum::QuantizedS8 || | ||||
| src.enumv() == DTypeEnum::Quantized8Asymm || | src.enumv() == DTypeEnum::Quantized8Asymm || | ||||
| src.enumv() == DTypeEnum::QuantizedS4 || | src.enumv() == DTypeEnum::QuantizedS4 || | ||||
| src.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| src.enumv() == DTypeEnum::Quantized4Asymm || | |||||
| src.enumv() == DTypeEnum::QuantizedS1) { | |||||
| supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter))); | supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter))); | ||||
| bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() || | bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() || | ||||
| ((dst.enumv() == DTypeEnum::QuantizedS4 || | ((dst.enumv() == DTypeEnum::QuantizedS4 || | ||||
| @@ -25,7 +25,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||||
| non_cudnn_algos.push_back(&matmul); | non_cudnn_algos.push_back(&matmul); | ||||
| non_cudnn_algos.push_back(&matmul8x8x32); | non_cudnn_algos.push_back(&matmul8x8x32); | ||||
| non_cudnn_algos.push_back(&batched_matmul); | non_cudnn_algos.push_back(&batched_matmul); | ||||
| non_cudnn_algos.push_back(&int1_simple); | |||||
| fill_cudnn_algos(); | fill_cudnn_algos(); | ||||
| for (auto&& algo : cudnn_conv_bias_activations) { | for (auto&& algo : cudnn_conv_bias_activations) { | ||||
| all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
| @@ -45,6 +45,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||||
| conv_algos.push_back(&matmul8x8x32); | conv_algos.push_back(&matmul8x8x32); | ||||
| conv_algos.push_back(&batched_matmul); | conv_algos.push_back(&batched_matmul); | ||||
| conv_algos.push_back(&group); | conv_algos.push_back(&group); | ||||
| conv_algos.push_back(&int1_simple); | |||||
| for (auto&& algo : conv_algos) { | for (auto&& algo : conv_algos) { | ||||
| all_algos.push_back(algo); | all_algos.push_back(algo); | ||||
| @@ -87,6 +87,7 @@ public: | |||||
| CUDA_FALLBACK_NCHW_INT4, | CUDA_FALLBACK_NCHW_INT4, | ||||
| CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32, | CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32, | ||||
| CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16, | CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16, | ||||
| CUDA_SIMPLE_INT1, | |||||
| }; | }; | ||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
| @@ -1089,6 +1090,24 @@ private: | |||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoSimpleInt1 final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| std::vector<SearchItem> get_subopr_list( | |||||
| const TensorLayoutArray& layouts, const OperatorBase* opr) const override; | |||||
| const char* name() const override { return "CONVBIAS_SIMPLE_INT1"; } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_SIMPLE_INT1) | |||||
| private: | |||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
| }; | |||||
| class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | ||||
| private: | private: | ||||
| AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
| @@ -1132,6 +1151,7 @@ public: | |||||
| std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> f16_implicit_bmm; | std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> f16_implicit_bmm; | ||||
| AlgoGroupConvGeneral group; | AlgoGroupConvGeneral group; | ||||
| AlgoBFloat16 bfloat16; | AlgoBFloat16 bfloat16; | ||||
| AlgoSimpleInt1 int1_simple; | |||||
| AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | ||||
| @@ -30,6 +30,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1) | |||||
| return false; | |||||
| if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | ||||
| args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | ||||
| args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | ||||
| @@ -134,6 +134,9 @@ void ConvBiasDesc::set_conv( | |||||
| namespace conv_bias { | namespace conv_bias { | ||||
| bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | ||||
| if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1) | |||||
| return false; | |||||
| if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | ||||
| args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | ||||
| args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | ||||
| @@ -221,6 +221,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| return &sm_algo_pack.fallback_nchw_qs8; | return &sm_algo_pack.fallback_nchw_qs8; | ||||
| } | } | ||||
| if (sm_algo_pack.int1_simple.is_available_attribute( | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.int1_simple; | |||||
| } | |||||
| if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | ||||
| return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>( | return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>( | ||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | ||||
| @@ -72,6 +72,7 @@ public: | |||||
| class AlgoInt4Int4NHWCIMMAImplicitGemm; | class AlgoInt4Int4NHWCIMMAImplicitGemm; | ||||
| class AlgoUInt4Int4NHWCIMMAImplicitGemm; | class AlgoUInt4Int4NHWCIMMAImplicitGemm; | ||||
| class AlgoBFloat16; | class AlgoBFloat16; | ||||
| class AlgoSimpleInt1; | |||||
| // The following algorithms are suitable for channel wise convolution | // The following algorithms are suitable for channel wise convolution | ||||
| class AlgoFloat32NCHWFMAImplicitBatchedGemm; | class AlgoFloat32NCHWFMAImplicitBatchedGemm; | ||||
| class AlgoFloat16NCHWHMMAImplicitBatchedGemm; | class AlgoFloat16NCHWHMMAImplicitBatchedGemm; | ||||
| @@ -0,0 +1,145 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/simple_int1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/utils.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace conv_bias; | |||||
| namespace { | |||||
| std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> sub_opr_config( | |||||
| const TensorLayoutArray& layouts, const ConvBiasForwardImpl* opr) { | |||||
| megdnn_assert(layouts.size() >= 3); | |||||
| std::pair<TensorLayoutArray, ConvBiasForwardImpl::Param> ret; | |||||
| ret.first = layouts; | |||||
| auto change_dtype = [](TensorLayout& layout) { | |||||
| if (layout.dtype.enumv() == DTypeEnum::QuantizedS1 || | |||||
| layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
| layout.dtype = dtype::Float32(); | |||||
| } | |||||
| }; | |||||
| change_dtype(ret.first[0]); | |||||
| change_dtype(ret.first[1]); | |||||
| change_dtype(ret.first[2]); | |||||
| change_dtype(ret.first[3]); | |||||
| change_dtype(ret.first[4]); | |||||
| ret.second = opr->param(); | |||||
| ret.second.compute_mode = ConvBiasForwardImpl::Param::ComputeMode::DEFAULT; | |||||
| return ret; | |||||
| } | |||||
| std::pair<TensorLayoutArray, std::unique_ptr<ConvBiasForward>> prepare_sub_opr( | |||||
| const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) { | |||||
| auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||||
| auto&& config = sub_opr_config( | |||||
| {*args.src_layout, *args.filter_layout, *args.bias_layout, *args.z_layout, | |||||
| *args.dst_layout}, | |||||
| args.opr); | |||||
| convbias_opr->param() = config.second; | |||||
| return {config.first, std::move(convbias_opr)}; | |||||
| } | |||||
| } // namespace | |||||
| std::vector<Algorithm::SearchItem> ConvBiasForwardImpl::AlgoSimpleInt1::get_subopr_list( | |||||
| const TensorLayoutArray& layouts, const OperatorBase* opr) const { | |||||
| auto&& config = | |||||
| sub_opr_config(layouts, static_cast<const ConvBiasForwardImpl*>(opr)); | |||||
| std::string param_str; | |||||
| Algorithm::serialize_write_pod(config.second, param_str); | |||||
| return {{Algorithm::OprType::CONVBIAS_FORWARD, param_str, config.first}}; | |||||
| } | |||||
| bool ConvBiasForwardImpl::AlgoSimpleInt1::is_available(const SizeArgs& args) const { | |||||
| if (args.src_layout->dtype.valid() && args.filter_layout->dtype.valid() && | |||||
| args.bias_layout->dtype.valid() && args.z_layout->dtype.valid() && | |||||
| args.dst_layout->dtype.valid()) { | |||||
| auto config = prepare_sub_opr(args); | |||||
| return args.src_layout->dtype.enumv() == args.filter_layout->dtype.enumv() && | |||||
| args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS1 && | |||||
| get_algorithm( | |||||
| static_cast<ConvBiasForwardImpl*>(config.second.get()), | |||||
| config.first[0], config.first[1], config.first[2], | |||||
| config.first[3], config.first[4]); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| WorkspaceBundle ConvBiasForwardImpl::AlgoSimpleInt1::get_workspace_bundle( | |||||
| void* ptr, const SizeArgs& args) const { | |||||
| auto config = prepare_sub_opr(args); | |||||
| SmallVector<size_t> sizes; | |||||
| auto get_workspace = [&sizes](const TensorLayout& src, const TensorLayout& dst) { | |||||
| if (src.dtype != dst.dtype) { | |||||
| sizes.push_back(dst.span().dist_byte()); | |||||
| } | |||||
| }; | |||||
| get_workspace(*args.src_layout, config.first[0]); | |||||
| get_workspace(*args.filter_layout, config.first[1]); | |||||
| get_workspace(*args.bias_layout, config.first[2]); | |||||
| get_workspace(*args.z_layout, config.first[3]); | |||||
| get_workspace(*args.dst_layout, config.first[4]); | |||||
| sizes.push_back(config.second->get_workspace_in_bytes( | |||||
| config.first[0], config.first[1], config.first[2], config.first[3], | |||||
| config.first[4], nullptr)); | |||||
| return {ptr, std::move(sizes)}; | |||||
| } | |||||
| size_t ConvBiasForwardImpl::AlgoSimpleInt1::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoSimpleInt1::exec(const ExecArgs& args) const { | |||||
| TensorND fsrc_tensor = *args.src_tensor; | |||||
| TensorND ffilter_tensor = *args.filter_tensor; | |||||
| TensorND fbias_tensor = *args.bias_tensor; | |||||
| TensorND fz_tensor = *args.z_tensor; | |||||
| TensorND fdst_tensor = *args.dst_tensor; | |||||
| auto config = prepare_sub_opr(args); | |||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||||
| CompTypeCvter<dtype::QuantizedS1, dtype::Float32> cvter(args.handle, &bundle); | |||||
| { | |||||
| cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||||
| .src_to_comp_type(*args.filter_tensor, ffilter_tensor); | |||||
| } | |||||
| WorkspaceBundle dst_bundle = { | |||||
| bundle.get(2), | |||||
| {bundle.get_size(2), bundle.get_size(3), bundle.get_size(4), | |||||
| bundle.get_size(5)}}; | |||||
| CompTypeCvter<dtype::QuantizedS32, dtype::Float32> dst_cvter( | |||||
| args.handle, &dst_bundle); | |||||
| { | |||||
| dst_cvter.src_to_comp_type(*args.bias_tensor, fbias_tensor) | |||||
| .src_to_comp_type(*args.z_tensor, fz_tensor) | |||||
| .src_to_comp_type(*args.dst_tensor, fdst_tensor); | |||||
| } | |||||
| config.second->exec( | |||||
| fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, fdst_tensor, nullptr, | |||||
| dst_cvter.workspace()); | |||||
| { dst_cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -44,6 +44,10 @@ std::pair<TensorLayoutArray, ConvBiasForward::Param> sub_opr_config( | |||||
| src.dtype.param<dtype::Quantized4Asymm>().scale * | src.dtype.param<dtype::Quantized4Asymm>().scale * | ||||
| filter.dtype.param<dtype::Quantized4Asymm>().scale); | filter.dtype.param<dtype::Quantized4Asymm>().scale); | ||||
| } else if (src.dtype.enumv() == DTypeEnum::QuantizedS1) { | |||||
| bias_type = dtype::QuantizedS32( | |||||
| src.dtype.param<dtype::QuantizedS1>().scale * | |||||
| filter.dtype.param<dtype::QuantizedS1>().scale); | |||||
| } else { | } else { | ||||
| megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | ||||
| bias_type = src.dtype; | bias_type = src.dtype; | ||||
| @@ -278,6 +278,9 @@ void ConvBiasForwardImpl::exec( | |||||
| DISPATCH_RAW( | DISPATCH_RAW( | ||||
| Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32, DEFAULT, | Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32, DEFAULT, | ||||
| (convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>)) | (convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>)) | ||||
| DISPATCH_RAW( | |||||
| QuantizedS1, QuantizedS1, QuantizedS32, QuantizedS32, FLOAT32, | |||||
| (convolution::forward_bias<dt_qint1, dt_qint1, dt_qint32, dt_qint32>)) | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| DISPATCH(Float16, Float16) | DISPATCH(Float16, Float16) | ||||
| DISPATCH_RAW( | DISPATCH_RAW( | ||||
| @@ -84,6 +84,15 @@ inline void StrategyFwd::on( | |||||
| d += cast(s) * cast(f); | d += cast(s) * cast(f); | ||||
| } | } | ||||
| template <> | |||||
| inline void StrategyFwd::on( | |||||
| dt_qint1& s, dt_qint1& f, dt_qint32& d, DType, DType, DType) { | |||||
| auto cast = [](const dt_qint1& val) { | |||||
| return dt_qint32(static_cast<int32_t>(val.as_int8())); | |||||
| }; | |||||
| d += cast(s) * cast(f); | |||||
| } | |||||
| struct StrategyBwdData { | struct StrategyBwdData { | ||||
| template <typename st, typename ft, typename dt> | template <typename st, typename ft, typename dt> | ||||
| static void on(st& s, ft& f, dt& d, DType, DType, DType) { | static void on(st& s, ft& f, dt& d, DType, DType, DType) { | ||||
| @@ -133,6 +133,32 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_BF16) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(CUDA, CONV_BIAS_FORWARD_QS1) { | |||||
| require_compute_capability(6, 1); | |||||
| UniformIntRNG int_rng{1, 1}; | |||||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback(AlgoChecker<ConvBiasForward>( | |||||
| ExecutionPolicyAlgoName{"CONVBIAS_SIMPLE_INT1", {{"MATMUL", {}}}})); | |||||
| ConvBias::Param param; | |||||
| param.format = ConvBias::Param::Format::NCHW; | |||||
| param.compute_mode = param::Convolution::ComputeMode::FLOAT32; | |||||
| { | |||||
| auto src_shape = TensorShape{20, 2, 224, 224}; | |||||
| auto filter_shape = TensorShape{20, 2, 3, 3}; | |||||
| checker.set_dtype(0, dtype::QuantizedS1(1.0f)) | |||||
| .set_dtype(1, dtype::QuantizedS1(1.0f)) | |||||
| .set_dtype(2, dtype::QuantizedS32(1.0f)) | |||||
| .set_dtype(3, dtype::QuantizedS32(1.0f)) | |||||
| .set_dtype(4, dtype::QuantizedS32(1.0f)) | |||||
| .set_rng(0, &int_rng) | |||||
| .set_rng(1, &int_rng) | |||||
| .set_param(param) | |||||
| .execs({src_shape, filter_shape, {}, {}, {}}); | |||||
| } | |||||
| } | |||||
| TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) { | TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) { | ||||
| require_compute_capability(6, 1); | require_compute_capability(6, 1); | ||||
| @@ -1509,7 +1509,7 @@ def sync_batch_norm( | |||||
| """ | """ | ||||
| _eps_mode = eps_mode.lower() | _eps_mode = eps_mode.lower() | ||||
| assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | ||||
| if _eps_mode == "additive" and not (is_distributed() or training): | |||||
| if _eps_mode == "additive" and not (is_distributed() and training): | |||||
| return batch_norm( | return batch_norm( | ||||
| inp, | inp, | ||||
| running_mean, | running_mean, | ||||
| @@ -1244,7 +1244,6 @@ def tile(inp: Tensor, reps: Iterable[int]): | |||||
| inp = _tile_one_dim(inp, rep, i) | inp = _tile_one_dim(inp, rep, i) | ||||
| if l_reps > l_shape: | if l_reps > l_shape: | ||||
| shape = inp.shape | |||||
| extra = reps[:-l_shape] | extra = reps[:-l_shape] | ||||
| extra_ones = ones_like(extra) | extra_ones = ones_like(extra) | ||||
| base_shape = concat([extra_ones, shape]) | base_shape = concat([extra_ones, shape]) | ||||
| @@ -53,7 +53,10 @@ def _assert_equal( | |||||
| """ | """ | ||||
| err = ( | err = ( | ||||
| abs(expect - actual) | abs(expect - actual) | ||||
| / maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32")) | |||||
| / maximum( | |||||
| minimum(abs(expect), abs(actual)), | |||||
| Tensor(1.0, dtype="float32", device=expect.device), | |||||
| ) | |||||
| ).max() | ).max() | ||||
| result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] | result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] | ||||
| _sync() # sync interpreter to get exception | _sync() # sync interpreter to get exception | ||||
| @@ -660,16 +660,16 @@ def interpolate( | |||||
| if mode != "linear": | if mode != "linear": | ||||
| wscale = (iw - 1.0) / (ow - 1.0) | wscale = (iw - 1.0) / (ow - 1.0) | ||||
| row0 = concat( | row0 = concat( | ||||
| [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0 | |||||
| ).reshape(1, 3) | |||||
| row1 = concat( | |||||
| [ | [ | ||||
| Tensor(0, dtype="float32", device=inp.device), | |||||
| hscale, | |||||
| Tensor(0, dtype="float32", device=inp.device), | |||||
| Tensor(wscale, dtype="float32", device=inp.device), | |||||
| Tensor([0, 0], dtype="float32", device=inp.device), | |||||
| ], | ], | ||||
| axis=0, | axis=0, | ||||
| ).reshape(1, 3) | ).reshape(1, 3) | ||||
| zeros = Tensor([0], dtype="float32", device=inp.device) | |||||
| row1 = concat( | |||||
| [zeros, Tensor(hscale, dtype="float32", device=inp.device), zeros], axis=0, | |||||
| ).reshape(1, 3) | |||||
| weight = concat( | weight = concat( | ||||
| [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | ||||
| axis=0, | axis=0, | ||||
| @@ -557,7 +557,14 @@ void init_ops(py::module m) { | |||||
| m.def( | m.def( | ||||
| "delete_rng_handle", | "delete_rng_handle", | ||||
| [](size_t handle) { | [](size_t handle) { | ||||
| if (mgb::imperative::python::interpreter_for_py->check_available()) { | |||||
| mgb::imperative::python::interpreter_for_py->sync(); | |||||
| } | |||||
| mgb::CompNode::sync_all(); | mgb::CompNode::sync_all(); | ||||
| mgb::CompNode::foreach ([](mgb::CompNode cn) { | |||||
| auto err = cn.check_async_error(); | |||||
| mgb_assert(!err, "%s", err->what()); | |||||
| }); | |||||
| py_task_q.wait_all_task_finish(); | py_task_q.wait_all_task_finish(); | ||||
| rng::delete_handle(handle); | rng::delete_handle(handle); | ||||
| }, | }, | ||||
| @@ -169,7 +169,8 @@ PyObject* py_apply( | |||||
| } | } | ||||
| HostTensorND ht(target_cn); | HostTensorND ht(target_cn); | ||||
| ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); | ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); | ||||
| if (PyArray_Check(args[i])) { // non scaler | |||||
| if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler | |||||
| // py_tuple is not allowed here because of tracing | |||||
| return imperative::apply( | return imperative::apply( | ||||
| CreateTensor(CreateTensor::Const, target_cn, ht.layout()), | CreateTensor(CreateTensor::Const, target_cn, ht.layout()), | ||||
| HostStorage::make(ht.storage()))[0]; | HostStorage::make(ht.storage()))[0]; | ||||
| @@ -189,8 +190,14 @@ PyObject* py_apply( | |||||
| if (is_symbol_var[i]) { | if (is_symbol_var[i]) { | ||||
| symbol_var_idx = i; | symbol_var_idx = i; | ||||
| tensors[i] = context.symvar2val(args[i]); | tensors[i] = context.symvar2val(args[i]); | ||||
| } else { | |||||
| } else if ( | |||||
| DTypePromoteCfg::convert_input_enabled && | |||||
| op->same_type<Elemwise>()) { | |||||
| tensors[i] = convert_pyinput_to_tensor(i); | tensors[i] = convert_pyinput_to_tensor(i); | ||||
| } else { | |||||
| PyErr_SetString( | |||||
| PyExc_TypeError, "py_apply expects tensor as inputs"); | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| auto outputs = imperative::apply(*op, tensors); | auto outputs = imperative::apply(*op, tensors); | ||||
| @@ -205,8 +212,13 @@ PyObject* py_apply( | |||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | ||||
| tensors[i] = tw->m_tensor->data(); | tensors[i] = tw->m_tensor->data(); | ||||
| } else { | |||||
| } else if ( | |||||
| DTypePromoteCfg::convert_input_enabled && | |||||
| op->same_type<Elemwise>()) { | |||||
| tensors[i] = convert_pyinput_to_tensor(i); | tensors[i] = convert_pyinput_to_tensor(i); | ||||
| } else { | |||||
| PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| @@ -957,14 +957,14 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { | |||||
| } | } | ||||
| bool enable_fastpath(py::handle inp) { | bool enable_fastpath(py::handle inp) { | ||||
| // FIXME: the way to judge whether it is in traced module is inaccurate | |||||
| auto&& tm_tr = TransformationManager::get_instance() | |||||
| .segments[TransformationManager::Segment::ModuleTrace]; | |||||
| if (!TensorWrapper::try_cast(inp.ptr()) || | if (!TensorWrapper::try_cast(inp.ptr()) || | ||||
| TransformationManager::get_instance() | TransformationManager::get_instance() | ||||
| .segments[TransformationManager::Segment::Trace] | .segments[TransformationManager::Segment::Trace] | ||||
| .size() > 0 || | .size() > 0 || | ||||
| TransformationManager::get_instance() | |||||
| .segments[TransformationManager::Segment::ModuleTrace] | |||||
| .size() > 0) { | |||||
| (tm_tr.size() > 0 && | |||||
| reinterpret_cast<ModuleTraceTransformation*>(tm_tr[0].get())->enabled())) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -11,13 +11,17 @@ import sys | |||||
| import pytest | import pytest | ||||
| import megengine.functional | |||||
| import megengine.module | |||||
| from megengine import Parameter | |||||
| from megengine.core._imperative_rt.core2 import sync | |||||
| from megengine.core import _config as config | |||||
| from megengine.core import _trace_option as trace_option | |||||
| from megengine.core import get_option | |||||
| from megengine.core._imperative_rt.core2 import ( | |||||
| _get_amp_dtype_autocast, | |||||
| _get_amp_high_prec_dtype, | |||||
| _get_amp_low_prec_dtype, | |||||
| _get_convert_inputs, | |||||
| ) | |||||
| from megengine.core.tensor import amp | |||||
| from megengine.device import get_device_count | from megengine.device import get_device_count | ||||
| from megengine.jit import trace as _trace | |||||
| from megengine.module import Linear, Module | |||||
| sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | ||||
| @@ -41,3 +45,58 @@ def skip_distributed(request): | |||||
| platform.system() | platform.system() | ||||
| ) | ) | ||||
| ) | ) | ||||
| @pytest.fixture(autouse=True) | |||||
| def run_around_tests(): | |||||
| env_vars1 = { | |||||
| "symbolic_shape": trace_option.use_symbolic_shape(), | |||||
| "async_level": get_option("async_level"), | |||||
| "enable_drop": get_option("enable_drop"), | |||||
| "max_recompute_time": get_option("max_recompute_time"), | |||||
| "catch_worker_execption": get_option("catch_worker_execption"), | |||||
| "enable_host_compute": get_option("enable_host_compute"), | |||||
| # "record_computing_path": get_option("record_computing_path"), | |||||
| "disable_memory_forwarding": get_option("disable_memory_forwarding"), | |||||
| "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), | |||||
| "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), | |||||
| "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), | |||||
| "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), | |||||
| "benchmark_kernel": config.benchmark_kernel, | |||||
| "deterministic_kernel": config.deterministic_kernel, | |||||
| "compute_mode": config._compute_mode, | |||||
| "conv_format": config._conv_format, | |||||
| "amp_enabled": amp.enabled, | |||||
| "convert_inputs": _get_convert_inputs(), | |||||
| "amp_dtype_autocast": _get_amp_dtype_autocast(), | |||||
| "amp_high_prec_dtype": _get_amp_high_prec_dtype(), | |||||
| "amp_low_prec_dtype": _get_amp_low_prec_dtype(), | |||||
| } | |||||
| yield | |||||
| env_vars2 = { | |||||
| "symbolic_shape": trace_option.use_symbolic_shape(), | |||||
| "async_level": get_option("async_level"), | |||||
| "enable_drop": get_option("enable_drop"), | |||||
| "max_recompute_time": get_option("max_recompute_time"), | |||||
| "catch_worker_execption": get_option("catch_worker_execption"), | |||||
| "enable_host_compute": get_option("enable_host_compute"), | |||||
| # "record_computing_path": get_option("record_computing_path"), | |||||
| "disable_memory_forwarding": get_option("disable_memory_forwarding"), | |||||
| "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), | |||||
| "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), | |||||
| "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), | |||||
| "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), | |||||
| "benchmark_kernel": config.benchmark_kernel, | |||||
| "deterministic_kernel": config.deterministic_kernel, | |||||
| "compute_mode": config._compute_mode, | |||||
| "conv_format": config._conv_format, | |||||
| "amp_enabled": amp.enabled, | |||||
| "convert_inputs": _get_convert_inputs(), | |||||
| "amp_dtype_autocast": _get_amp_dtype_autocast(), | |||||
| "amp_high_prec_dtype": _get_amp_high_prec_dtype(), | |||||
| "amp_low_prec_dtype": _get_amp_low_prec_dtype(), | |||||
| } | |||||
| for key in env_vars1: | |||||
| assert ( | |||||
| env_vars1[key] == env_vars2[key] | |||||
| ), "{} have been changed after test".format(key) | |||||
| @@ -37,7 +37,7 @@ if [[ "$TEST_PLAT" =~ "local" ]]; then | |||||
| PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'not isolated_distributed' | PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'not isolated_distributed' | ||||
| if [[ "$TEST_PLAT" =~ "cuda" ]]; then | if [[ "$TEST_PLAT" =~ "cuda" ]]; then | ||||
| echo "test GPU pytest now" | echo "test GPU pytest now" | ||||
| PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' | |||||
| PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' --ignore=./integration/test_dtr.py | |||||
| fi | fi | ||||
| else | else | ||||
| cd $(dirname "${BASH_SOURCE[0]}")/.. | cd $(dirname "${BASH_SOURCE[0]}")/.. | ||||
| @@ -77,6 +77,11 @@ def test_div(): | |||||
| np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), | np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), | ||||
| ) | ) | ||||
| np.testing.assert_allclose( | |||||
| (tensor([[5, 4, 3], [4, 2, 6]]) // [1, 2, 1]).numpy(), | |||||
| np.floor_divide(np.array([[5, 4, 3], [4, 2, 6]], dtype=np.int32), [1, 2, 1]), | |||||
| ) | |||||
| def test_clamp(): | def test_clamp(): | ||||
| """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and | """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and | ||||
| @@ -206,31 +206,31 @@ def test_interpolate(): | |||||
| def linear_interpolate(): | def linear_interpolate(): | ||||
| inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | ||||
| out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear") | |||||
| out2 = F.vision.interpolate(inp, 4, mode="linear") | |||||
| np.testing.assert_allclose( | |||||
| out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) | |||||
| test_func = lambda inp: F.vision.interpolate( | |||||
| inp, scale_factor=2.0, mode="linear" | |||||
| ) | ) | ||||
| ref_func = lambda inp: F.vision.interpolate(inp, 4, mode="linear").numpy() | |||||
| cases = [{"input": inp}] | |||||
| opr_test(cases, test_func, ref_fn=ref_func, test_trace=True) | |||||
| def many_batch_interpolate(): | def many_batch_interpolate(): | ||||
| inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2)) | inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2)) | ||||
| out = F.vision.interpolate(inp, [4, 4]) | |||||
| out2 = F.vision.interpolate(inp, scale_factor=2.0) | |||||
| test_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0) | |||||
| ref_func = lambda inp: F.vision.interpolate(inp, [4, 4]).numpy() | |||||
| np.testing.assert_allclose(out.numpy(), out2.numpy()) | |||||
| cases = [{"input": inp}] | |||||
| opr_test(cases, test_func, ref_fn=ref_func, test_trace=True) | |||||
| def assign_corner_interpolate(): | def assign_corner_interpolate(): | ||||
| inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | ||||
| out = F.vision.interpolate(inp, [4, 4], align_corners=True) | |||||
| out2 = F.vision.interpolate(inp, scale_factor=2.0, align_corners=True) | |||||
| test_func = lambda inp: F.vision.interpolate(inp, [4, 4]) | |||||
| ref_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0).numpy() | |||||
| np.testing.assert_allclose(out.numpy(), out2.numpy()) | |||||
| cases = [{"input": inp}] | |||||
| opr_test(cases, test_func, ref_fn=ref_func, test_trace=True) | |||||
| def error_shape_linear_interpolate(): | def error_shape_linear_interpolate(): | ||||
| inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | ||||
| @@ -248,7 +248,7 @@ def test_interpolate(): | |||||
| many_batch_interpolate() | many_batch_interpolate() | ||||
| assign_corner_interpolate() | assign_corner_interpolate() | ||||
| error_shape_linear_interpolate() | error_shape_linear_interpolate() | ||||
| inappropriate_scale_linear_interpolate() | |||||
| # inappropriate_scale_linear_interpolate() | |||||
| def _save_to(self, name="grad"): | def _save_to(self, name="grad"): | ||||
| @@ -831,7 +831,8 @@ def test_repeat(shape, repeats, axis, is_varnode): | |||||
| ((2,), (2,)), | ((2,), (2,)), | ||||
| ((2, 3, 4, 5), (1, 1, 1, 1)), | ((2, 3, 4, 5), (1, 1, 1, 1)), | ||||
| ((2, 3, 4, 5), (1, 2, 3, 4)), | ((2, 3, 4, 5), (1, 2, 3, 4)), | ||||
| ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||||
| # FIXME: tile does not support ndim 7 | |||||
| # ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||||
| ], | ], | ||||
| ) | ) | ||||
| @pytest.mark.parametrize("is_varnode", [True]) | @pytest.mark.parametrize("is_varnode", [True]) | ||||
| @@ -21,7 +21,6 @@ import megengine.optimizer as optim | |||||
| import megengine.utils.comp_graph_tools as cgtools | import megengine.utils.comp_graph_tools as cgtools | ||||
| from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
| from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
| from megengine.core._trace_option import set_symbolic_shape | |||||
| from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| from megengine.core.tensor.utils import isscalar | from megengine.core.tensor.utils import isscalar | ||||
| @@ -39,8 +39,6 @@ from megengine.random import uniform | |||||
| get_device_count("xpu") <= 2, reason="xpu counts need > 2", | get_device_count("xpu") <= 2, reason="xpu counts need > 2", | ||||
| ) | ) | ||||
| def test_gaussian_op(): | def test_gaussian_op(): | ||||
| # FIXME: remove this sync | |||||
| mge.core.set_option("async_level", 0) | |||||
| set_global_seed(1024) | set_global_seed(1024) | ||||
| shape = ( | shape = ( | ||||
| 8, | 8, | ||||
| @@ -516,4 +514,3 @@ def test_rng_empty_tensor(is_symbolic): | |||||
| np.testing.assert_equal(out.numpy().shape, (0,)) | np.testing.assert_equal(out.numpy().shape, (0,)) | ||||
| if is_symbolic is None: | if is_symbolic is None: | ||||
| break | break | ||||
| mge.core.set_option("async_level", 2) | |||||
| @@ -10,8 +10,6 @@ from megengine.core._trace_option import set_symbolic_shape | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.traced_module import trace_module | from megengine.traced_module import trace_module | ||||
| set_symbolic_shape(True) | |||||
| class Main(M.Module): | class Main(M.Module): | ||||
| def forward(self, x): | def forward(self, x): | ||||
| @@ -61,6 +59,7 @@ class Net(M.Module): | |||||
| def test_preprocess(): | def test_preprocess(): | ||||
| saved = set_symbolic_shape(True) | |||||
| module = Main() | module = Main() | ||||
| data = F.ones((1, 14, 8, 8), dtype=np.uint8) | data = F.ones((1, 14, 8, 8), dtype=np.uint8) | ||||
| traced_module = trace_module(module, data) | traced_module = trace_module(module, data) | ||||
| @@ -88,3 +87,5 @@ def test_preprocess(): | |||||
| y, | y, | ||||
| atol=1e-6, | atol=1e-6, | ||||
| ) | ) | ||||
| set_symbolic_shape(saved) | |||||
| @@ -11,8 +11,6 @@ from megengine.core._trace_option import set_symbolic_shape | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.traced_module import trace_module | from megengine.traced_module import trace_module | ||||
| set_symbolic_shape(True) | |||||
| class Main(M.Module): | class Main(M.Module): | ||||
| def forward(self, x): | def forward(self, x): | ||||
| @@ -64,6 +62,7 @@ class Net(M.Module): | |||||
| def test_preprocess(): | def test_preprocess(): | ||||
| saved = set_symbolic_shape(True) | |||||
| batch_size = 2 | batch_size = 2 | ||||
| module = Main() | module = Main() | ||||
| data = mge.tensor( | data = mge.tensor( | ||||
| @@ -92,3 +91,5 @@ def test_preprocess(): | |||||
| infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values() | infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values() | ||||
| )[0] | )[0] | ||||
| np.testing.assert_allclose(expect, actual) | np.testing.assert_allclose(expect, actual) | ||||
| set_symbolic_shape(saved) | |||||
| @@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||||
| if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { | if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { | ||||
| ptr->to_contiguous_inplace(); | ptr->to_contiguous_inplace(); | ||||
| } | } | ||||
| dest->desc.layout = ptr->layout(); | |||||
| dest->desc.comp_node = ptr->comp_node(); | dest->desc.comp_node = ptr->comp_node(); | ||||
| dest->memory = ptr->blob()->size(); | dest->memory = ptr->blob()->size(); | ||||
| dest->ptr = std::move(ptr); | dest->ptr = std::move(ptr); | ||||
| @@ -175,10 +175,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| megdnn::Workspace dnn_wk; | megdnn::Workspace dnn_wk; | ||||
| auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout); | auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout); | ||||
| if (wk_size != 0) { | |||||
| auto wk = Blob::make(comp_node, wk_size); | |||||
| dnn_wk.raw_ptr = wk->storage().get(); | |||||
| dnn_wk.size = wk_size; | |||||
| if (wk_size) { | |||||
| TensorLayout w_layout({wk_size}, dtype::Byte()); | |||||
| dnn_wk = dnn_op.create_workspace(w_layout); | |||||
| } | } | ||||
| DeviceTensorND out = | DeviceTensorND out = | ||||
| @@ -205,6 +204,12 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| size_t size = inputs.size(); | size_t size = inputs.size(); | ||||
| SmallVector<LogicalTensorDesc> dests(size); | SmallVector<LogicalTensorDesc> dests(size); | ||||
| for (size_t i = 0; i < size; i++) { | |||||
| if (inputs[i].layout.ndim == 0) { | |||||
| return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}}, | |||||
| false}; | |||||
| } | |||||
| } | |||||
| if (size > 1) { | if (size > 1) { | ||||
| auto [output_descs, validated] = | auto [output_descs, validated] = | ||||
| proxy_graph_detail::infer_output_attrs_fallible(def, inputs); | proxy_graph_detail::infer_output_attrs_fallible(def, inputs); | ||||
| @@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| template <typename Op> | template <typename Op> | ||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
| bool success = inputs[0].layout.ndim != 0; | |||||
| LogicalTensorDesc dest; | LogicalTensorDesc dest; | ||||
| auto&& xxx_rng_def = def.cast_final_safe<Op>(); | auto&& xxx_rng_def = def.cast_final_safe<Op>(); | ||||
| size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
| @@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| xxx_rng_def.dyn_typeinfo()->name, nr_inp); | xxx_rng_def.dyn_typeinfo()->name, nr_inp); | ||||
| } | } | ||||
| dest.comp_node = inputs[0].comp_node; | dest.comp_node = inputs[0].comp_node; | ||||
| dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); | |||||
| if (success) { | |||||
| dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); | |||||
| } else { | |||||
| dest.layout = TensorLayout(inputs[0].layout.dtype); | |||||
| } | |||||
| return {{dest}, inputs[0].layout.ndim != 0}; | return {{dest}, inputs[0].layout.ndim != 0}; | ||||
| } | } | ||||
| @@ -115,6 +115,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| TensorShapeArray src(inputs.size()); | TensorShapeArray src(inputs.size()); | ||||
| for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
| src[i] = inputs[i].layout; | src[i] = inputs[i].layout; | ||||
| if (!src[i].ndim) { | |||||
| return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | |||||
| } | |||||
| } | } | ||||
| megdnn::Elemwise::deduce_shape(src, shp); | megdnn::Elemwise::deduce_shape(src, shp); | ||||
| } | } | ||||
| @@ -67,10 +67,15 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { | |||||
| void NetworkImplDft::application_config() { | void NetworkImplDft::application_config() { | ||||
| auto device_type = m_user_config->device_type; | auto device_type = m_user_config->device_type; | ||||
| m_compnode_locator.type = to_compnode_locator(device_type).type; | m_compnode_locator.type = to_compnode_locator(device_type).type; | ||||
| m_compnode_locator.device = m_user_config->device_id; | |||||
| //! when the device id is not configured, configure it | |||||
| if (m_compnode_locator.device == -1) { | |||||
| m_compnode_locator.device = m_user_config->device_id; | |||||
| } | |||||
| if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) { | if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) { | ||||
| m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; | m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; | ||||
| m_compnode_locator.device = m_user_config->device_id; | |||||
| if (m_compnode_locator.device == -1) { | |||||
| m_compnode_locator.device = m_user_config->device_id; | |||||
| } | |||||
| } | } | ||||
| //! model options | //! model options | ||||
| #define ConfigOption(mge_name, lite_name) \ | #define ConfigOption(mge_name, lite_name) \ | ||||
| @@ -155,11 +160,13 @@ void NetworkImplDft::set_cpu_inplace_mode() { | |||||
| m_is_cpu_inplace_mode = true; | m_is_cpu_inplace_mode = true; | ||||
| if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) { | if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) { | ||||
| m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | ||||
| m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | |||||
| } else { | } else { | ||||
| LITE_ASSERT( | LITE_ASSERT( | ||||
| m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD, | m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD, | ||||
| "cpu inplace mode is only avaliable in CPU."); | "cpu inplace mode is only avaliable in CPU."); | ||||
| m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | ||||
| m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||||
| } | } | ||||
| } | } | ||||
| @@ -170,6 +177,12 @@ void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) { | |||||
| if (nr_threads > 1) { | if (nr_threads > 1) { | ||||
| m_nr_threads = nr_threads; | m_nr_threads = nr_threads; | ||||
| m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; | m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; | ||||
| if (m_is_cpu_inplace_mode) { | |||||
| m_compnode_locator.device = | |||||
| mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||||
| m_user_config->device_id = | |||||
| mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||||
| } | |||||
| m_compnode_locator.nr_threads = nr_threads; | m_compnode_locator.nr_threads = nr_threads; | ||||
| } | } | ||||
| } | } | ||||
| @@ -216,6 +216,57 @@ TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { | |||||
| compare_lite_tensor<float>(output_tensor, result_mgb); | compare_lite_tensor<float>(output_tensor, result_mgb); | ||||
| } | } | ||||
| namespace { | |||||
| void test_multi_thread(bool multi_thread_compnode) { | |||||
| Config config; | |||||
| auto lite_tensor = get_input_data("./input_data.npy"); | |||||
| std::string model_path = "./shufflenet.mge"; | |||||
| size_t nr_threads = 2; | |||||
| std::vector<std::thread::id> thread_ids(nr_threads); | |||||
| auto runner = [&](size_t i) { | |||||
| std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
| Runtime::set_cpu_inplace_mode(network); | |||||
| if (multi_thread_compnode) { | |||||
| Runtime::set_cpu_threads_number(network, 2); | |||||
| } | |||||
| network->load_model(model_path); | |||||
| Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) { | |||||
| if (id == 0) { | |||||
| thread_ids[i] = std::this_thread::get_id(); | |||||
| } | |||||
| }); | |||||
| std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0); | |||||
| auto src_ptr = lite_tensor->get_memory_ptr(); | |||||
| auto src_layout = lite_tensor->get_layout(); | |||||
| input_tensor->reset(src_ptr, src_layout); | |||||
| network->forward(); | |||||
| network->wait(); | |||||
| std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0); | |||||
| }; | |||||
| std::vector<std::thread> threads; | |||||
| for (size_t i = 0; i < nr_threads; i++) { | |||||
| threads.emplace_back(runner, i); | |||||
| } | |||||
| for (size_t i = 0; i < nr_threads; i++) { | |||||
| threads[i].join(); | |||||
| } | |||||
| ASSERT_NE(thread_ids[0], thread_ids[1]); | |||||
| } | |||||
| } // namespace | |||||
| TEST(TestNetWork, InplaceAndUserMultithreadThread) { | |||||
| test_multi_thread(false); | |||||
| } | |||||
| TEST(TestNetWork, InplaceAndMultithread) { | |||||
| test_multi_thread(true); | |||||
| } | |||||
| TEST(TestNetWork, NetworkShareWeights) { | TEST(TestNetWork, NetworkShareWeights) { | ||||
| Config config; | Config config; | ||||
| auto lite_tensor = get_input_data("./input_data.npy"); | auto lite_tensor = get_input_data("./input_data.npy"); | ||||
| @@ -14,8 +14,8 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #define MGE_MAJOR 1 | #define MGE_MAJOR 1 | ||||
| #define MGE_MINOR 8 | |||||
| #define MGE_PATCH 0 | |||||
| #define MGE_MINOR 9 | |||||
| #define MGE_PATCH 1 | |||||
| // for rc version, could be like "rc1", "rc2", etc | // for rc version, could be like "rc1", "rc2", etc | ||||
| #define MGE_EXTRA_NAME "" | #define MGE_EXTRA_NAME "" | ||||