| @@ -10,9 +10,9 @@ | |||||
| [submodule "third_party/protobuf"] | [submodule "third_party/protobuf"] | ||||
| path = third_party/protobuf | path = third_party/protobuf | ||||
| url = https://github.com/protocolbuffers/protobuf.git | url = https://github.com/protocolbuffers/protobuf.git | ||||
| [submodule "graphengine"] | |||||
| path = graphengine | |||||
| url = https://gitee.com/mindspore/graphengine.git | |||||
| [submodule "akg"] | [submodule "akg"] | ||||
| path = akg | path = akg | ||||
| url = https://gitee.com/mindspore/akg.git | url = https://gitee.com/mindspore/akg.git | ||||
| [submodule "graphengine"] | |||||
| path = graphengine | |||||
| url = https://gitee.com/ms-incubator/graphengine.git | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit dda72a48c7e0033389bd377c5804d485fdf3112d | |||||
| Subproject commit 8891f0546c4a250095ff68e1262f58772b938fd9 | |||||
| @@ -141,7 +141,7 @@ if (ENABLE_GE) | |||||
| else () | else () | ||||
| target_link_libraries(mindspore ge_client) | target_link_libraries(mindspore ge_client) | ||||
| endif () | endif () | ||||
| target_link_libraries(mindspore graph tsdclient) | |||||
| target_link_libraries(mindspore graph tsdclient datatransfer) | |||||
| endif() | endif() | ||||
| if (ENABLE_D) | if (ENABLE_D) | ||||
| @@ -29,6 +29,7 @@ constexpr auto kInitData = "InitData"; | |||||
| constexpr auto kGetNext = "GetNext"; | constexpr auto kGetNext = "GetNext"; | ||||
| constexpr auto kPrint = "Print"; | constexpr auto kPrint = "Print"; | ||||
| constexpr auto kPack = "Pack"; | constexpr auto kPack = "Pack"; | ||||
| constexpr auto kOutputTypes = "output_types"; | constexpr auto kOutputTypes = "output_types"; | ||||
| constexpr auto kOutputShapes = "output_shapes"; | constexpr auto kOutputShapes = "output_shapes"; | ||||
| constexpr auto kChannelName = "channel_name"; | constexpr auto kChannelName = "channel_name"; | ||||
| @@ -58,7 +58,6 @@ class GetMakeRefEliminater : public OptimizerCaller { | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); | MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); | MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); | MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -1168,7 +1168,7 @@ INPUT_MAP(SparseApplyAdagradD) = { | |||||
| {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; | {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; | ||||
| ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())}, | ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())}, | ||||
| {"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | {"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; | |||||
| OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; | |||||
| // ApplyProximalAdagradD | // ApplyProximalAdagradD | ||||
| INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, | INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, | ||||
| @@ -181,14 +181,21 @@ bool MsContext::OpenTsd() { | |||||
| } | } | ||||
| MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; | MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; | ||||
| #if (defined(ENABLE_TDTQUE) && defined(ENABLE_GE)) | |||||
| int32_t initStatus = tdt::TdtHostInit(device_id); | |||||
| if (initStatus != TDT_OK_CODE) { | |||||
| MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << "."; | |||||
| return false; | |||||
| } | |||||
| tdt_print_ = std::thread(TensorPrint()); | |||||
| #endif | |||||
| TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size); | TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size); | ||||
| if (status != TDT_OK) { | if (status != TDT_OK) { | ||||
| MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; | MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| tsd_ref_++; | tsd_ref_++; | ||||
| #ifdef ENABLE_TDTQUE | |||||
| #if (defined(ENABLE_TDTQUE) && !defined(ENABLE_GE)) | |||||
| int32_t initStatus = tdt::TdtHostInit(device_id); | int32_t initStatus = tdt::TdtHostInit(device_id); | ||||
| if (initStatus != TDT_OK_CODE) { | if (initStatus != TDT_OK_CODE) { | ||||
| MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << "."; | MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << "."; | ||||
| @@ -342,7 +342,6 @@ class Optimizer(Cell): | |||||
| current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) | current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) | ||||
| lr += (current_dynamic_lr,) | lr += (current_dynamic_lr,) | ||||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | F.control_depend(lr, self.assignadd(self.global_step, 1)) | ||||
| else: | else: | ||||
| lr = self.learning_rate | lr = self.learning_rate | ||||
| if self.dynamic_lr: | if self.dynamic_lr: | ||||
| @@ -518,6 +518,18 @@ def get_bprop_l2_loss(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.RNNTLoss) | |||||
| def get_bprop_rnnt_loss(self): | |||||
| """Grad definition for `RNNTLoss` operation.""" | |||||
| expand = P.ExpandDims() | |||||
| def bprop(acts, labels, act_lens, label_lens, out, dout): | |||||
| grad_loss = out[1] | |||||
| grad = grad_loss * expand(expand(expand(dout[0], -1), -1), -1) | |||||
| return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens) | |||||
| return bprop | |||||
| @bprop_getters.register(P.PReLU) | @bprop_getters.register(P.PReLU) | ||||
| def get_bprop_prelu(self): | def get_bprop_prelu(self): | ||||
| """Grad definition for `PReLU` operation.""" | """Grad definition for `PReLU` operation.""" | ||||
| @@ -14,6 +14,7 @@ | |||||
| """aicpu ops""" | """aicpu ops""" | ||||
| from .init_data_set_queue import _init_data_set_queue_aicpu | from .init_data_set_queue import _init_data_set_queue_aicpu | ||||
| from .embedding_lookup import _embedding_lookup_aicpu | |||||
| from .dropout_genmask import _dropout_genmask_aicpu | from .dropout_genmask import _dropout_genmask_aicpu | ||||
| from .get_next import _get_next_aicpu | from .get_next import _get_next_aicpu | ||||
| from .print_tensor import _print_aicpu | from .print_tensor import _print_aicpu | ||||
| @@ -29,3 +30,6 @@ from .normal import _normal_aicpu | |||||
| from .ctcloss import _ctcloss_aicpu | from .ctcloss import _ctcloss_aicpu | ||||
| from .reverse_sequence import _reverse_sequence_aicpu | from .reverse_sequence import _reverse_sequence_aicpu | ||||
| from .crop_and_resize import _crop_and_resize_aicpu | from .crop_and_resize import _crop_and_resize_aicpu | ||||
| from .rnnt_loss import _rnnt_loss_aicpu | |||||
| from .random_categorical import _random_categorical_aicpu | |||||
| from .cast import _cast_aicpu | |||||
| @@ -0,0 +1,172 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Cast op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| cast_op_info = AiCPURegOp("Cast") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(cast_op_info) | |||||
| def _cast_aicpu(): | |||||
| """Cast AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,102 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """EmbeddingLookup op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| embeddingLookup_op_info = AiCPURegOp("EmbeddingLookup") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "params", "required") \ | |||||
| .input(1, "indices", "required") \ | |||||
| .input(2, "offset", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(embeddingLookup_op_info) | |||||
| def _embedding_lookup_aicpu(): | |||||
| """EmbeddingLookup AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """RandomCategorical op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| random_categorical_op_info = AiCPURegOp("RandomCategorical") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "logits", "required") \ | |||||
| .input(1, "num_sample", "required") \ | |||||
| .input(2, "seed", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(random_categorical_op_info) | |||||
| def _random_categorical_aicpu(): | |||||
| """RandomCategorical AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """RNNTLoss op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| rnnt_loss_op_info = AiCPURegOp("RNNTLoss") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "acts", "required") \ | |||||
| .input(1, "labels", "required") \ | |||||
| .input(2, "input_lengths", "required") \ | |||||
| .input(3, "label_lengths", "required") \ | |||||
| .output(0, "costs", "required") \ | |||||
| .output(1, "grads", "required") \ | |||||
| .attr("blank_label", "int") \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(rnnt_loss_op_info) | |||||
| def _rnnt_loss_aicpu(): | |||||
| """RNNTLoss AiCPU register""" | |||||
| return | |||||
| @@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, | Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, | ||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) | Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) | ||||
| from .random_ops import (RandomChoiceWithMask, Normal) | |||||
| from .random_ops import (RandomChoiceWithMask, Normal, RandomCategorical) | |||||
| from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, | from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, | ||||
| BiasAdd, Conv2D, | BiasAdd, Conv2D, | ||||
| DepthwiseConv2dNative, | DepthwiseConv2dNative, | ||||
| @@ -69,6 +69,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||||
| ResizeBilinear, Sigmoid, | ResizeBilinear, Sigmoid, | ||||
| SigmoidCrossEntropyWithLogits, | SigmoidCrossEntropyWithLogits, | ||||
| SmoothL1Loss, Softmax, Softplus, | SmoothL1Loss, Softmax, Softplus, | ||||
| RNNTLoss, | |||||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | SoftmaxCrossEntropyWithLogits, ROIAlign, | ||||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | SparseSoftmaxCrossEntropyWithLogits, Tanh, | ||||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | ||||
| @@ -77,6 +78,8 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | ||||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop) | CheckValid, MakeRefKey, Partial, Depend, CheckBprop) | ||||
| from . import _quant_ops | |||||
| from ._quant_ops import * | |||||
| from .thor_ops import * | from .thor_ops import * | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -168,6 +171,7 @@ __all__ = [ | |||||
| 'Tanh', | 'Tanh', | ||||
| 'RandomChoiceWithMask', | 'RandomChoiceWithMask', | ||||
| 'Normal', | 'Normal', | ||||
| 'RandomCategorical', | |||||
| 'ResizeBilinear', | 'ResizeBilinear', | ||||
| 'ScalarSummary', | 'ScalarSummary', | ||||
| 'ImageSummary', | 'ImageSummary', | ||||
| @@ -198,6 +202,7 @@ __all__ = [ | |||||
| 'SmoothL1Loss', | 'SmoothL1Loss', | ||||
| 'L2Loss', | 'L2Loss', | ||||
| 'CTCLoss', | 'CTCLoss', | ||||
| 'RNNTLoss', | |||||
| 'ReduceAll', | 'ReduceAll', | ||||
| 'ScalarToArray', | 'ScalarToArray', | ||||
| 'ScalarToTensor', | 'ScalarToTensor', | ||||
| @@ -302,6 +307,7 @@ __all__ = [ | |||||
| "ApplyCenteredRMSProp", | "ApplyCenteredRMSProp", | ||||
| "SpaceToBatchND", | "SpaceToBatchND", | ||||
| "BatchToSpaceND", | "BatchToSpaceND", | ||||
| "ReverseSequence", | |||||
| "SquareSumAll", | "SquareSumAll", | ||||
| "BitwiseAnd", | "BitwiseAnd", | ||||
| "BitwiseOr", | "BitwiseOr", | ||||
| @@ -315,7 +321,8 @@ __all__ = [ | |||||
| "DataFormatDimMap", | "DataFormatDimMap", | ||||
| "ApproximateEqual", | "ApproximateEqual", | ||||
| "InplaceUpdate", | "InplaceUpdate", | ||||
| "InTopK" | |||||
| "InTopK", | |||||
| "CropAndResize" | |||||
| ] | ] | ||||
| __all__.sort() | __all__.sort() | ||||
| @@ -1093,8 +1093,18 @@ class StridedSliceGrad(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) | self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) | ||||
| def __infer__(self, dy, shapex, begin, end, strides): | def __infer__(self, dy, shapex, begin, end, strides): | ||||
| args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']} | |||||
| args = {"dy": dy['dtype']} | |||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | validator.check_tensor_type_same(args, mstype.number_type, self.name) | ||||
| for idx, item in enumerate(shapex['value']): | |||||
| validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) | |||||
| for idx, item in enumerate(begin['value']): | |||||
| validator.check_value_type("begin[%d]" % idx, item, [int], self.name) | |||||
| for idx, item in enumerate(end['value']): | |||||
| validator.check_value_type("end[%d]" % idx, item, [int], self.name) | |||||
| for idx, item in enumerate(strides['value']): | |||||
| validator.check_value_type("strides[%d]" % idx, item, [int], self.name) | |||||
| return {'shape': shapex['value'], | return {'shape': shapex['value'], | ||||
| 'dtype': dy['dtype'], | 'dtype': dy['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| @@ -1697,6 +1697,60 @@ class DataFormatDimMap(PrimitiveWithInfer): | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | ||||
| return x_type | return x_type | ||||
| class RNNTLoss(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes the RNNTLoss and its gradient with respect to the softmax outputs. | |||||
| Args: | |||||
| blank_label (int): blank label. Default: 0. | |||||
| Inputs: | |||||
| - **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`. | |||||
| - **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`. | |||||
| - **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. | |||||
| - **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. | |||||
| Outputs: | |||||
| - **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`. | |||||
| - **grads** (Tensor[int32]) - Has the same shape as `acts`. | |||||
| Examples: | |||||
| >>> B, T, U, V = 1, 2, 3, 5 | |||||
| >>> acts = np.random.random((B, T, U, V)).astype(np.float32) | |||||
| >>> labels = np.array([[1, 2]]).astype(np.int32) | |||||
| >>> input_length = np.array([T] * B).astype(np.int32) | |||||
| >>> label_length = np.array([len(l) for l in labels]).astype(np.int32) | |||||
| >>> rnnt_loss = P.RNNTLoss(blank_label=blank) | |||||
| >>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, blank_label=0): | |||||
| validator.check_value_type('blank_label', blank_label, [int], self.name) | |||||
| self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'], | |||||
| outputs=['costs', 'grads']) | |||||
| def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape): | |||||
| validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name) | |||||
| validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name) | |||||
| validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) | |||||
| validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||||
| validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||||
| validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||||
| costs_shape = (acts_shape[0],) | |||||
| return (costs_shape, acts_shape) | |||||
| def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type): | |||||
| validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name) | |||||
| validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name) | |||||
| validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name) | |||||
| validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name) | |||||
| validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name) | |||||
| validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name) | |||||
| return (acts_type, acts_type) | |||||
| class SGD(PrimitiveWithInfer): | class SGD(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -108,3 +108,60 @@ class Normal(PrimitiveWithInfer): | |||||
| "dtype": mstype.float32, | "dtype": mstype.float32, | ||||
| "value": None} | "value": None} | ||||
| return out | return out | ||||
| class RandomCategorical(PrimitiveWithInfer): | |||||
| """ | |||||
| Generates random samples from a given categorical distribution tensor. | |||||
| Args: | |||||
| dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16, | |||||
| mindspore.int32, mindspore.int64]. Default: mindspore.int64. | |||||
| Inputs: | |||||
| - **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes]. | |||||
| - **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed. | |||||
| - **seed** (int) - Random seed. Default: 0. | |||||
| Outputs: | |||||
| - **output** (Tensor) - The output Tensor with shape [batch_size, num_samples]. | |||||
| Examples: | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self, num_sample): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.random_categorical = P.RandomCategorical(mindspore.int64) | |||||
| >>> self.num_sample = num_sample | |||||
| >>> def construct(self, logits, seed=0): | |||||
| >>> return self.random_categorical(logits, self.num_sample, seed) | |||||
| >>> | |||||
| >>> x = np.random.random((10, 5)).astype(np.float32) | |||||
| >>> net = Net(8) | |||||
| >>> output = net(Tensor(x)) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, dtype=mstype.int64): | |||||
| """Init RandomCategorical""" | |||||
| self.dtype = dtype | |||||
| valid_values = (mstype.int32, mstype.int16, mstype.int64) | |||||
| validator.check_type_name("dtype", dtype, valid_values, self.name) | |||||
| self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], | |||||
| outputs=['output']) | |||||
| def __infer__(self, logits, num_samples, seed): | |||||
| logits_dtype = logits['dtype'] | |||||
| valid_types = (mstype.float32, mstype.float16, mstype.float64) | |||||
| validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) | |||||
| num_samples_v = num_samples['value'] | |||||
| seed_v = seed['value'] | |||||
| validator.check_value_type('num_samples', num_samples_v, (int,), self.name) | |||||
| validator.check_value_type('seed', seed_v, (int,), self.name) | |||||
| validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name) | |||||
| x_shape = list(logits['shape']) | |||||
| if len(x_shape) != 2: | |||||
| raise ValueError("RandomCategorical shape should be 2-dimension.") | |||||
| ndim = len(x_shape) - 1 | |||||
| x_shape[ndim] = num_samples_v | |||||
| return {'shape': (x_shape), | |||||
| 'dtype': (self.dtype), | |||||
| 'value': None} | |||||
| @@ -0,0 +1,75 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, x, dtype): | |||||
| super(Net, self).__init__() | |||||
| self.cast = P.Cast() | |||||
| self.x = x | |||||
| self.dtype = dtype | |||||
| def construct(self): | |||||
| return self.cast(self.x, self.dtype) | |||||
| def test_net_f32_bool(): | |||||
| x = np.random.randn(3,4).astype(np.float32) | |||||
| x[:,1] = 0 | |||||
| net = Net(Tensor(x), mstype.bool_) | |||||
| output = net() | |||||
| print(output.asnumpy()) | |||||
| print(Tensor(x).dtype) | |||||
| print(output.dtype) | |||||
| def test_net_f16_bool(): | |||||
| x = np.random.randn(3,4).astype(np.float16) | |||||
| x[:,1] = 0 | |||||
| net = Net(Tensor(x), mstype.bool_) | |||||
| output = net() | |||||
| print(output.asnumpy()) | |||||
| print(Tensor(x).dtype) | |||||
| print(output.dtype) | |||||
| def test_net_f64_bool(): | |||||
| x = np.random.randn(3,4).astype(np.float64) | |||||
| x[:,1] = 0 | |||||
| net = Net(Tensor(x), mstype.bool_) | |||||
| output = net() | |||||
| print(output.asnumpy()) | |||||
| print(Tensor(x).dtype) | |||||
| print(output.dtype) | |||||
| def test_net_int16_float16(): | |||||
| x = np.random.randint(-512, 512, size=(3,4)).astype(np.int16) | |||||
| net = Net(Tensor(x), mstype.float16) | |||||
| output = net() | |||||
| print(output.asnumpy()) | |||||
| print(Tensor(x).dtype) | |||||
| print(output.dtype) | |||||
| def test_net_int64_float16(): | |||||
| x = np.random.randint(-512, 512, size=(3,4)).astype(np.int64) | |||||
| net = Net(Tensor(x), mstype.float16) | |||||
| output = net() | |||||
| print(output.asnumpy()) | |||||
| print(Tensor(x).dtype) | |||||
| print(output.dtype) | |||||
| @@ -127,7 +127,6 @@ def test_net_int64(): | |||||
| print(output.asnumpy()) | print(output.asnumpy()) | ||||
| assert np.array_equal(output.asnumpy(), np.stack([x, y], axis)) | assert np.array_equal(output.asnumpy(), np.stack([x, y], axis)) | ||||
| def test_net_uint64(): | def test_net_uint64(): | ||||
| x = np.random.randn(3, 5, 4).astype(np.uint64) | x = np.random.randn(3, 5, 4).astype(np.uint64) | ||||
| y = np.random.randn(3, 5, 4).astype(np.uint64) | y = np.random.randn(3, 5, 4).astype(np.uint64) | ||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import ms_function | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, num_sample): | |||||
| super(Net, self).__init__() | |||||
| self.random_categorical = P.RandomCategorical(mindspore.int64) | |||||
| self.num_sample = num_sample | |||||
| def construct(self, logits, seed=0): | |||||
| return self.random_categorical(logits, self.num_sample, seed) | |||||
| def test_net(): | |||||
| x = np.random.random((10, 5)).astype(np.float32) | |||||
| net = Net(8) | |||||
| output = net(Tensor(x)) | |||||
| print(x) | |||||
| print(output.asnumpy()) | |||||
| print(output.dtype()) | |||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore as ms | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import ms_function | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.rnnt_loss = P.RNNTLoss(blank_label=0) | |||||
| def construct(self, acts, labels, act_lens, label_lens): | |||||
| return self.rnnt_loss(acts, labels, act_lens, label_lens) | |||||
| def test_net(): | |||||
| B, T, U, V = 1, 2, 3, 5 | |||||
| acts = np.random.random((B, T, U, V)).astype(np.float32) | |||||
| labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32) | |||||
| input_length = np.array([T] * B).astype(np.int32) | |||||
| label_length = np.array([len(l) for l in labels]).astype(np.int32) | |||||
| rnnt_loss = Net() | |||||
| costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) | |||||
| print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) | |||||
| print(costs.asnumpy()) | |||||
| print(grads.asnumpy()) | |||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, offset): | |||||
| super(Net, self).__init__() | |||||
| self.embedding = P.EmbeddingLookup() | |||||
| self.offset = offset | |||||
| def construct(self, param, index): | |||||
| return self.embedding(param, index, self.offset) | |||||
| def test_embedding_lookup_sparse(): | |||||
| params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.int32) | |||||
| indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32) | |||||
| offset = 4 | |||||
| embedding = Net(offset) | |||||
| out = embedding(params, indices) | |||||
| assert(out.asnumpy() == [[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).all() | |||||
| @@ -29,7 +29,6 @@ context.set_context(mode=context.GRAPH_MODE) | |||||
| class LeNet5(nn.Cell): | class LeNet5(nn.Cell): | ||||
| """ LeNet5 definition """ | """ LeNet5 definition """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(LeNet5, self).__init__() | super(LeNet5, self).__init__() | ||||
| self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid') | self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid') | ||||