| @@ -1 +1 @@ | |||||
| Subproject commit f308919c39811c2c3e07fb0dcc8054a533c84cbc | |||||
| Subproject commit 2956e64803cad9b84316cdf2b25d034c5f944ccc | |||||
| @@ -21,3 +21,5 @@ from .softmax import expand_softmax | |||||
| from .square import expand_square | from .square import expand_square | ||||
| from .bias_add import expand_biasadd | from .bias_add import expand_biasadd | ||||
| from .bias_add_grad import expand_biasaddgrad | from .bias_add_grad import expand_biasaddgrad | ||||
| from .fused_adam import expand_fusedadam | |||||
| from .fused_adam_weight_decay import expand_fusedadamweightdecay | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for fused_adam""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_fusedadam(expand_info): | |||||
| """FusedAdma expander""" | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| input_desc_3 = expand_info['input_desc'][3] | |||||
| input_desc_4 = expand_info['input_desc'][4] | |||||
| input_desc_5 = expand_info['input_desc'][5] | |||||
| input_desc_6 = expand_info['input_desc'][6] | |||||
| input_desc_7 = expand_info['input_desc'][7] | |||||
| input_desc_8 = expand_info['input_desc'][8] | |||||
| input_desc_9 = expand_info['input_desc'][9] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient) | |||||
| # compute result | |||||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | |||||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | |||||
| next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||||
| beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | |||||
| grad_square = graph_builder.emit('Mul', [gradient, gradient]) | |||||
| one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | |||||
| next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||||
| sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | |||||
| sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) | |||||
| update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | |||||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | |||||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||||
| param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||||
| m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) | |||||
| v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) | |||||
| # set graph output. | |||||
| graph_scope.set_output(param_result, m_result, v_result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -0,0 +1,76 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for fused_adam_weight_decay""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_fusedadamweightdecay(expand_info): | |||||
| """FusedAdmaWeightDecay expander""" | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| input_desc_3 = expand_info['input_desc'][3] | |||||
| input_desc_4 = expand_info['input_desc'][4] | |||||
| input_desc_5 = expand_info['input_desc'][5] | |||||
| input_desc_6 = expand_info['input_desc'][6] | |||||
| input_desc_7 = expand_info['input_desc'][7] | |||||
| input_desc_8 = expand_info['input_desc'][8] | |||||
| input_desc_9 = expand_info['input_desc'][9] | |||||
| input_desc_10 = expand_info['input_desc'][10] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||||
| weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format']) | |||||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, | |||||
| eps, lr, param, m, v, gradient, weight_decay) | |||||
| # compute result | |||||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | |||||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | |||||
| next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||||
| beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | |||||
| grad_square = graph_builder.emit('Mul', [gradient, gradient]) | |||||
| one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | |||||
| next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||||
| sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | |||||
| sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) | |||||
| update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | |||||
| param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) | |||||
| update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay]) | |||||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | |||||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||||
| para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||||
| m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) | |||||
| v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) | |||||
| # set graph output. | |||||
| graph_scope.set_output(para_result, m_result, v_result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -154,6 +154,8 @@ class PrimLib: | |||||
| 'ControlDepend': Prim(CONTROL), | 'ControlDepend': Prim(CONTROL), | ||||
| 'Assign': Prim(ELEMWISE), | 'Assign': Prim(ELEMWISE), | ||||
| 'Tanh': Prim(ELEMWISE), | 'Tanh': Prim(ELEMWISE), | ||||
| 'ExpandDims': Prim(ELEMWISE), | |||||
| 'InplaceAssign': Prim(ELEMWISE), | |||||
| '@ReduceInit': Prim(ELEMWISE), | '@ReduceInit': Prim(ELEMWISE), | ||||
| } | } | ||||
| @@ -70,6 +70,7 @@ class OpInfer: | |||||
| infer_shape_func = { | infer_shape_func = { | ||||
| # add special infer func here | # add special infer func here | ||||
| 'InplaceAssign': lambda inputs, attrs: inputs[2].shape | |||||
| } | } | ||||
| infer_dtype_func = { | infer_dtype_func = { | ||||
| # add special infer func here | # add special infer func here | ||||
| @@ -560,7 +560,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { | |||||
| auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { | auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { | ||||
| auto shape = GetNodeShape(node); | auto shape = GetNodeShape(node); | ||||
| if (shape.size() != 0 && shape.size() != 1) { | if (shape.size() != 0 && shape.size() != 1) { | ||||
| return node; | |||||
| return nullptr; | |||||
| } else { | } else { | ||||
| auto tmp_node = node->cast<CNodePtr>(); | auto tmp_node = node->cast<CNodePtr>(); | ||||
| auto transpose_node = tmp_node->input(1); | auto transpose_node = tmp_node->input(1); | ||||
| @@ -635,7 +635,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { | |||||
| AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); | AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); | ||||
| return new_cnode; | return new_cnode; | ||||
| } | } | ||||
| return node; | |||||
| return nullptr; | |||||
| }; | }; | ||||
| auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr { | auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr { | ||||
| auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node); | auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node); | ||||
| @@ -702,7 +702,8 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | std::unordered_set<PrimitivePtr> GetExpandOps() { | ||||
| std::unordered_set<PrimitivePtr> expand_ops = { | std::unordered_set<PrimitivePtr> expand_ops = { | ||||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, prim::kPrimGeluGrad, | |||||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, | |||||
| prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, | |||||
| }; | }; | ||||
| return expand_ops; | return expand_ops; | ||||
| } | } | ||||
| @@ -729,7 +730,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() { | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | ||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | ||||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | ||||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh}; | |||||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape}; | |||||
| return fusible_basic_ops; | return fusible_basic_ops; | ||||
| } | } | ||||
| @@ -174,6 +174,8 @@ inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive> | |||||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | ||||
| inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | ||||
| inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | ||||
| inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam"); | |||||
| inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay"); | |||||
| // Comm ops | // Comm ops | ||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| @@ -20,6 +20,7 @@ from .hsigmoid import _hsigmoid_akg | |||||
| from .hsigmoid_grad import _hsigmoid_grad_akg | from .hsigmoid_grad import _hsigmoid_grad_akg | ||||
| from .hswish import _hswish_akg | from .hswish import _hswish_akg | ||||
| from .hswish_grad import _hswish_grad_akg | from .hswish_grad import _hswish_grad_akg | ||||
| from .inplace_assign import _inplace_assign_akg | |||||
| from .lessequal import _lessequal_akg | from .lessequal import _lessequal_akg | ||||
| from .logical_and import _logical_and_akg | from .logical_and import _logical_and_akg | ||||
| from .logical_not import _logical_not_akg | from .logical_not import _logical_not_akg | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """InplaceAssign op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT | |||||
| op_info = AkgGpuRegOp("InplaceAssign") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .input(2, "z") \ | |||||
| .output(0, "output") \ | |||||
| .attr("fake_output", "optional", "bool") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||||
| .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ | |||||
| .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ | |||||
| .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _inplace_assign_akg(): | |||||
| """InplaceAssign Akg register""" | |||||
| return | |||||
| @@ -82,7 +82,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl | |||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler) | ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler) | ||||
| from . import _quant_ops | from . import _quant_ops | ||||
| from ._quant_ops import * | from ._quant_ops import * | ||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | |||||
| from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | |||||
| CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull) | CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull) | ||||
| from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | ||||
| CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | ||||
| @@ -65,6 +65,36 @@ class Assign(PrimitiveWithCheck): | |||||
| validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) | validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) | ||||
| class InplaceAssign(PrimitiveWithInfer): | |||||
| """ | |||||
| Inplace assign `Parameter` with a value. | |||||
| This primitive can only use in graph kernel. | |||||
| Inputs: | |||||
| - **variable** (Parameter) - The `Parameter`. | |||||
| - **value** (Tensor) - The value to be assigned. | |||||
| - **depend** (Tensor) - The dependent tensor to keep this op connected in graph. | |||||
| Outputs: | |||||
| Tensor, has the same type as original `variable`. | |||||
| Examples: | |||||
| >>> def construct(self, x): | |||||
| >>> val = x - 1.0 | |||||
| >>> ret = x + 2.0 | |||||
| >>> return InplaceAssign()(x, val, ret) | |||||
| >>> x = Tensor([2.0], mindspore.float32) | |||||
| >>> net = Net() | |||||
| >>> net(x) | |||||
| """ | |||||
| @ prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output']) | |||||
| def infer_shape(self, x, y, z): | |||||
| return z | |||||
| def infer_dtype(self, x, y, z): | |||||
| return z | |||||
| class BoundingBoxEncode(PrimitiveWithInfer): | class BoundingBoxEncode(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Encodes bounding boxes locations. | Encodes bounding boxes locations. | ||||
| @@ -509,6 +539,7 @@ class PopulationCount(PrimitiveWithInfer): | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) | validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) | ||||
| return mstype.tensor_type(mstype.uint8) | return mstype.tensor_type(mstype.uint8) | ||||
| class Push(PrimitiveWithInfer): | class Push(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Pushes the inputs of the corresponding optimizer to parameter server. | Pushes the inputs of the corresponding optimizer to parameter server. | ||||
| @@ -539,6 +570,7 @@ class Push(PrimitiveWithInfer): | |||||
| def infer_dtype(self, inputs, shapes): | def infer_dtype(self, inputs, shapes): | ||||
| return mstype.uint64 | return mstype.uint64 | ||||
| class Pull(PrimitiveWithInfer): | class Pull(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Pulls weight from parameter server. | Pulls weight from parameter server. | ||||
| @@ -563,6 +595,7 @@ class Pull(PrimitiveWithInfer): | |||||
| def infer_dtype(self, key_dtype, weight_dtype): | def infer_dtype(self, key_dtype, weight_dtype): | ||||
| return mstype.float32 | return mstype.float32 | ||||
| class identity(Primitive): | class identity(Primitive): | ||||
| """ | """ | ||||
| Makes a identify primitive, used for pynative mode. | Makes a identify primitive, used for pynative mode. | ||||
| @@ -52,6 +52,7 @@ def CalGelu(x): | |||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_gelu(): | def test_gelu(): | ||||
| np.random.seed(0) | |||||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | ||||
| net = GeluNet() | net = GeluNet() | ||||
| @@ -67,6 +68,7 @@ def test_gelu(): | |||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_gelu_grad(): | def test_gelu_grad(): | ||||
| np.random.seed(0) | |||||
| input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | ||||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | ||||
| input_y = CalGelu(input_x) | input_y = CalGelu(input_x) | ||||