GitOrigin-RevId: 82833f41d9
tags/v1.0.0
| @@ -179,6 +179,11 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
| add_enum_alias('Format', 'ConvolutionV0') | add_enum_alias('Format', 'ConvolutionV0') | ||||
| ) | ) | ||||
| (pdef('AdaptivePooling'). | |||||
| add_enum_alias('Mode', 'Pooling'). | |||||
| add_enum_alias('Format', 'ConvolutionV0') | |||||
| ) | |||||
| (pdef('LRN', | (pdef('LRN', | ||||
| 'see ImageNet Classification with Deep Convolutional Neural Networks for' | 'see ImageNet Classification with Deep Convolutional Neural Networks for' | ||||
| ' meaning of the fields'). | ' meaning of the fields'). | ||||
| @@ -0,0 +1,148 @@ | |||||
| /** | |||||
| * \file src/opr/impl/dnn/adaptive_pooling.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
| #include "../internal/megdnn_opr_wrapper.inl" | |||||
| #include "megbrain/graph/grad_impl.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megdnn/oprs/nn.h" | |||||
| using namespace mgb; | |||||
| using namespace opr; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingForward); | |||||
| AdaptivePoolingForward::AdaptivePoolingForward(VarNode* src, VarNode* out_shape, | |||||
| const Param& param, | |||||
| const OperatorNodeConfig& config) | |||||
| : Super(OperatorNodeBaseCtorParam{src->owner_graph(), | |||||
| config, | |||||
| "adaptive_pooling", | |||||
| {src, out_shape}}) { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({src, out_shape}); | |||||
| outshape_by_symvar_enable(1, 1); | |||||
| } | |||||
| SymbolVar AdaptivePoolingForward::make(SymbolVar src, SymbolVar out_shape, | |||||
| const Param& param, | |||||
| const OperatorNodeConfig& config) { | |||||
| return src.insert_single_output_opr<AdaptivePoolingForward>( | |||||
| src.node(), out_shape.node(), param, config); | |||||
| } | |||||
| void AdaptivePoolingForward::scn_do_execute() { | |||||
| megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||||
| output(0)->dev_tensor().as_megdnn(), | |||||
| intl::get_megdnn_workspace_from_var(output().back())); | |||||
| } | |||||
| void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( | |||||
| TensorShape& dest, const ShapeInferInfo& shpinfo) { | |||||
| TensorShape oshp2d; | |||||
| cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); | |||||
| auto src = shpinfo.shape_inp_shp.at(0); | |||||
| mgb_assert(src.ndim == 4 && oshp2d.ndim == 2, | |||||
| "shape mismatch for AdaptivePooling: src=%s, out2d=%s", | |||||
| src.to_string().c_str(), oshp2d.to_string().c_str()); | |||||
| mgb_assert(param().format == Param::Format::NCHW, | |||||
| "AdaptivePooling only support NCHW"); | |||||
| dest.ndim = 4; | |||||
| dest.shape[0] = src.shape[0]; | |||||
| dest.shape[1] = src.shape[1]; | |||||
| dest.shape[2] = oshp2d.shape[0]; | |||||
| dest.shape[3] = oshp2d.shape[1]; | |||||
| } | |||||
| size_t AdaptivePoolingForward::get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const { | |||||
| return megdnn_opr()->get_workspace_in_bytes( | |||||
| {input_shapes[0], this->input(0)->dtype(), | |||||
| this->input(0)->format()}, | |||||
| {output_shapes[0], this->output(0)->dtype(), | |||||
| this->output(0)->format()}); | |||||
| } | |||||
| void AdaptivePoolingForward::init_output_dtype() { | |||||
| output(0)->dtype(input(0)->dtype()); | |||||
| } | |||||
| void AdaptivePoolingForward::add_input_layout_constraint() { | |||||
| mixin::megdnn_utils::add_input_layout_constraint_contig(*this); | |||||
| } | |||||
| void AdaptivePoolingForward::init_output_static_infer_desc() { | |||||
| Super::init_output_static_infer_desc(); | |||||
| init_output_static_infer_desc_workspace(false); | |||||
| } | |||||
| void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) { | |||||
| record_megdnn_opr(deps); | |||||
| } | |||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) { | |||||
| if (wrt_idx == 0) { | |||||
| // wrt src | |||||
| SymbolVar grad = AdaptivePoolingBackward::make( | |||||
| opr.input(0), opr.input(1), opr.output(0), out_grad[0], | |||||
| opr.param()); | |||||
| return grad.node(); | |||||
| } else { | |||||
| mgb_assert(wrt_idx == 1); | |||||
| return InvalidGrad::make(opr, wrt_idx); | |||||
| } | |||||
| } | |||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingBackward); | |||||
| AdaptivePoolingBackward::AdaptivePoolingBackward( | |||||
| VarNode* src, VarNode* out_shape, VarNode* dst, VarNode* diff, | |||||
| const Param& param, const OperatorNodeConfig& config) | |||||
| : Super(OperatorNodeBaseCtorParam{src->owner_graph(), | |||||
| config, | |||||
| "adaptive_pooling_bwd", | |||||
| {src}}, | |||||
| 0, true) { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({src, out_shape, dst, diff}); | |||||
| } | |||||
| SymbolVar AdaptivePoolingBackward::make(SymbolVar src, SymbolVar out_shape, | |||||
| SymbolVar dst, SymbolVar diff, | |||||
| const Param& param, | |||||
| const OperatorNodeConfig& config) { | |||||
| return src.insert_single_output_opr<AdaptivePoolingBackward>( | |||||
| src.node(), out_shape.node(), dst.node(), diff.node(), param, | |||||
| config); | |||||
| } | |||||
| void AdaptivePoolingBackward::scn_do_execute() { | |||||
| megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||||
| input(2)->dev_tensor().as_megdnn(), | |||||
| input(3)->dev_tensor().as_megdnn(), | |||||
| output(0)->dev_tensor().as_megdnn(), | |||||
| intl::get_megdnn_workspace_from_var(output().back())); | |||||
| } | |||||
| size_t AdaptivePoolingBackward::get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const { | |||||
| return megdnn_opr()->get_workspace_in_bytes( | |||||
| {input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
| {input_shapes[2], input(2)->dtype(), input(2)->format()}, | |||||
| {input_shapes[3], input(3)->dtype(), input(3)->format()}, | |||||
| {output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -140,6 +140,13 @@ decl_opr('Pooling', | |||||
| inputs=['src'], | inputs=['src'], | ||||
| params='Pooling') | params='Pooling') | ||||
| decl_opr('AdaptivePooling', | |||||
| inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'), | |||||
| Doc('out_shape', 'output image shape, containing two elements specifying output height and width.')], | |||||
| params='AdaptivePooling', | |||||
| desc='Adaptive Pooling.' | |||||
| 'The output shape is (n, c, oh, ow), where (oh, ow) is given by *out_shape*.') | |||||
| decl_opr('ROIPooling', outputs=[0], | decl_opr('ROIPooling', outputs=[0], | ||||
| inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'), | inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'), | ||||
| Doc('rois', 'regions of interest, shape (m, 5). ' | Doc('rois', 'regions of interest, shape (m, 5). ' | ||||
| @@ -258,7 +265,7 @@ decl_opr('ROIAlign', outputs=[0], | |||||
| 'store it as a float, but it should be an integral value.' | 'store it as a float, but it should be an integral value.' | ||||
| ' The rois[:, 1:5] are (x0, y0, x1, y1) for each ROI, ' | ' The rois[:, 1:5] are (x0, y0, x1, y1) for each ROI, ' | ||||
| 'which would be multiplied by the scale value given in ' | 'which would be multiplied by the scale value given in ' | ||||
| 'param.')], | |||||
| 'param.')], | |||||
| params='ROIAlign', | params='ROIAlign', | ||||
| desc='ROI Align, see ' | desc='ROI Align, see ' | ||||
| 'Mask-RCNN: https://arxiv.org/pdf/1703.06870.pdf, ' | 'Mask-RCNN: https://arxiv.org/pdf/1703.06870.pdf, ' | ||||
| @@ -295,7 +302,7 @@ decl_opr('BatchConvBiasForward', | |||||
| ('execution_policy', 'ExecutionPolicy')], | ('execution_policy', 'ExecutionPolicy')], | ||||
| desc=Doc(None, | desc=Doc(None, | ||||
| r""" | r""" | ||||
| Apply a convolution of input tensor and filter tensor whose weights are not shared in batch dimensions. Outputs with batch index use the same weight. | |||||
| Apply a convolution of input tensor and filter tensor whose weights are not shared in batch dimensions. Outputs with batch index use the same weight. | |||||
| Assume input shape is :math:`(N, IC, IH, IW)` and filter shape is :math:`(batch, OC, IC, FH, FW)`, the output shape will be :math:`(N, OC, OH, OW)` where :math:`(OH, OW)` would be computed from padding, stride, :math:`(FH, FW)` and :math:`(IH, IW)`, as in convolution. | Assume input shape is :math:`(N, IC, IH, IW)` and filter shape is :math:`(batch, OC, IC, FH, FW)`, the output shape will be :math:`(N, OC, OH, OW)` where :math:`(OH, OW)` would be computed from padding, stride, :math:`(FH, FW)` and :math:`(IH, IW)`, as in convolution. | ||||
| for each output location, we have; | for each output location, we have; | ||||
| @@ -13,6 +13,7 @@ | |||||
| #include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
| #include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
| #include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
| #include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
| #include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
| @@ -388,6 +389,9 @@ namespace opr { | |||||
| MGB_SEREG_OPR(Pooling, 1); | MGB_SEREG_OPR(Pooling, 1); | ||||
| MGB_SEREG_OPR(PoolingBackward, 3); | MGB_SEREG_OPR(PoolingBackward, 3); | ||||
| MGB_SEREG_OPR(AdaptivePooling, 2); | |||||
| MGB_SEREG_OPR(AdaptivePoolingBackward, 4); | |||||
| MGB_SEREG_OPR(ROIPooling, 3); | MGB_SEREG_OPR(ROIPooling, 3); | ||||
| MGB_SEREG_OPR(ROIPoolingBackward, 4); | MGB_SEREG_OPR(ROIPoolingBackward, 4); | ||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * \file src/opr/include/megbrain/opr/dnn/adaptive_pooling.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megdnn/oprs/nn.h" | |||||
| namespace mgb { | |||||
| namespace opr { | |||||
| MGB_DEFINE_OPR_CLASS( | |||||
| AdaptivePoolingForward, | |||||
| intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr< | |||||
| mixin::MegDNNOprHolderImpl<megdnn::AdaptivePoolingForward>>>) // { | |||||
| public: | |||||
| AdaptivePoolingForward(VarNode * src, VarNode * out_shape, | |||||
| const Param& param, | |||||
| const OperatorNodeConfig& config); | |||||
| static SymbolVar make(SymbolVar src, SymbolVar out_shape, | |||||
| const Param& param, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| static SymbolVar make(SymbolVar src, const TensorShape& out_shape, | |||||
| const Param& param, | |||||
| const OperatorNodeConfig& config = {}) { | |||||
| return make(src, cg::var_from_tensor_shape(src, out_shape), param, | |||||
| config); | |||||
| } | |||||
| private: | |||||
| void scn_do_execute() override; | |||||
| void outshape_by_symvar_do_get_output_shape( | |||||
| TensorShape & dest, const ShapeInferInfo& shpinfo) override; | |||||
| size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) | |||||
| const override; | |||||
| void init_output_dtype() override; | |||||
| void add_input_layout_constraint() override; | |||||
| void init_output_static_infer_desc() override; | |||||
| void record_execute_deps(ExecDependencyArray& deps) override; | |||||
| }; | |||||
| using AdaptivePooling = AdaptivePoolingForward; | |||||
| MGB_DEFINE_OPR_CLASS( | |||||
| AdaptivePoolingBackward, | |||||
| intl::MegDNNOprWrapperBwd<megdnn::AdaptivePoolingBackward>) // { | |||||
| public: | |||||
| AdaptivePoolingBackward(VarNode * src, VarNode * out_shape, VarNode * dst, | |||||
| VarNode * diff, const Param& param, | |||||
| const OperatorNodeConfig& config); | |||||
| static SymbolVar make(SymbolVar src, SymbolVar out_shape, SymbolVar dst, | |||||
| SymbolVar diff, const Param& param, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| private: | |||||
| void scn_do_execute() override; | |||||
| size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) | |||||
| const override; | |||||
| }; | |||||
| } // namespace opr | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * \file src/opr/test/dnn/adaptive_pooling.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
| #include "megbrain/comp_node_env.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/test/autocheck.h" | |||||
| #include "megbrain/test/megdnn_helper.h" | |||||
| #include "megdnn/dtype.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| using namespace std; | |||||
| using namespace mgb; | |||||
| namespace { | |||||
| using Param = opr::AdaptivePoolingForward::Param; | |||||
| void run(Param::Mode mode) { | |||||
| using Checker = AutoOprChecker<2, 1>; | |||||
| Param param{mode}; | |||||
| auto make_graph = | |||||
| [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
| auto o0 = opr::GetVarShape::make(inputs[1]); | |||||
| auto o1 = opr::AdaptivePoolingForward::make(inputs[0], o0, param); | |||||
| return {o1}; | |||||
| }; | |||||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
| auto opr = MegDNNHandle::get( | |||||
| CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||||
| ->create_operator<megdnn::AdaptivePoolingForward>(); | |||||
| opr->param() = param; | |||||
| size_t N = inp[0].get()->shape(0), C = inp[0].get()->shape(1); | |||||
| size_t OH = inp[1].get()->shape(0), OW = inp[1].get()->shape(1); | |||||
| dest[0].resize(TensorShape{N, C, OH, OW}); | |||||
| opr->exec(inp[0]->as_megdnn(), dest[0].as_megdnn(), {}); | |||||
| }; | |||||
| auto gen = [&](HostTensorND& src) { | |||||
| if (mode == Param::Mode::MAX) { | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::CONSECUTIVE> | |||||
| src_gen(1.0f, 0.1f); | |||||
| src = *src_gen(src.shape(), src.comp_node()); | |||||
| } else { | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> | |||||
| src_gen(10.f); | |||||
| src = *src_gen(src.shape(), src.comp_node()); | |||||
| } | |||||
| }; | |||||
| Checker::RunOptions opt; | |||||
| opt.numdiff_max_err = 1e-2; | |||||
| Checker checker{make_graph, fwd}; | |||||
| checker.set_input_allow_grad(1, false) | |||||
| .set_input_generator(0, gen); | |||||
| checker.run({TensorShape{1, 1, 10, 7}, TensorShape{5, 4}}, opt); | |||||
| checker.run({TensorShape{1, 1, 9, 7}, TensorShape{5, 4}}, opt); | |||||
| checker.run({TensorShape{1, 2, 8, 9}, TensorShape{3, 4}}, opt); | |||||
| } | |||||
| } // anonymous namespace | |||||
| TEST(TestOprDNN, AdaptivePoolingMax) { | |||||
| run(Param::Mode::MAX); | |||||
| } | |||||
| TEST(TestOprDNN, AdaptivePoolingAverage) { | |||||
| run(Param::Mode::AVERAGE); | |||||
| } | |||||
| TEST(TestOprDNN, AdaptivePoolingAverageCountExcludePadding) { | |||||
| run(Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -99,6 +99,7 @@ union OperatorParam { | |||||
| DType = 67, | DType = 67, | ||||
| param.Remap = 68, | param.Remap = 68, | ||||
| param.NMSKeep = 69, | param.NMSKeep = 69, | ||||
| param.AdaptivePooling = 70, | |||||
| } | } | ||||
| table Operator { | table Operator { | ||||
| @@ -113,6 +113,20 @@ dtype, RandomDistribution::CONSTANT>::operator ()( | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| template<typename dtype> | |||||
| std::shared_ptr<HostTensorND> HostTensorGenerator< | |||||
| dtype, RandomDistribution::CONSECUTIVE>::operator ()( | |||||
| const TensorShape &shape, CompNode cn) { | |||||
| if (!cn.valid()) | |||||
| cn = CompNode::load("xpu0"); | |||||
| std::shared_ptr<HostTensorND> ret = | |||||
| std::make_shared<HostTensorND>(cn, shape, dtype()); | |||||
| auto ptr = ret->ptr<ctype>(); | |||||
| for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++ i) { | |||||
| ptr[i] = m_val + i * m_delta; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| // explicit instantialization of HostTensorGenerator | // explicit instantialization of HostTensorGenerator | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -122,12 +136,16 @@ namespace mgb { | |||||
| dtype::Float32, RandomDistribution::UNIFORM>; | dtype::Float32, RandomDistribution::UNIFORM>; | ||||
| template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
| dtype::Float32, RandomDistribution::CONSTANT>; | dtype::Float32, RandomDistribution::CONSTANT>; | ||||
| template class HostTensorGenerator< | |||||
| dtype::Float32, RandomDistribution::CONSECUTIVE>; | |||||
| template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
| dtype::Float16, RandomDistribution::GAUSSIAN>; | dtype::Float16, RandomDistribution::GAUSSIAN>; | ||||
| template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
| dtype::Int8, RandomDistribution::UNIFORM>; | dtype::Int8, RandomDistribution::UNIFORM>; | ||||
| template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
| dtype::Int8, RandomDistribution::CONSTANT>; | dtype::Int8, RandomDistribution::CONSTANT>; | ||||
| template class HostTensorGenerator< | |||||
| dtype::Int8, RandomDistribution::CONSECUTIVE>; | |||||
| template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
| dtype::Uint8, RandomDistribution::UNIFORM>; | dtype::Uint8, RandomDistribution::UNIFORM>; | ||||
| template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
| @@ -168,7 +168,7 @@ class RNGxorshf { | |||||
| }; | }; | ||||
| enum class RandomDistribution { | enum class RandomDistribution { | ||||
| GAUSSIAN, UNIFORM, CONSTANT | |||||
| GAUSSIAN, UNIFORM, CONSTANT, CONSECUTIVE | |||||
| }; | }; | ||||
| template<class dtype> | template<class dtype> | ||||
| @@ -342,6 +342,29 @@ class HostTensorGenerator<dtype, RandomDistribution::CONSTANT> final: | |||||
| private: | private: | ||||
| ctype m_default_val; | ctype m_default_val; | ||||
| }; | }; | ||||
| //! consecutive value | |||||
| template<class dtype> | |||||
| class HostTensorGenerator<dtype, RandomDistribution::CONSECUTIVE> final: | |||||
| public HostTensorGeneratorBase { | |||||
| public: | |||||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||||
| HostTensorGenerator(ctype val, ctype delta) | |||||
| : HostTensorGeneratorBase{next_rand_seed()}, | |||||
| m_val{val}, m_delta{delta} {} | |||||
| std::shared_ptr<HostTensorND> operator ()( | |||||
| const TensorShape &shape, CompNode cn = {}) override; | |||||
| using HostTensorGeneratorBase::operator(); | |||||
| private: | |||||
| ctype m_val; | |||||
| ctype m_delta; | |||||
| }; | |||||
| template <> | template <> | ||||
| class HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM> final | class HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM> final | ||||
| : public HostTensorGeneratorBase { | : public HostTensorGeneratorBase { | ||||