| @@ -1 +1 @@ | |||
| Subproject commit f308919c39811c2c3e07fb0dcc8054a533c84cbc | |||
| Subproject commit 2956e64803cad9b84316cdf2b25d034c5f944ccc | |||
| @@ -21,3 +21,5 @@ from .softmax import expand_softmax | |||
| from .square import expand_square | |||
| from .bias_add import expand_biasadd | |||
| from .bias_add_grad import expand_biasaddgrad | |||
| from .fused_adam import expand_fusedadam | |||
| from .fused_adam_weight_decay import expand_fusedadamweightdecay | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for fused_adam""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| def expand_fusedadam(expand_info): | |||
| """FusedAdma expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| input_desc_3 = expand_info['input_desc'][3] | |||
| input_desc_4 = expand_info['input_desc'][4] | |||
| input_desc_5 = expand_info['input_desc'][5] | |||
| input_desc_6 = expand_info['input_desc'][6] | |||
| input_desc_7 = expand_info['input_desc'][7] | |||
| input_desc_8 = expand_info['input_desc'][8] | |||
| input_desc_9 = expand_info['input_desc'][9] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient) | |||
| # compute result | |||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | |||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | |||
| next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||
| beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | |||
| grad_square = graph_builder.emit('Mul', [gradient, gradient]) | |||
| one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | |||
| next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||
| sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | |||
| sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) | |||
| update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | |||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | |||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||
| param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) | |||
| v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) | |||
| # set graph output. | |||
| graph_scope.set_output(param_result, m_result, v_result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| @@ -0,0 +1,76 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for fused_adam_weight_decay""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| def expand_fusedadamweightdecay(expand_info): | |||
| """FusedAdmaWeightDecay expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| input_desc_3 = expand_info['input_desc'][3] | |||
| input_desc_4 = expand_info['input_desc'][4] | |||
| input_desc_5 = expand_info['input_desc'][5] | |||
| input_desc_6 = expand_info['input_desc'][6] | |||
| input_desc_7 = expand_info['input_desc'][7] | |||
| input_desc_8 = expand_info['input_desc'][8] | |||
| input_desc_9 = expand_info['input_desc'][9] | |||
| input_desc_10 = expand_info['input_desc'][10] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||
| weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format']) | |||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, | |||
| eps, lr, param, m, v, gradient, weight_decay) | |||
| # compute result | |||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | |||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | |||
| next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||
| beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | |||
| grad_square = graph_builder.emit('Mul', [gradient, gradient]) | |||
| one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | |||
| next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||
| sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | |||
| sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) | |||
| update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | |||
| param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) | |||
| update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay]) | |||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | |||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||
| para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) | |||
| v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) | |||
| # set graph output. | |||
| graph_scope.set_output(para_result, m_result, v_result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| @@ -154,6 +154,8 @@ class PrimLib: | |||
| 'ControlDepend': Prim(CONTROL), | |||
| 'Assign': Prim(ELEMWISE), | |||
| 'Tanh': Prim(ELEMWISE), | |||
| 'ExpandDims': Prim(ELEMWISE), | |||
| 'InplaceAssign': Prim(ELEMWISE), | |||
| '@ReduceInit': Prim(ELEMWISE), | |||
| } | |||
| @@ -70,6 +70,7 @@ class OpInfer: | |||
| infer_shape_func = { | |||
| # add special infer func here | |||
| 'InplaceAssign': lambda inputs, attrs: inputs[2].shape | |||
| } | |||
| infer_dtype_func = { | |||
| # add special infer func here | |||
| @@ -560,7 +560,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { | |||
| auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { | |||
| auto shape = GetNodeShape(node); | |||
| if (shape.size() != 0 && shape.size() != 1) { | |||
| return node; | |||
| return nullptr; | |||
| } else { | |||
| auto tmp_node = node->cast<CNodePtr>(); | |||
| auto transpose_node = tmp_node->input(1); | |||
| @@ -635,7 +635,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { | |||
| AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); | |||
| return new_cnode; | |||
| } | |||
| return node; | |||
| return nullptr; | |||
| }; | |||
| auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr { | |||
| auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node); | |||
| @@ -702,7 +702,8 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| std::unordered_set<PrimitivePtr> expand_ops = { | |||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, prim::kPrimGeluGrad, | |||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, | |||
| prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, | |||
| }; | |||
| return expand_ops; | |||
| } | |||
| @@ -729,7 +730,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() { | |||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh}; | |||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape}; | |||
| return fusible_basic_ops; | |||
| } | |||
| @@ -174,6 +174,8 @@ inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive> | |||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | |||
| inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | |||
| inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | |||
| inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam"); | |||
| inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay"); | |||
| // Comm ops | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| @@ -20,6 +20,7 @@ from .hsigmoid import _hsigmoid_akg | |||
| from .hsigmoid_grad import _hsigmoid_grad_akg | |||
| from .hswish import _hswish_akg | |||
| from .hswish_grad import _hswish_grad_akg | |||
| from .inplace_assign import _inplace_assign_akg | |||
| from .lessequal import _lessequal_akg | |||
| from .logical_and import _logical_and_akg | |||
| from .logical_not import _logical_not_akg | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """InplaceAssign op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT | |||
| op_info = AkgGpuRegOp("InplaceAssign") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .input(0, "x") \ | |||
| .input(1, "y") \ | |||
| .input(2, "z") \ | |||
| .output(0, "output") \ | |||
| .attr("fake_output", "optional", "bool") \ | |||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||
| .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||
| .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||
| .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ | |||
| .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ | |||
| .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register(op_info) | |||
| def _inplace_assign_akg(): | |||
| """InplaceAssign Akg register""" | |||
| return | |||
| @@ -82,7 +82,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl | |||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler) | |||
| from . import _quant_ops | |||
| from ._quant_ops import * | |||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | |||
| from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | |||
| CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull) | |||
| from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | |||
| CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | |||
| @@ -65,6 +65,36 @@ class Assign(PrimitiveWithCheck): | |||
| validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) | |||
| class InplaceAssign(PrimitiveWithInfer): | |||
| """ | |||
| Inplace assign `Parameter` with a value. | |||
| This primitive can only use in graph kernel. | |||
| Inputs: | |||
| - **variable** (Parameter) - The `Parameter`. | |||
| - **value** (Tensor) - The value to be assigned. | |||
| - **depend** (Tensor) - The dependent tensor to keep this op connected in graph. | |||
| Outputs: | |||
| Tensor, has the same type as original `variable`. | |||
| Examples: | |||
| >>> def construct(self, x): | |||
| >>> val = x - 1.0 | |||
| >>> ret = x + 2.0 | |||
| >>> return InplaceAssign()(x, val, ret) | |||
| >>> x = Tensor([2.0], mindspore.float32) | |||
| >>> net = Net() | |||
| >>> net(x) | |||
| """ | |||
| @ prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output']) | |||
| def infer_shape(self, x, y, z): | |||
| return z | |||
| def infer_dtype(self, x, y, z): | |||
| return z | |||
| class BoundingBoxEncode(PrimitiveWithInfer): | |||
| """ | |||
| Encodes bounding boxes locations. | |||
| @@ -509,6 +539,7 @@ class PopulationCount(PrimitiveWithInfer): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) | |||
| return mstype.tensor_type(mstype.uint8) | |||
| class Push(PrimitiveWithInfer): | |||
| """ | |||
| Pushes the inputs of the corresponding optimizer to parameter server. | |||
| @@ -539,6 +570,7 @@ class Push(PrimitiveWithInfer): | |||
| def infer_dtype(self, inputs, shapes): | |||
| return mstype.uint64 | |||
| class Pull(PrimitiveWithInfer): | |||
| """ | |||
| Pulls weight from parameter server. | |||
| @@ -563,6 +595,7 @@ class Pull(PrimitiveWithInfer): | |||
| def infer_dtype(self, key_dtype, weight_dtype): | |||
| return mstype.float32 | |||
| class identity(Primitive): | |||
| """ | |||
| Makes a identify primitive, used for pynative mode. | |||
| @@ -52,6 +52,7 @@ def CalGelu(x): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gelu(): | |||
| np.random.seed(0) | |||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| net = GeluNet() | |||
| @@ -67,6 +68,7 @@ def test_gelu(): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gelu_grad(): | |||
| np.random.seed(0) | |||
| input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| input_y = CalGelu(input_x) | |||