Merge pull request !25026 from fangzehua/dynamic_mae_1014tags/v1.6.0
| @@ -27,7 +27,7 @@ template <typename T> | |||
| void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| node_wpt_ = kernel_node; | |||
| cnode_ptr_ = kernel_node; | |||
| axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS)); | |||
| auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (axis_ < 0) { | |||
| @@ -38,18 +38,18 @@ void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| template <typename T> | |||
| bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto node = node_wpt_.lock(); | |||
| if (!node) { | |||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||
| auto node_ = cnode_ptr_.lock(); | |||
| if (!node_) { | |||
| MS_LOG(EXCEPTION) << "cnode_ptr_ is expired."; | |||
| } | |||
| const size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| const size_t input_num = AnfAlgo::GetInputTensorNum(node_); | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOutputsNum, kernel_name_); | |||
| std::vector<std::vector<size_t>> input_flat_shape_list; | |||
| input_flat_shape_list.reserve(input_num); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(node, i); | |||
| auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(node_, i); | |||
| auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_); | |||
| (void)input_flat_shape_list.emplace_back(flat_shape); | |||
| } | |||
| @@ -38,7 +38,6 @@ class ConcatCPUKernel : public CPUKernel { | |||
| private: | |||
| int axis_{0}; | |||
| CNodeWeakPtr node_wpt_; | |||
| }; | |||
| MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, float); | |||
| @@ -0,0 +1,93 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/concat_offset_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kConcatOffsetOutputNum = 1; | |||
| constexpr size_t kConcatOffsetOutputShapeSize = 2; | |||
| } // namespace | |||
| template <typename T> | |||
| void ConcatOffsetCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| cnode_ptr_ = kernel_node; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| int64_t axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS); | |||
| auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (axis < 0) { | |||
| axis_ = LongToSize(axis + input_1_shape.size()); | |||
| } else { | |||
| axis_ = LongToSize(axis); | |||
| } | |||
| if (axis_ >= input_1_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "axis should less input shape size, but got axis: " << axis_ | |||
| << ", input shape size: " << input_1_shape.size(); | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool ConcatOffsetCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_); | |||
| auto node_ = cnode_ptr_.lock(); | |||
| if (!node_) { | |||
| MS_LOG(EXCEPTION) << "cnode_ptr_ is expired."; | |||
| } | |||
| auto output_addr = reinterpret_cast<int64_t *>(outputs[0]->addr); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node_); | |||
| std::vector<size_t> offset{0}; | |||
| size_t all_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0)[axis_]; | |||
| // cal offset | |||
| for (size_t i = 1; i < input_num; i++) { | |||
| auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(node_, i); | |||
| if (axis_ >= input_shape_i.size()) { | |||
| MS_LOG(EXCEPTION) << "axis should less input shape size, but got axis: " << axis_ | |||
| << ", input shape size: " << input_shape_i.size(); | |||
| } | |||
| offset.emplace_back(all_shape); | |||
| all_shape += input_shape_i[axis_]; | |||
| } | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(node_, 0); | |||
| if (output_shape.size() != kConcatOffsetOutputShapeSize) { | |||
| MS_LOG(EXCEPTION) << "The length of output_shape must be " << kConcatOffsetOutputShapeSize | |||
| << ", but got:" << output_shape.size(); | |||
| } | |||
| if (output_shape[0] != input_num) { | |||
| MS_LOG(EXCEPTION) << "ConcatOffset output_shape[0] must equal to input_num, but got " << output_shape[0]; | |||
| } | |||
| size_t rank = output_shape[1]; | |||
| size_t idx = 0; | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| for (size_t j = 0; j < rank; ++j) { | |||
| if (j == axis_) { | |||
| output_addr[idx] = offset[i]; | |||
| } else { | |||
| output_addr[idx] = 0; | |||
| } | |||
| idx++; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_OFFSET_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_OFFSET_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class ConcatOffsetCPUKernel : public CPUKernel { | |||
| public: | |||
| ConcatOffsetCPUKernel() = default; | |||
| ~ConcatOffsetCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| size_t axis_{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, int8_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, int16_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, int32_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, int64_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, uint8_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, uint16_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, uint32_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, uint64_t) | |||
| MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCPUKernel, bool) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_OFFSET_CPU_KERNEL_H_ | |||
| @@ -25,8 +25,25 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void CpuDynamicKernel::UpdateArgs() { | |||
| if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) { | |||
| return; | |||
| } | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "Update Args: " << cnode->fullname_with_scope(); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto cpu_kernel_mod = dynamic_cast<CPUKernel *>(kernel_mod); | |||
| MS_EXCEPTION_IF_NULL(cpu_kernel_mod); | |||
| cpu_kernel_mod->Init(cnode); | |||
| } | |||
| void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index); | |||
| @@ -103,6 +103,16 @@ struct ParallelSearchInfo { | |||
| size_t search_count{0}; | |||
| }; | |||
| class CpuDynamicKernel : public device::DynamicKernel { | |||
| public: | |||
| explicit CpuDynamicKernel(const CNodePtr &cnode_ptr) : DynamicKernel(nullptr, cnode_ptr) {} | |||
| ~CpuDynamicKernel() = default; | |||
| void UpdateArgs() override; | |||
| void PostExecute() final { MS_LOG(EXCEPTION) << "`PostExecute()` should not invoked with cpu backend"; }; | |||
| void Execute() final { MS_LOG(EXCEPTION) << "`Execute()` should not invoked with cpu backend"; } | |||
| }; | |||
| class CPUKernel : public kernel::KernelMod { | |||
| public: | |||
| CPUKernel() = default; | |||
| @@ -119,12 +129,19 @@ class CPUKernel : public kernel::KernelMod { | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| void SetCNodePtr(const CNodePtr &kernel_node) { cnode_ptr_ = kernel_node; } | |||
| const CNodeWeakPtr &GetCNodePtr() { return cnode_ptr_; } | |||
| void InitDynamicKernel(const CNodePtr &cnode_ptr) { dynamic_kernel_ = std::make_shared<CpuDynamicKernel>(cnode_ptr); } | |||
| device::DynamicKernelPtr DynamicKernel() const { return dynamic_kernel_; } | |||
| protected: | |||
| virtual void InitInputOutputSize(const CNodePtr &kernel_node); | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| ParallelSearchInfo parallel_search_info_; | |||
| CNodeWeakPtr cnode_ptr_; | |||
| device::DynamicKernelPtr dynamic_kernel_; | |||
| template <typename T> | |||
| inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) { | |||
| @@ -26,9 +26,9 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose", "Unpack", "AddN"}; | |||
| } // namespace | |||
| const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose", | |||
| "Unpack", "AddN", "ConcatOffset", "DynamicStitch"}; | |||
| } | |||
| CPUKernelFactory &CPUKernelFactory::GetInstance() { | |||
| static CPUKernelFactory instance; | |||
| return instance; | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/dynamic_shape_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kDynamicShapeOutputNum = 1; | |||
| } // namespace | |||
| template <typename T> | |||
| void DynamicShapeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| cnode_ptr_ = kernel_node; | |||
| size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_count != 1) { | |||
| MS_LOG(EXCEPTION) << input_count << " arguments were provided, but DynamicShapeCPUKernel expects 1."; | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool DynamicShapeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDynamicShapeOutputNum, kernel_name_); | |||
| auto node_ = cnode_ptr_.lock(); | |||
| if (node_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "cnode_ptr_ is expired."; | |||
| } | |||
| auto output_addr = reinterpret_cast<int64_t *>(outputs[0]->addr); | |||
| std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(node_, 0); | |||
| if (output_shape.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "The length of output_shape must be 1, but got:" << output_shape.size(); | |||
| } | |||
| if (output_shape[0] != input_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "DynamicShape output_shape[0] must equal to the size of input_shape, but got " | |||
| << output_shape[0]; | |||
| } | |||
| for (size_t i = 0; i < output_shape[0]; ++i) { | |||
| output_addr[i] = input_shape[i]; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_SHAPE_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_SHAPE_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class DynamicShapeCPUKernel : public CPUKernel { | |||
| public: | |||
| DynamicShapeCPUKernel() = default; | |||
| ~DynamicShapeCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, int8_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, int16_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, int32_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, int64_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, uint8_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, uint16_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, uint32_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, uint64_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCPUKernel, bool) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_SHAPE_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/dynamic_stitch_cpu_kernel.h" | |||
| #include <functional> | |||
| #include <algorithm> | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kDynamicStitchOutputNum = 1; | |||
| } // namespace | |||
| template <typename T> | |||
| void DynamicStitchCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| cnode_ptr_ = kernel_node; | |||
| } | |||
| size_t GetShapeSize(const std::vector<size_t> &shape) { | |||
| return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>()); | |||
| } | |||
| template <typename T> | |||
| void DynamicStitchCPUKernel<T>::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDynamicStitchOutputNum, kernel_name_); | |||
| auto node_ = cnode_ptr_.lock(); | |||
| int first_dim_size = 0; | |||
| size_t input_count = AnfAlgo::GetInputTensorNum(node_); | |||
| input_tuple_num_ = input_count / 2; | |||
| int max_index = -1; | |||
| for (size_t i = 0; i < input_tuple_num_; ++i) { | |||
| auto indice = reinterpret_cast<int32_t *>(inputs[i]->addr); | |||
| auto shape_size = GetShapeSize(AnfAlgo::GetPrevNodeOutputInferShape(node_, i)); | |||
| for (size_t j = 0; j < shape_size; ++j) { | |||
| max_index = std::max(indice[j], max_index); | |||
| } | |||
| } | |||
| first_dim_size = max_index + 1; | |||
| std::vector<TypeId> dtypes{AnfAlgo::GetOutputDeviceDataType(node_, 0)}; | |||
| std::vector<size_t> result_shape{IntToSize(first_dim_size)}; | |||
| auto data0_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, input_tuple_num_); | |||
| auto indice_dims = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0).size(); | |||
| for (size_t d = indice_dims; d < data0_shape.size(); ++d) { | |||
| result_shape.emplace_back(data0_shape[d]); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {result_shape}, node_.get()); | |||
| size_t num_out_dims = 2; | |||
| std::vector<size_t> out_dims(num_out_dims, 0); | |||
| for (size_t out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) { | |||
| out_dims[out_dim] = out_dim >= result_shape.size() ? 1 : result_shape[out_dim]; | |||
| } | |||
| for (size_t in_dim = num_out_dims; in_dim < result_shape.size(); ++in_dim) { | |||
| out_dims[num_out_dims - 1] *= result_shape[in_dim]; | |||
| } | |||
| auto merged = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t slice_size = out_dims[1]; | |||
| size_t slice_bytes = slice_size * sizeof(T); | |||
| for (size_t i = 0; i < input_tuple_num_; i++) { | |||
| auto indice = reinterpret_cast<int32_t *>(inputs[i]->addr); | |||
| auto data = reinterpret_cast<T *>(inputs[i + input_tuple_num_]->addr); | |||
| auto shape_size = GetShapeSize(AnfAlgo::GetPrevNodeOutputInferShape(node_, i)); | |||
| for (size_t j = 0; j < shape_size; ++j) { | |||
| auto ret = memcpy_s(merged + indice[j] * slice_size, slice_bytes, data + j * slice_size, slice_bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool DynamicStitchCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| LaunchKernel(inputs, outputs); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_STITCH_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_STITCH_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class DynamicStitchCPUKernel : public CPUKernel { | |||
| public: | |||
| DynamicStitchCPUKernel() = default; | |||
| ~DynamicStitchCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| size_t input_tuple_num_{1}; | |||
| }; | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, int8_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, int16_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, int32_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, int64_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, uint8_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, uint16_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, uint32_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, uint64_t) | |||
| MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCPUKernel, bool) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_STITCH_CPU_KERNEL_H_ | |||
| @@ -33,6 +33,7 @@ template <typename T> | |||
| void ReduceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| axis_.clear(); | |||
| input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| @@ -80,6 +81,9 @@ void ReduceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| } else if (kernel_name_ == prim::kPrimReduceMean->name()) { | |||
| reduce_type_ = kReduceMean; | |||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; }; | |||
| } else if (kernel_name == "ReduceProd") { | |||
| reduce_type_ = kReduceProd; | |||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out *= input[pos]; }; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name_; | |||
| } | |||
| @@ -38,7 +38,7 @@ class ReduceCPUKernel : public CPUKernel { | |||
| private: | |||
| void AccelerateLongVector(T *input_addr, T *output_addr, size_t input_size); | |||
| enum ReduceType { kReduceAll, kReduceAny, kReduceMax, kReduceMin, kReduceSum, kReduceMean }; | |||
| enum ReduceType { kReduceAll, kReduceAny, kReduceMax, kReduceMin, kReduceSum, kReduceMean, kReduceProd }; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<int64_t> axis_; | |||
| ReduceType reduce_type_{kReduceAll}; | |||
| @@ -66,6 +66,11 @@ MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, double); | |||
| MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int32_t); | |||
| MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int64_t); | |||
| MS_REG_CPU_KERNEL_T(ReduceProd, KernelAttr(), ReduceCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T(ReduceProd, KernelAttr(), ReduceCPUKernel, double); | |||
| MS_REG_CPU_KERNEL_T(ReduceProd, KernelAttr(), ReduceCPUKernel, int32_t); | |||
| MS_REG_CPU_KERNEL_T(ReduceProd, KernelAttr(), ReduceCPUKernel, int64_t); | |||
| MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr(), ReduceCPUKernel, bool); | |||
| MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr(), ReduceCPUKernel, bool); | |||
| @@ -26,34 +26,25 @@ constexpr size_t kReshapeOutputsNum = 1; | |||
| void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| node_wpt_ = kernel_node; | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| x_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| type_size_ = GetTypeByte(TypeIdToType(x_data_type_)); | |||
| } | |||
| bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), kReshapeInputsNum, kernel_name_); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs is empty"; | |||
| } | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kReshapeOutputsNum, kernel_name_); | |||
| if (inputs[0]->size != outputs[0]->size) { | |||
| return false; | |||
| MS_LOG(EXCEPTION) << "Input size: {" << inputs[0]->size << "} is not equal to output size: {" << outputs[0]->size | |||
| << "}"; | |||
| } | |||
| if (inputs[0]->addr == outputs[0]->addr) { | |||
| return true; | |||
| } | |||
| auto node = node_wpt_.lock(); | |||
| if (!node) { | |||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||
| } | |||
| auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||
| size_t mem_bits = type_size_; | |||
| for (size_t i = 0; i < x_shape.size(); ++i) { | |||
| mem_bits *= x_shape[i]; | |||
| } | |||
| auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); | |||
| if (ret != EOK) { | |||
| size_t copy_size = outputs[0]->size; | |||
| auto ret = memcpy_s(outputs[0]->addr, copy_size, inputs[0]->addr, copy_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||
| } | |||
| return true; | |||
| @@ -33,11 +33,6 @@ class ReshapeCPUKernel : public CPUKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| CNodeWeakPtr node_wpt_; | |||
| TypeId x_data_type_{kNumberTypeInt32}; | |||
| size_t type_size_ = 4; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReshapeCPUKernel); | |||
| @@ -58,6 +53,13 @@ MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutpu | |||
| ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||
| ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Reshape, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| @@ -24,6 +24,7 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kSliceInputsNum = 1; | |||
| constexpr size_t kSliceDynamicInputNum = 3; | |||
| constexpr size_t kSliceOutputsNum = 1; | |||
| } // namespace | |||
| @@ -38,6 +39,7 @@ int NormalizeBeginPos(int begin_pos, int dim_len) { | |||
| void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| cnode_ptr_ = kernel_node; | |||
| static const std::unordered_map<TypeId, int> type_size_map = {{kNumberTypeBool, sizeof(bool)}, | |||
| {kNumberTypeInt32, sizeof(int)}, | |||
| {kNumberTypeFloat32, sizeof(float)}, | |||
| @@ -46,12 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| if (input_shape.size() > DIMENSION_8D || input_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "Slice only support 1D to 8D input tensor, but got " << input_shape.size() << "D."; | |||
| } | |||
| auto size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE); | |||
| auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN); | |||
| if (begin.size() != input_shape.size() || size.size() != input_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "Slice requires the length of begin and size must be equal to input dimension."; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| // begin and size are const input | |||
| if (input_num == 1) { | |||
| auto size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE); | |||
| auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN); | |||
| if (begin.size() != input_shape.size() || size.size() != input_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "Slice requires the length of begin and size must be equal to input dimension."; | |||
| } | |||
| InitSliceParam(input_shape, begin, size); | |||
| } | |||
| InitSliceParam(input_shape, begin, size); | |||
| TypeId dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| auto size_pair = type_size_map.find(dtype); | |||
| @@ -102,15 +109,37 @@ void SliceSimpleDim2(const int8_t *input, int8_t *output, const SliceParameter * | |||
| bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSliceInputsNum, kernel_name_); | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSliceOutputsNum, kernel_name_); | |||
| if (outputs[0]->size == 0) { | |||
| MS_LOG(WARNING) << "Slice output memory size should be greater than 0, but got 0."; | |||
| return true; | |||
| if (inputs.size() != kSliceInputsNum && inputs.size() != kSliceDynamicInputNum) { | |||
| MS_LOG(EXCEPTION) << "Input num should be " << kSliceInputsNum << " or " << kSliceDynamicInputNum << ", but got " | |||
| << inputs.size(); | |||
| } | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSliceOutputsNum, kernel_name_); | |||
| auto input_addr = inputs[0]->addr; | |||
| auto output_addr = outputs[0]->addr; | |||
| if (inputs.size() == kSliceDynamicInputNum) { | |||
| auto cnode = cnode_ptr_.lock(); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| auto begin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1); | |||
| auto size_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2); | |||
| if (begin_shape.size() != 1 || size_shape.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Slice requires the dimension of begin and size must be equal to 1."; | |||
| } | |||
| if (begin_shape[0] != input_shape.size() || size_shape[0] != input_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "Slice requires the length of begin and size must be equal to input dimension."; | |||
| } | |||
| auto begin_ptr = reinterpret_cast<int32_t *>(inputs[1]->addr); | |||
| auto size_ptr = reinterpret_cast<int32_t *>(inputs[2]->addr); | |||
| std::vector<int64_t> begin{begin_ptr, begin_ptr + begin_shape[0]}; | |||
| std::vector<int64_t> size{size_ptr, size_ptr + size_shape[0]}; | |||
| for (size_t i = 0; i < begin.size(); ++i) { | |||
| if (input_shape[i] < IntToSize(begin[i] + size[i])) { | |||
| MS_LOG(EXCEPTION) << "Slice shape can not bigger than origin shape."; | |||
| } | |||
| } | |||
| InitSliceParam(input_shape, begin, size); | |||
| } | |||
| if (origin_dim_size_ == 2) { | |||
| auto task = [this, &input_addr, &output_addr](size_t start, size_t end) { | |||
| auto src = | |||
| @@ -49,6 +49,27 @@ MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput | |||
| SliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| SliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(Slice, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(Slice, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| SliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(Slice, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| SliceCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -16,27 +16,41 @@ | |||
| #include "backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h" | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kSliceGradInputsNum = 2; | |||
| constexpr size_t kStridedSliceGradInputsNum = 1; | |||
| constexpr size_t kSliceGradDynamicInputsNum = 4; | |||
| constexpr size_t kStridedSliceGradDynamicInputsNum = 5; | |||
| constexpr size_t kOutputsNum = 1; | |||
| constexpr size_t kSliceGradMaxInputShapeSize = 4; | |||
| } // namespace | |||
| void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| cnode_ptr_ = kernel_node; | |||
| begin_.clear(); | |||
| size_.clear(); | |||
| strides_.clear(); | |||
| end_.clear(); | |||
| input_element_num_.clear(); | |||
| output_element_num_.clear(); | |||
| input_shape_.clear(); | |||
| output_shape_.clear(); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.empty() || input_shape.size() > kSliceGradMaxInputShapeSize) { | |||
| MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 1-4D."; | |||
| } | |||
| output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() > 4) { | |||
| MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceGradCpuKernel support 4d or lower."; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num == kSliceGradDynamicInputsNum || input_num == kStridedSliceGradDynamicInputsNum) { | |||
| return; | |||
| } | |||
| // in the case that begin, end, size, stride are const value. | |||
| std::vector<int64_t> begin_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN); | |||
| (void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_), | |||
| [](const int64_t &value) { return LongToInt(value); }); | |||
| @@ -71,15 +85,16 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| void SliceGradCPUKernel::ExpandAllMemberDims() { | |||
| auto output_len = output_shape_.size(); | |||
| if (output_len < 4) { | |||
| for (size_t i = 0; i < 4 - output_len; ++i) { | |||
| constexpr size_t expand_dims = 4; | |||
| if (output_len < expand_dims) { | |||
| for (size_t i = 0; i < expand_dims - output_len; ++i) { | |||
| (void)output_shape_.insert(output_shape_.begin(), 1); | |||
| (void)begin_.insert(begin_.begin(), 0); | |||
| (void)strides_.insert(strides_.begin(), 1); | |||
| (void)end_.insert(end_.begin(), 1); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < 4; ++i) { | |||
| for (size_t i = 0; i < expand_dims; ++i) { | |||
| if (SignOfStride(i)) { | |||
| int ax = (end_[i] - begin_[i]) * SignOfStride(i); | |||
| if (ax < 0) { | |||
| @@ -90,11 +105,57 @@ void SliceGradCPUKernel::ExpandAllMemberDims() { | |||
| } | |||
| } | |||
| void SliceGradCPUKernel::InitParams(const std::vector<kernel::AddressPtr> &inputs) { | |||
| auto cnode = cnode_ptr_.lock(); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(cnode); | |||
| auto begin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2); | |||
| auto begin_ptr = reinterpret_cast<int32_t *>(inputs[2]->addr); | |||
| std::vector<int32_t> begin{begin_ptr, begin_ptr + begin_shape[0]}; | |||
| (void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_), | |||
| [](const int32_t &value) { return value; }); | |||
| if (kernel_name == prim::kPrimStridedSliceGrad->name()) { | |||
| auto end_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3); | |||
| auto stride_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 4); | |||
| if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "StridedSliceGrad requires the dimension of begin, end, strides must be equal to 1."; | |||
| } | |||
| auto end_ptr = reinterpret_cast<int32_t *>(inputs[3]->addr); | |||
| auto strides_ptr = reinterpret_cast<int32_t *>(inputs[4]->addr); | |||
| std::vector<int32_t> end{end_ptr, end_ptr + end_shape[0]}; | |||
| std::vector<int32_t> strides{strides_ptr, strides_ptr + stride_shape[0]}; | |||
| (void)std::transform(strides.begin(), strides.end(), std::back_inserter(strides_), | |||
| [](const int32_t &value) { return value; }); | |||
| (void)std::transform(end.begin(), end.end(), std::back_inserter(end_), [](const int32_t &value) { return value; }); | |||
| if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) { | |||
| MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; | |||
| } | |||
| FormatArgs(true); | |||
| } else { | |||
| auto size_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3); | |||
| if (begin_shape.size() != 1 || size_shape.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "SliceGrad requires the dimension of begin, end must be equal to 1."; | |||
| } | |||
| auto size_ptr = reinterpret_cast<int32_t *>(inputs[3]->addr); | |||
| std::vector<int32_t> size{size_ptr, size_ptr + size_shape[0]}; | |||
| (void)std::transform(size.begin(), size.end(), std::back_inserter(size_), | |||
| [](const int32_t &value) { return value; }); | |||
| if (size_.size() != output_shape_.size() || begin_.size() != output_shape_.size()) { | |||
| MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; | |||
| } | |||
| FormatArgs(false); | |||
| } | |||
| ExpandAllMemberDims(); | |||
| CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); | |||
| CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); | |||
| } | |||
| bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| size_t expect_inputs_num = | |||
| kernel_name_ == prim::kPrimSliceGrad->name() ? kSliceGradInputsNum : kStridedSliceGradInputsNum; | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), expect_inputs_num, kernel_name_); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Input is empty"; | |||
| } | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); | |||
| bool ret = true; | |||
| @@ -114,10 +175,14 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, c | |||
| template <typename T> | |||
| bool SliceGradCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) const { | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| // init params for not const inputs | |||
| if (inputs.size() == kSliceGradDynamicInputsNum || inputs.size() == kStridedSliceGradDynamicInputsNum) { | |||
| InitParams(inputs); | |||
| } | |||
| auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret; | |||
| @@ -36,13 +36,12 @@ class SliceGradCPUKernel : public CPUKernel { | |||
| private: | |||
| template <typename T> | |||
| bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) const; | |||
| bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); | |||
| template <typename T> | |||
| void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset, | |||
| const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num, | |||
| int id) const; | |||
| void InitParams(const std::vector<kernel::AddressPtr> &inputs); | |||
| void ExpandAllMemberDims(); | |||
| bool CanCopyMemoryOnAxis(size_t dim) const; | |||
| int SignOfStride(size_t axis) const; | |||
| @@ -73,6 +72,39 @@ MS_REG_CPU_KERNEL( | |||
| SliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(SliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(SliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(SliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(SliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| @@ -81,6 +113,38 @@ MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat64 | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| SliceGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSliceGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| SliceGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -26,6 +26,7 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kStridedSliceInputsNum = 1; | |||
| constexpr size_t kStridedSliceDynamicInputsNum = 4; | |||
| constexpr size_t kStridedSliceOutputsNum = 1; | |||
| } // namespace | |||
| @@ -43,12 +44,18 @@ int NormalizePos(int pos, int dim_len, PosType pos_type) { | |||
| void StridedSliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| cnode_ptr_ = kernel_node; | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (input_shape_.size() > DIMENSION_8D || input_shape_.empty()) { | |||
| MS_LOG(EXCEPTION) << "StridedSlice only support 1D to 8D input tensor, but got " << input_shape_.size() << "D."; | |||
| } | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| return; | |||
| } | |||
| // for begin, end, stride are const input | |||
| auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN); | |||
| auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END); | |||
| auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES); | |||
| @@ -56,9 +63,8 @@ void StridedSliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_LOG(EXCEPTION) | |||
| << "StridedSLice requires the length of begin, stride and end must be equal and less than input dimension."; | |||
| } | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| InitSliceParam(begin, end, stride); | |||
| InitSliceParam(begin, end, stride); | |||
| parallel_ = MatchParallelPattern(); | |||
| if (parallel_) { | |||
| InitParallelParam(); | |||
| @@ -214,14 +220,41 @@ void StridedSliceCPUKernel::ParallelRun(const uint8_t *input_addr, uint8_t *outp | |||
| bool StridedSliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /* workspace */, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), kStridedSliceInputsNum, kernel_name_); | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_); | |||
| if (outputs[0]->size == 0) { | |||
| MS_LOG(WARNING) << "StridedSlice output memory size should be greater than 0, but got 0."; | |||
| return true; | |||
| if (inputs.size() != kStridedSliceInputsNum && inputs.size() != kStridedSliceDynamicInputsNum) { | |||
| MS_LOG(EXCEPTION) << "Input num should be " << kStridedSliceInputsNum << " or " << kStridedSliceDynamicInputsNum | |||
| << ", but got " << inputs.size(); | |||
| } | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_); | |||
| auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr); | |||
| auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| if (input_num == kStridedSliceDynamicInputsNum) { | |||
| // for begin, end, stride are not const input | |||
| auto begin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1); | |||
| auto end_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2); | |||
| auto stride_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3); | |||
| if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "StridedSliceGrad requires the dimension of begin, end, strides must be equal to 1."; | |||
| } | |||
| auto begin_ptr = reinterpret_cast<int64_t *>(inputs[1]->addr); | |||
| auto end_ptr = reinterpret_cast<int64_t *>(inputs[2]->addr); | |||
| auto strides_ptr = reinterpret_cast<int64_t *>(inputs[3]->addr); | |||
| std::vector<int64_t> begin{begin_ptr, begin_ptr + begin_shape[0]}; | |||
| std::vector<int64_t> end{end_ptr, end_ptr + end_shape[0]}; | |||
| std::vector<int64_t> stride{strides_ptr, strides_ptr + stride_shape[0]}; | |||
| if (begin.size() != end.size() || begin.size() != stride.size() || begin.size() > input_shape_.size()) { | |||
| MS_LOG(EXCEPTION) | |||
| << "StridedSLice requires the length of begin, stride and end must be equal and less than input dimension."; | |||
| } | |||
| InitSliceParam(begin, end, stride); | |||
| parallel_ = MatchParallelPattern(); | |||
| if (parallel_) { | |||
| InitParallelParam(); | |||
| } | |||
| } | |||
| int thread_num = slice_param_.op_parameter_.thread_num_; | |||
| if (parallel_ && thread_num >= 2) { | |||
| ParallelRun(input_addr, output_addr, thread_num); | |||
| @@ -66,6 +66,39 @@ MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).Ad | |||
| StridedSliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| StridedSliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSlice, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| StridedSliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSlice, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| StridedSliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSlice, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| StridedSliceCPUKernel); | |||
| MS_REG_CPU_KERNEL(StridedSlice, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| StridedSliceCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -22,10 +22,38 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kTileInputsNum = 1; | |||
| constexpr size_t kTileDynamicInputsNum = 2; | |||
| constexpr size_t kTileOutputsNum = 1; | |||
| } // namespace | |||
| void TileCPUKernel::TileMultipleCompute() { | |||
| size_t ones = multiples_.size() - x_shape_.size(); | |||
| if (ones > 0) { | |||
| for (size_t i = 0; i < ones; ++i) { | |||
| x_shape_.insert(x_shape_.begin(), 1); | |||
| } | |||
| } | |||
| if (x_shape_.size() > MAX_TILE_DIM_SIZE || x_shape_.size() > y_shape_.size()) { | |||
| MS_LOG(EXCEPTION) << "Tile input shape should not be greater than default max size :" << MAX_TILE_DIM_SIZE | |||
| << " and output shape : " << y_shape_.size() << ", but got input shape " << x_shape_.size(); | |||
| } | |||
| input_size_ = 1; | |||
| tile_parameter_.in_dim_ = x_shape_.size(); | |||
| for (int i = 0; i < tile_parameter_.in_dim_; i++) { | |||
| input_size_ *= x_shape_[i]; | |||
| tile_parameter_.in_shape_[i] = x_shape_[i]; | |||
| tile_parameter_.out_shape_[i] = y_shape_[i]; | |||
| } | |||
| int stridex = 1; | |||
| int stridey = 1; | |||
| for (int i = tile_parameter_.in_dim_ - 1; i >= 0; i--) { | |||
| tile_parameter_.in_strides_[i] = stridex; | |||
| tile_parameter_.out_strides_[i] = stridey; | |||
| stridex *= x_shape_[i]; | |||
| stridey *= y_shape_[i]; | |||
| } | |||
| int large_one_multiple_count_ = 0; | |||
| int multiple = 0; | |||
| size_t mul_index = 0; | |||
| @@ -52,45 +80,22 @@ void TileCPUKernel::TileMultipleCompute() { | |||
| void TileCPUKernel::TileTensorParamrInit(const CNodePtr &kernel_node) { | |||
| x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| y_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (x_shape_.size() > MAX_TILE_DIM_SIZE || x_shape_.size() > y_shape_.size()) { | |||
| MS_LOG(EXCEPTION) << "Tile input shape should not be greater than default max size :" << MAX_TILE_DIM_SIZE | |||
| << " and output shape : " << y_shape_.size() << ", but got input shape " << x_shape_.size(); | |||
| } | |||
| std::vector<int64_t> multiples_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "multiples"); | |||
| (void)std::transform(multiples_me.begin(), multiples_me.end(), std::back_inserter(multiples_), | |||
| [](const int64_t &value) { return LongToInt(value); }); | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| size_t ones = multiples_.size() - x_shape_.size(); | |||
| if (ones > 0) { | |||
| for (size_t i = 0; i < ones; ++i) { | |||
| (void)x_shape_.insert(x_shape_.begin(), 1); | |||
| } | |||
| multiples_.clear(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num == kTileInputsNum) { | |||
| std::vector<int64_t> multiples_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "multiples"); | |||
| (void)std::transform(multiples_me.begin(), multiples_me.end(), std::back_inserter(multiples_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| TileMultipleCompute(); | |||
| } | |||
| input_size_ = 1; | |||
| tile_parameter_.in_dim_ = x_shape_.size(); | |||
| for (int i = 0; i < tile_parameter_.in_dim_; i++) { | |||
| input_size_ *= x_shape_[i]; | |||
| tile_parameter_.in_shape_[i] = x_shape_[i]; | |||
| tile_parameter_.out_shape_[i] = y_shape_[i]; | |||
| } | |||
| int stridex = 1; | |||
| int stridey = 1; | |||
| for (int i = tile_parameter_.in_dim_ - 1; i >= 0; i--) { | |||
| tile_parameter_.in_strides_[i] = stridex; | |||
| tile_parameter_.out_strides_[i] = stridey; | |||
| stridex *= x_shape_[i]; | |||
| stridey *= y_shape_[i]; | |||
| } | |||
| TileMultipleCompute(); | |||
| } | |||
| void TileCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| TileTensorParamrInit(kernel_node); | |||
| cnode_ptr_ = kernel_node; | |||
| launch_map_[kNumberTypeInt8] = &TileCPUKernel::LaunchKernel<int8_t>; | |||
| launch_map_[kNumberTypeInt16] = &TileCPUKernel::LaunchKernel<int16_t>; | |||
| @@ -113,7 +118,10 @@ void TileCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| bool TileCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTileInputsNum, kernel_name_); | |||
| if (inputs.size() != kTileInputsNum && inputs.size() != kTileDynamicInputsNum) { | |||
| MS_LOG(EXCEPTION) << "Input num should be " << kTileInputsNum << " or " << kTileDynamicInputsNum << ", but got " | |||
| << inputs.size(); | |||
| } | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTileOutputsNum, kernel_name_); | |||
| launch_func_(this, inputs, outputs); | |||
| return true; | |||
| @@ -123,8 +131,23 @@ template <typename T> | |||
| void TileCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| auto x_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto y_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| tile_parameter_.data_size_ = sizeof(T); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| if (input_num == kTileDynamicInputsNum) { | |||
| auto multiples_addr = reinterpret_cast<int32_t *>(inputs[1]->addr); | |||
| auto multiple_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1); | |||
| size_t multiple_nums = 1; | |||
| for (size_t i = 0; i < multiple_shape.size(); ++i) { | |||
| multiple_nums *= multiple_shape[i]; | |||
| } | |||
| for (size_t i = 0; i < multiple_nums; ++i) { | |||
| multiples_.emplace_back(multiples_addr[i]); | |||
| } | |||
| TileMultipleCompute(); | |||
| } | |||
| tile_parameter_.data_size_ = sizeof(T); | |||
| if (one_dim_tile_) { | |||
| auto task = [&x_addr, &y_addr, this](size_t start, size_t end) { | |||
| TileSimple(x_addr, y_addr, start, end, &tile_parameter_); | |||
| @@ -57,7 +57,19 @@ class TileCPUKernel : public CPUKernel { | |||
| size_t input_size_{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr(), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), TileCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ | |||
| @@ -36,7 +36,7 @@ namespace { | |||
| constexpr unsigned int kLstmReserveIndex = 3; | |||
| AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type) { | |||
| const abstract::BaseShapePtr &origin_shape, const TypeId &origin_type) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::string input_format = format; | |||
| std::string output_format = format; | |||
| @@ -52,10 +52,29 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| cast->set_kernel_info(kernel_info); | |||
| } | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | |||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); | |||
| if (origin_shape->IsDynamic()) { | |||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cast); | |||
| AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cast); | |||
| AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cast); | |||
| } | |||
| AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(output_type), cast); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | |||
| AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, cast.get()); | |||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | |||
| std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kCastOpName, cast); | |||
| if (cpu_kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Operator[Cast] " << cast->kernel_info() << " is not support."; | |||
| } | |||
| try { | |||
| cpu_kernel->Init(cast); | |||
| cpu_kernel->InitDynamicKernel(cast); | |||
| auto cpu_dynamic_kernel = cpu_kernel->DynamicKernel(); | |||
| MS_EXCEPTION_IF_NULL(cpu_dynamic_kernel); | |||
| cpu_dynamic_kernel->Initialize(); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(EXCEPTION) << e.what() << "\nTrace: " << trace::DumpSourceLines(cast); | |||
| } | |||
| AnfAlgo::SetKernelMod(cpu_kernel, cast.get()); | |||
| return cast; | |||
| } | |||
| @@ -74,7 +93,7 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| continue; | |||
| } | |||
| const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); | |||
| const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); | |||
| const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second); | |||
| if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) { | |||
| auto cast = | |||
| @@ -107,8 +126,8 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn | |||
| } | |||
| auto used_node_index = static_cast<size_t>(used_node_list->at(j).second - 1); | |||
| auto cur_input = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(used_node), used_node_index); | |||
| const std::vector<size_t> origin_shape = | |||
| AnfAlgo::GetPrevNodeOutputInferShape(utils::cast<CNodePtr>(used_node), i); | |||
| const abstract::BaseShapePtr origin_shape = | |||
| AnfAlgo::GetPrevNodeOutputDetailShape(utils::cast<CNodePtr>(used_node), i); | |||
| auto cast = | |||
| AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, device_type, infer_type, origin_shape, infer_type); | |||
| MS_EXCEPTION_IF_NULL(cast); | |||
| @@ -56,6 +56,11 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An | |||
| MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); | |||
| return nullptr; | |||
| } | |||
| } else if (device == kCPUDevice) { | |||
| if (DynamicShapeConstInputToAttrCPU.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttrCPU.end()) { | |||
| MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); | |||
| return nullptr; | |||
| } | |||
| } else { | |||
| if (DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) { | |||
| MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); | |||
| @@ -104,14 +104,16 @@ class TwoReshapeEliminater : public AnfVisitor { | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (IsPrimitiveCNode(node, prim::kPrimReshape)) { | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| // {PrimReshape, X, Y} | |||
| if (inputs.size() != 3) { | |||
| return; | |||
| if (prim_ == nullptr && x_ == nullptr) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimReshape)) { | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| // {PrimReshape, X, Y} | |||
| if (inputs.size() != 3) { | |||
| return; | |||
| } | |||
| prim_ = GetValueNode<PrimitivePtr>(inputs[0]); | |||
| x_ = inputs[1]; | |||
| } | |||
| prim_ = GetValueNode<PrimitivePtr>(inputs[0]); | |||
| x_ = inputs[1]; | |||
| } else { | |||
| shape_ = node; | |||
| } | |||
| @@ -408,13 +408,11 @@ void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *di | |||
| MS_EXCEPTION_IF_NULL(arg_tensor); | |||
| MS_EXCEPTION_IF_NULL(arg_tensor->shape()); | |||
| (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape(); | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| const auto &min_shape = arg_tensor->shape()->min_shape(); | |||
| const auto &max_shape = arg_tensor->shape()->max_shape(); | |||
| if (!min_shape.empty() && !max_shape.empty()) { | |||
| (*dic)[ATTR_MIN_SHAPE] = min_shape; | |||
| (*dic)[ATTR_MAX_SHAPE] = max_shape; | |||
| } | |||
| const auto &min_shape = arg_tensor->shape()->min_shape(); | |||
| const auto &max_shape = arg_tensor->shape()->max_shape(); | |||
| if (!min_shape.empty() && !max_shape.empty()) { | |||
| (*dic)[ATTR_MIN_SHAPE] = min_shape; | |||
| (*dic)[ATTR_MAX_SHAPE] = max_shape; | |||
| } | |||
| auto min_value = arg_tensor->get_min_value(); | |||
| @@ -69,8 +69,7 @@ using OpExecInfoPtr = std::shared_ptr<OpExecInfo>; | |||
| const std::set<std::string> ignore_infer_prim = {"mixed_precision_cast"}; | |||
| const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"}; | |||
| const std::set<std::string> dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup", | |||
| "Transpose"}; | |||
| const std::set<std::string> dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "EmbeddingLookup", "Transpose"}; | |||
| } // namespace pynative | |||
| } // namespace mindspore | |||
| @@ -162,11 +162,24 @@ int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vect | |||
| void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { | |||
| MS_EXCEPTION_IF_NULL(kernel_attr); | |||
| TypeId input_dtype = kernel_attr->GetInputAttr(0).first; | |||
| size_t attr_num = kernel_attr->GetInputSize(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t i = 1; i < input_num; ++i) { | |||
| kernel_attr->AddInputAttr(input_dtype); | |||
| if (attr_num == 0) { | |||
| MS_LOG(EXCEPTION) << "Input size is empty"; | |||
| return; // To pass the CI Check_Cppcheck | |||
| } | |||
| // Only support one dynamic input like Concat or | |||
| // many dynamic input but each input has same number like DynamicStitch | |||
| std::string format = kOpFormat_DEFAULT; | |||
| std::vector<DataType> attr_list; | |||
| size_t each_attr_input_num = input_num / attr_num; | |||
| for (size_t i = 0; i < attr_num; ++i) { | |||
| TypeId input_dtype = kernel_attr->GetInputAttr(i).first; | |||
| for (size_t j = 0; j < each_attr_input_num; ++j) { | |||
| (void)attr_list.emplace_back(input_dtype, format); | |||
| } | |||
| } | |||
| kernel_attr->SetInputAttrList(attr_list); | |||
| TypeId output_dtype = kernel_attr->GetOutputAttr(0).first; | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| @@ -28,13 +28,14 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| using DataType = std::pair<TypeId, std::string>; | |||
| void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | |||
| // Indicate whether the kernel input/output number are variable. | |||
| bool IsDynamicParamKernel(const std::string &op_name); | |||
| class KernelAttr { | |||
| public: | |||
| using DataType = std::pair<TypeId, std::string>; | |||
| KernelAttr() : all_same_(0) {} | |||
| ~KernelAttr() = default; | |||
| @@ -56,6 +57,9 @@ class KernelAttr { | |||
| const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } | |||
| const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } | |||
| bool GetAllSame() const { return all_same_; } | |||
| void SetInputAttrList(const std::vector<DataType> &addr_list) { | |||
| input_type_.assign(addr_list.begin(), addr_list.end()); | |||
| } | |||
| size_t GetInputSize() const { return input_type_.size(); } | |||
| size_t GetOutputSize() const { return output_type_.size(); } | |||
| @@ -107,6 +107,10 @@ void KernelActor::RunOpControlWithInputTensor(AID *const input_control, OpContex | |||
| PushInputDeviceTensor(input_tensors); | |||
| // When all the inputs are collected, then allocate memory and callback launch. | |||
| if (CheckRunningCondition(context)) { | |||
| if (is_dynamic_shape_) { | |||
| device_contexts_[0]->UpdateDynamicShape(kernel_); | |||
| } | |||
| FetchOutputDeviceTensor(); | |||
| if (memory_alloc_list_.size() > 0) { | |||
| SendMemoryAllocReq(context); | |||
| @@ -365,6 +369,11 @@ void KernelActor::PreLaunchKernel(OpContext<DeviceTensor> *) { | |||
| } | |||
| void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) { | |||
| // The size of output address may be changed in dynamic shape scenario. | |||
| if (is_dynamic_shape_) { | |||
| UpdateOutputAddrSize(); | |||
| } | |||
| running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_); | |||
| // The input is invalid and needs to be erased when finish kernel launch. | |||
| @@ -383,6 +392,18 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) { | |||
| } | |||
| } | |||
| void KernelActor::UpdateOutputAddrSize() { | |||
| auto &output_addresses = kernel_info_->output_address_list(); | |||
| for (size_t i = 0; i < output_addresses.size(); ++i) { | |||
| auto output_address = output_addresses[i].get(); | |||
| MS_EXCEPTION_IF_NULL(output_address); | |||
| auto output_addr_size = AnfAlgo::GetOutputTensorMemSize(kernel_, i); | |||
| if (output_addr_size != output_address->GetSize()) { | |||
| output_address->SetSize(output_addr_size); | |||
| } | |||
| } | |||
| } | |||
| void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const { | |||
| if (recorder_aid_ != nullptr) { | |||
| MS_EXCEPTION_IF_NULL(kernel_); | |||
| @@ -94,6 +94,10 @@ class KernelActor : public DebugAwareActor { | |||
| // The processing after kernel launch: 1.erase input, 2.free memory, 3.send output. | |||
| void PostLaunchKernel(OpContext<DeviceTensor> *const context); | |||
| // The size of output address may be changed in dynamic shape scenario, for example, the output shape of operator | |||
| // 'Unique' will change after PostExecute, the output address size should update. | |||
| void UpdateOutputAddrSize(); | |||
| // The info of kernel. | |||
| CNodePtr kernel_; | |||
| KernelInfo *kernel_info_; | |||
| @@ -215,6 +215,10 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const { | |||
| #endif | |||
| cpu_kernel->Init(node); | |||
| cpu_kernel->InitDynamicKernel(node); | |||
| auto cpu_dynamic_kernel = cpu_kernel->DynamicKernel(); | |||
| MS_EXCEPTION_IF_NULL(cpu_dynamic_kernel); | |||
| cpu_dynamic_kernel->Initialize(); | |||
| AnfAlgo::SetKernelMod(cpu_kernel, node.get()); | |||
| } | |||
| #ifdef ENABLE_AKG | |||
| @@ -223,6 +227,25 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const { | |||
| #endif | |||
| } | |||
| void CPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { | |||
| MS_LOG(EXCEPTION) << "Akg kernels do not support dynamic shape by now."; | |||
| } | |||
| kernel::CPUKernel *cpu_kernel = dynamic_cast<kernel::CPUKernel *>(kernel_mod); | |||
| MS_EXCEPTION_IF_NULL(cpu_kernel); | |||
| device::DynamicKernelPtr dynamic_kernel = cpu_kernel->DynamicKernel(); | |||
| MS_EXCEPTION_IF_NULL(dynamic_kernel); | |||
| dynamic_kernel->InferShape(); | |||
| dynamic_kernel->UpdateArgs(); | |||
| } | |||
| void CPUDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // Remove reorder after PS feature finish adapting push/pull in auto_monad. | |||
| @@ -49,6 +49,7 @@ class CPUDeviceContext : public DeviceContext { | |||
| void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override; | |||
| void CreateKernel(const std::vector<CNodePtr> &nodes) const override; | |||
| void UpdateDynamicShape(const CNodePtr &kernel) const override; | |||
| void PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const override; | |||
| @@ -360,7 +360,8 @@ void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const { | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER); | |||
| bool is_pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode; | |||
| if (is_pynative_infer || is_pynative_mode) { | |||
| std::vector<int64_t> dynamic_shape_depends = abstract::GetDependsFormMap(kernel); | |||
| if ((is_pynative_infer || is_pynative_mode) && dynamic_shape_depends.empty()) { | |||
| return; | |||
| } | |||
| @@ -685,6 +685,10 @@ const std::set<std::string> DynamicShapeConstInputToAttr = { | |||
| kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kReduceMinOpName, | |||
| kReduceMeanOpName, kReduceMaxOpName, kReduceAllOpName, kReduceAnyOpName, kConcatOpName}; | |||
| const std::set<std::string> DynamicShapeConstInputToAttrCPU = { | |||
| kCastOpName, kExpandDimsOpName, kEmbeddingLookupOpName, kReduceMinOpName, kReduceMeanOpName, | |||
| kReduceMaxOpName, kReduceAllOpName, kReduceAnyOpName, kConcatOpName, kReduceSumOpName}; | |||
| const std::set<std::string> DynamicShapeConstInputToAttrGPU = { | |||
| kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName, | |||
| kReduceMinOpName, kReduceMeanOpName, kReduceMaxOpName, kReduceAllOpName, kReduceAnyOpName, kConcatOpName}; | |||
| @@ -261,6 +261,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConcatOffset(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -820,14 +820,23 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| if (x_min_shape.empty()) { | |||
| x_min_shape = x_shape; | |||
| } | |||
| ValuePtr sh = primitive->GetAttr("shape"); | |||
| MS_EXCEPTION_IF_NULL(sh); | |||
| auto reshape_value_tuple = sh->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(reshape_value_tuple); | |||
| auto reshape_tuple = reshape_value_tuple->value(); | |||
| if (args_spec_list.size() == 2) { | |||
| auto input_value = args_spec_list[1]->BuildValue(); | |||
| if (input_value->isa<tensor::Tensor>()) { | |||
| shape = CheckAndConvertUtils::CheckTensorIntValue("reshape args value", input_value, op_name); | |||
| } else { | |||
| shape = CheckAndConvertUtils::CheckAttrTupleInt("reshape args value", input_value, op_name); | |||
| } | |||
| } else { | |||
| ValuePtr sh = primitive->GetAttr("shape"); | |||
| MS_EXCEPTION_IF_NULL(sh); | |||
| auto reshape_value_tuple = sh->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(reshape_value_tuple); | |||
| auto reshape_tuple = reshape_value_tuple->value(); | |||
| (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| } | |||
| auto max_shape = shape; | |||
| auto min_shape = shape; | |||
| @@ -867,6 +876,15 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| max_shape[index] = infer_max_value; | |||
| } | |||
| int64_t shape_num = 1; | |||
| for (int64_t value : shape) { | |||
| shape_num = LongMulWithOverflowCheck(value, shape_num); | |||
| } | |||
| if (shape_num != x_num) { | |||
| MS_LOG(EXCEPTION) << "The accumulate of x_shape must equal to out_shape, but got x_shape: " << x_shape | |||
| << ", and out_shape: " << shape; | |||
| } | |||
| AbstractTensorPtr ret = | |||
| std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| return ret; | |||
| @@ -980,6 +998,36 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive | |||
| return std::make_shared<AbstractTensor>(kBool, output_shape); | |||
| } | |||
| AbstractBasePtr InferImplConcatOffset(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| const std::string op_name = primitive->name(); | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(EXCEPTION) << "args_spec_list is empty."; | |||
| } | |||
| AbstractTuplePtr arg = nullptr; | |||
| AbstractTensorPtr tensor_base = nullptr; | |||
| size_t tuple_len = 0; | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| if (args_spec_list[0]->isa<AbstractTuple>()) { | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| tuple_len = arg->elements().size(); | |||
| tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0); | |||
| } else if (args_spec_list[0]->isa<AbstractTensor>()) { | |||
| tuple_len = args_spec_list.size(); | |||
| tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(tensor_base); | |||
| ShapeVector shape_base = tensor_base->shape()->shape(); | |||
| size_t rank = shape_base.size(); | |||
| ShapeVector out_shape{SizeToLong(tuple_len), SizeToLong(rank)}; | |||
| TypePtr out_type = kInt64; | |||
| return std::make_shared<AbstractTensor>(out_type, std::make_shared<Shape>(out_shape)); | |||
| } | |||
| AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| @@ -36,7 +36,8 @@ | |||
| #include "abstract/infer_functions.h" | |||
| #include "utils/ms_context.h" | |||
| #include "ops/tile.h" | |||
| #include "ops/slice.h" | |||
| #include "ops/grad/slice_grad.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| @@ -53,9 +54,13 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| const auto kGatherV2 = prim::kPrimGatherV2->name(); | |||
| const auto kDynamicShape = prim::kPrimDynamicShape->name(); | |||
| const auto kRange = prim::kPrimRange->name(); | |||
| const auto kConv2DBackpropFilter = prim::kPrimConv2DBackpropFilter->name(); | |||
| const auto kConv2DBackpropInput = prim::kPrimConv2DBackpropInput->name(); | |||
| const auto kTile = prim::kPrimTile->name(); | |||
| const auto kSlice = prim::kPrimSlice->name(); | |||
| const auto kSliceGrad = prim::kPrimSliceGrad->name(); | |||
| const auto kReshape = prim::kPrimReshape->name(); | |||
| // common dynamic shape depends | |||
| static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {{kUnsortedSegmentSum, {2}}, | |||
| {kUnsortedSegmentMin, {2}}, | |||
| @@ -69,7 +74,11 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| {kOneHot, {1, 3}}, | |||
| {kDropoutGenMask, {0}}, | |||
| {kStridedSlice, {1, 2, 3}}, | |||
| {kStridedSliceGrad, {1, 2, 3, 4}}}; | |||
| {kStridedSliceGrad, {1, 2, 3, 4}}, | |||
| {kTile, {1}}, | |||
| {kReshape, {1}}, | |||
| {kSlice, {1, 2}}, | |||
| {kSliceGrad, {2, 3}}}; | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| @@ -252,8 +261,11 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { | |||
| {prim::kPrimShape, {InferImplShape, nullptr, false}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}}, | |||
| {prim::kPrimStridedSlice, {ops::StridedSliceInfer, nullptr, true}}, | |||
| {prim::kPrimSlice, {ops::SliceInfer, nullptr, true}}, | |||
| {prim::kPrimSliceGrad, {ops::SliceGradInfer, nullptr, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, nullptr, true}}, | |||
| {prim::kPrimConcat, {InferImplConcat, nullptr, true}}, | |||
| {prim::kPrimConcatOffset, {InferImplConcatOffset, nullptr, true}}, | |||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}}, | |||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}}, | |||
| {prim::kPrimTransData, {InferImplTransData, nullptr, true}}, | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/grad/slice_grad.h" | |||
| #include <set> | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr SliceGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 4, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); | |||
| auto input_shape = shape_map[kShape]; | |||
| auto min_shape = shape_map[kMinShape]; | |||
| auto max_shape = shape_map[kMaxShape]; | |||
| if (max_shape.empty() && min_shape.empty()) { | |||
| return std::make_shared<abstract::Shape>(input_shape); | |||
| } | |||
| return std::make_shared<abstract::Shape>(input_shape, min_shape, max_shape); | |||
| } | |||
| TypePtr SliceGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto prim_name = prim->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("slice_grad_prim_infer", input_args.size(), kEqual, 4, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(input_args[1]); | |||
| auto x_type_map = input_args[1]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(x_type_map); | |||
| auto x_dtype = x_type_map->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(x_dtype); | |||
| std::set<TypePtr> template_types = {kTensorType}; | |||
| return CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, template_types, prim_name); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr SliceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return abstract::MakeAbstract(SliceGradInferShape(primitive, input_args), SliceGradInferType(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameSliceGrad, SliceGrad); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_SLICE_GRAD_H_ | |||
| #define MINDSPORE_CORE_OPS_SLICE_GRAD_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameSliceGrad = "SliceGrad"; | |||
| class MS_CORE_API SliceGrad : public PrimitiveC { | |||
| public: | |||
| SliceGrad() : PrimitiveC(kNameSliceGrad) { InitIOName({"dy", "x", "begin", "size"}, {"output"}); } | |||
| ~SliceGrad() = default; | |||
| MS_DECLARE_PARENT(SliceGrad, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr SliceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimSliceGradPtr = std::shared_ptr<SliceGrad>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_SUB_GRAD_H_ | |||
| @@ -26,6 +26,58 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 3, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||
| auto min_shape = shape_map[kMinShape]; | |||
| auto max_shape = shape_map[kMaxShape]; | |||
| std::vector<std::vector<int64_t>> input_values; | |||
| // get begin and size value | |||
| for (size_t i = 1; i <= 2; ++i) { | |||
| std::vector<int64_t> tmp_input; | |||
| auto input_value = input_args[i]->BuildValue(); | |||
| if (input_value->isa<tensor::Tensor>()) { | |||
| tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name); | |||
| } else { | |||
| tmp_input = CheckAndConvertUtils::CheckAttrTupleInt("slice args value", input_value, prim_name); | |||
| } | |||
| (void)input_values.emplace_back(tmp_input); | |||
| } | |||
| if (max_shape.empty() && min_shape.empty()) { | |||
| return std::make_shared<abstract::Shape>(input_values[1]); | |||
| } | |||
| return std::make_shared<abstract::Shape>(input_values[1], min_shape, max_shape); | |||
| } | |||
| TypePtr SliceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto prim_name = prim->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("slice_prim_infer", input_args.size(), kEqual, 3, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| auto x_type_map = input_args[0]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(x_type_map); | |||
| auto x_dtype = x_type_map->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(x_dtype); | |||
| std::set<TypePtr> template_types = {kTensorType}; | |||
| return CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, template_types, prim_name); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr SliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return abstract::MakeAbstract(SliceInferShape(primitive, input_args), SliceInferType(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameSlice, Slice); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -38,6 +38,9 @@ class MS_CORE_API Slice : public PrimitiveC { | |||
| /// \brief Init. | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr SliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimSlicePtr = std::shared_ptr<Slice>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -61,9 +61,14 @@ abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vect | |||
| auto input_shape = shape_map[kShape]; | |||
| auto min_shape = shape_map[kMinShape]; | |||
| auto max_shape = shape_map[kMaxShape]; | |||
| auto get_cast_temp = input_args[1]->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(get_cast_temp); | |||
| auto multiples_v = GetValue<std::vector<int64_t>>(get_cast_temp->BuildValue()); | |||
| std::vector<int64_t> multiples_v; | |||
| auto multiple_value = input_args[1]->BuildValue(); | |||
| if (multiple_value->isa<tensor::Tensor>()) { | |||
| multiples_v = CheckAndConvertUtils::CheckTensorIntValue("tile multiples value", multiple_value, prim_name); | |||
| } else { | |||
| multiples_v = CheckAndConvertUtils::CheckAttrTupleInt("tile multiples value", multiple_value, prim_name); | |||
| } | |||
| auto infer_shape = GetInferShape(input_shape, multiples_v); | |||
| if (max_shape.empty() && min_shape.empty()) { | |||
| return std::make_shared<abstract::Shape>(infer_shape); | |||
| @@ -182,6 +182,8 @@ def get_bprop_reshape(self): | |||
| def bprop(x, shp, out, dout): | |||
| shapex = shape_op(x) | |||
| if -1 in shapex: | |||
| shapex = dyn_shape_op(x) | |||
| return reshape(dout, shapex), zeros_like(shp) | |||
| return bprop | |||
| @@ -252,10 +254,29 @@ def _tile_shape(multiples, shapex): | |||
| @bprop_getters.register(P.Tile) | |||
| def get_bprop_tile(self): | |||
| """Generate bprop for Tile""" | |||
| tuple_to_array = P.TupleToArray() | |||
| cast = P.Cast() | |||
| stack_op = P.Stack(1) | |||
| ones = P.Ones() | |||
| concat = P.Concat() | |||
| def bprop(x, multiples, out, dout): | |||
| shapex = shape_op(x) | |||
| r_shape = _tile_shape(multiples, shapex) | |||
| if isinstance(multiples, tuple): | |||
| r_shape = _tile_shape(multiples, shapex) | |||
| else: | |||
| len_multi = size_op(multiples) | |||
| rank = len(shapex) | |||
| shape_tensor = cast(tuple_to_array(shapex), mstype.int64) | |||
| if len_multi > rank: | |||
| one_tensor = ones((len_multi - rank,), mstype.int64) | |||
| shape_tensor = concat((shape_tensor, one_tensor)) | |||
| elif len_multi < rank: | |||
| one_tensor = ones((rank - len_multi,), mstype.int64) | |||
| multiples = concat((multiples, one_tensor)) | |||
| tile_shape = stack_op((multiples, shape_tensor)) | |||
| r_shape = reshape(tile_shape, (-1,)) | |||
| # 0 represents the start index, and 2 represents the step | |||
| axis = F.make_range(0, len(r_shape), 2) | |||
| dx = reduce_sum(reshape(dout, r_shape), axis) | |||
| @@ -340,9 +361,16 @@ def get_bprop_concat(self): | |||
| out_offset = G.ConcatOffset(len(x), axis)(x) | |||
| input_nums = len(x) | |||
| input_shapes = () | |||
| for i in range(input_nums): | |||
| input_shapes = input_shapes + (shape_op(x[i]),) | |||
| is_uniform = _concat_grad_uniform(input_shapes, input_nums) | |||
| if isinstance(out_offset, tuple): | |||
| for i in range(input_nums): | |||
| input_shapes = input_shapes + (shape_op(x[i]),) | |||
| is_uniform = _concat_grad_uniform(input_shapes, input_nums) | |||
| else: | |||
| # for dynamic shape | |||
| for i in range(input_nums): | |||
| input_shapes = input_shapes + (dyn_shape_op(x[i]),) | |||
| is_uniform = False | |||
| if isinstance(x, list): | |||
| dx = [] | |||
| if is_uniform: | |||
| @@ -27,14 +27,18 @@ from ..functional import broadcast_gradient_args, reduced_shape, tuple_div | |||
| from .grad_base import bprop_getters | |||
| from ..primitive import constexpr | |||
| from ..composite.multitype_ops import _constexpr_utils as const_utils | |||
| from ..operations._inner_ops import DynamicStitch | |||
| from ...common import Tensor | |||
| shape_op = P.Shape() | |||
| dyn_shape_op = P.DynamicShape() | |||
| reduce_prod = P.ReduceProd() | |||
| reduce_sum = P.ReduceSum() | |||
| reshape = P.Reshape() | |||
| tile = P.Tile() | |||
| is_sub_class = P.IsSubClass() | |||
| to_array = P.TupleToArray() | |||
| real_div = P.RealDiv() | |||
| def binop_grad_common(x, y, dx, dy): | |||
| """ | |||
| @@ -61,11 +65,37 @@ def binop_grad_common(x, y, dx, dy): | |||
| return reduce_dx, reduce_dy | |||
| def _dyn_reduced_shape(input_shape, axis): | |||
| """Dynamic reduce shape""" | |||
| input_shape = P.Cast()(input_shape, ms.int32) | |||
| if isinstance(axis, Tensor): | |||
| input_rank = P.Rank()(input_shape) | |||
| real_axis = (axis + input_rank) % input_rank | |||
| axis_shape = shape_op(real_axis) | |||
| else: | |||
| real_axis = () | |||
| input_rank = len(input_shape) | |||
| if isinstance(axis, int): | |||
| axis = (axis,) | |||
| elif not axis: | |||
| axis = range(input_rank) | |||
| for i in axis: | |||
| real_axis += ((i + input_rank)%input_rank,) | |||
| axis_shape = (len(real_axis),) | |||
| return DynamicStitch()([to_array(range(input_rank)), to_array(axis)], | |||
| [input_shape, P.Fill()(ms.int32, axis_shape, 1)]) | |||
| def _sum_grad(x, axis, dout): | |||
| """Grad definition for `Sum` operation.""" | |||
| input_shape = shape_op(x) | |||
| output_shape_kept_dims = reduced_shape(input_shape, axis) | |||
| tile_scaling = tuple_div(input_shape, output_shape_kept_dims) | |||
| if -1 in input_shape: | |||
| input_shape = dyn_shape_op(x) | |||
| output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis) | |||
| tile_scaling = real_div(input_shape, output_shape_kept_dims) | |||
| else: | |||
| output_shape_kept_dims = reduced_shape(input_shape, axis) | |||
| tile_scaling = tuple_div(input_shape, output_shape_kept_dims) | |||
| grad = reshape(dout, output_shape_kept_dims) | |||
| return tile(grad, tile_scaling) | |||
| @@ -788,8 +818,16 @@ def get_bprop_reduce_mean(self): | |||
| def bprop(x, axis, out, dout): | |||
| grad = _sum_grad(x, axis, dout) | |||
| div_shape = F.shape_mul(shape_op(x)) / F.shape_mul(shape_op(out)) | |||
| dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad))) | |||
| shape_x = shape_op(x) | |||
| shape_out = shape_op(out) | |||
| if -1 in shape_x: | |||
| shape_x = dyn_shape_op(x) | |||
| shape_out = dyn_shape_op(out) | |||
| div_shape = reduce_prod(shape_x) / reduce_prod(shape_out) | |||
| dx = div_op(grad, cast(div_shape, dtype(grad))) | |||
| else: | |||
| div_shape = F.shape_mul(shape_x) / F.shape_mul(shape_out) | |||
| dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad))) | |||
| return dx, zeros_like(axis) | |||
| return bprop | |||
| @@ -1172,6 +1210,7 @@ def get_bprop_imag(self): | |||
| return bprop | |||
| @bprop_getters.register(P.ScalarCast) | |||
| def get_bprop_scalar_cast(self): | |||
| """Generate bprop for ScalarCast""" | |||
| @@ -20,6 +20,9 @@ from .pow import _pow_cpu | |||
| from .real_div import _real_div_cpu | |||
| from .div import _div_cpu | |||
| from .concat import _concat_cpu | |||
| from .concat_offset import _concat_offset_cpu | |||
| from .dynamic_shape import _dynamic_shape_cpu | |||
| from .dynamic_stitch import _dynamic_stitch_cpu | |||
| from .split import _split_cpu | |||
| from .adam import _adam_cpu | |||
| from .adam_weight_decay import _adam_weight_decay_cpu | |||
| @@ -55,6 +58,7 @@ from .reduce_mean import _reduce_mean_cpu | |||
| from .reduce_max import _reduce_max_cpu | |||
| from .reduce_sum import _reduce_sum_cpu | |||
| from .reduce_min import _reduce_min_cpu | |||
| from .reduce_prod import _reduce_prod_cpu | |||
| from .reduce_all import _reduce_all_cpu | |||
| from .reduce_any import _reduce_any_cpu | |||
| from .transpose import _transpose_cpu | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ConcatOffset op""" | |||
| from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType | |||
| concat_offset_op_info = CpuRegOp("ConcatOffset") \ | |||
| .input(0, "x", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(concat_offset_op_info) | |||
| def _concat_offset_cpu(): | |||
| """ConcatOffset cpu register""" | |||
| return | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DynamicShape op""" | |||
| from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType | |||
| dynamic_shape_op_info = CpuRegOp("DynamicShape") \ | |||
| .input(0, "x", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(dynamic_shape_op_info) | |||
| def _dynamic_shape_cpu(): | |||
| """DynamicShape cpu register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DynamicStitch op""" | |||
| from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType | |||
| dynamic_stitch_op_info = CpuRegOp("DynamicStitch") \ | |||
| .input(0, "indices", "dynamic") \ | |||
| .input(1, "data", "dynamic") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(dynamic_stitch_op_info) | |||
| def _dynamic_stitch_cpu(): | |||
| """DynamicStitch CPU register""" | |||
| return | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ReduceProd op""" | |||
| from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType | |||
| reduce_prod_op_info = CpuRegOp("ReduceProd") \ | |||
| .input(0, "x", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(reduce_prod_op_info) | |||
| def _reduce_prod_cpu(): | |||
| """ReduceProd cpu register""" | |||
| return | |||
| @@ -31,7 +31,6 @@ tile_op_info = CpuRegOp("Tile") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(tile_op_info) | |||
| def _tile_cpu(): | |||
| """Tile cpu register""" | |||
| @@ -16,5 +16,4 @@ bprop.32:x* | |||
| bprop.32:y* | |||
| bprop.32:out* | |||
| bprop.32:dout2 | |||
| bprop.32:[CNode]35:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.32:[CNode]35:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,4 +9,4 @@ s | |||
| bprop.13:x* | |||
| bprop.13:out* | |||
| bprop.13:dout2 | |||
| bprop.13:[CNode]15:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.13:[CNode]15:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,4 +9,4 @@ s | |||
| bprop.16:x* | |||
| bprop.16:out* | |||
| bprop.16:dout2 | |||
| bprop.16:[CNode]18:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.16:[CNode]18:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.21:x* | |||
| bprop.21:y* | |||
| bprop.21:out* | |||
| bprop.21:dout2 | |||
| bprop.21:[CNode]24:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.21:[CNode]24:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.25:x* | |||
| bprop.25:y* | |||
| bprop.25:out* | |||
| bprop.25:dout2 | |||
| bprop.25:[CNode]28:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.25:[CNode]28:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -12,4 +12,4 @@ bprop.60:x* | |||
| bprop.60:y* | |||
| bprop.60:out* | |||
| bprop.60:dout2 | |||
| bprop.60:[CNode]62:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.60:[CNode]62:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,4 +9,4 @@ s | |||
| bprop.67:x* | |||
| bprop.67:out* | |||
| bprop.67:dout2 | |||
| bprop.67:[CNode]69:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.67:[CNode]69:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -6,4 +6,4 @@ l | |||
| bprop.19:x* | |||
| bprop.19:out* | |||
| bprop.19:dout2 | |||
| bprop.19:[CNode]20:1:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.19:[CNode]20:1:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -17,4 +17,4 @@ | |||
| bprop.110:keep_prob* | |||
| bprop.110:out* | |||
| bprop.110:dout2 | |||
| bprop.110:[CNode]114:4:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.110:[CNode]114:4:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -12,4 +12,4 @@ | |||
| bprop.50:keep_prob* | |||
| bprop.50:out* | |||
| bprop.50:dout2 | |||
| bprop.50:[CNode]53:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.50:[CNode]53:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.70:x* | |||
| bprop.70:y* | |||
| bprop.70:out* | |||
| bprop.70:dout2 | |||
| bprop.70:[CNode]73:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.70:[CNode]73:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.82:x* | |||
| bprop.82:y* | |||
| bprop.82:out* | |||
| bprop.82:dout2 | |||
| bprop.82:[CNode]85:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.82:[CNode]85:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.78:x* | |||
| bprop.78:y* | |||
| bprop.78:out* | |||
| bprop.78:dout2 | |||
| bprop.78:[CNode]81:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.78:[CNode]81:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,8 +16,4 @@ bprop.63:x* | |||
| bprop.63:y* | |||
| bprop.63:out* | |||
| bprop.63:dout2 | |||
| <<<<<<< HEAD | |||
| bprop.63:[CNode]66:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| ======= | |||
| bprop.63:[CNode]66:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b94098bb48fd28f82b2c43a7cc640206eb58ba7b1f20ae3da88df883f2a08687246ba6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| >>>>>>> Add complex ops and bprop of real�conj�imag ops | |||
| bprop.63:[CNode]66:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -5,4 +5,4 @@ f | |||
| bprop.2:x* | |||
| bprop.2:out* | |||
| bprop.2:dout2 | |||
| bprop.2:[CNode]3:1:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.2:[CNode]3:1:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,8 +9,4 @@ s | |||
| bprop.29:x* | |||
| bprop.29:out* | |||
| bprop.29:dout2 | |||
| <<<<<<< HEAD | |||
| bprop.29:[CNode]31:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| ======= | |||
| bprop.29:[CNode]31:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b94098bb48fd28f82b2c43a7cc640206eb58ba7b1f20ae3da88df883f2a08687246ba6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| >>>>>>> Add complex ops and bprop of real�conj�imag ops | |||
| bprop.29:[CNode]31:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.90:x* | |||
| bprop.90:y* | |||
| bprop.90:out* | |||
| bprop.90:dout2 | |||
| bprop.90:[CNode]93:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.90:[CNode]93:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.86:x* | |||
| bprop.86:y* | |||
| bprop.86:out* | |||
| bprop.86:dout2 | |||
| bprop.86:[CNode]89:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.86:[CNode]89:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -15,4 +15,4 @@ | |||
| bprop.45:num* | |||
| bprop.45:out* | |||
| bprop.45:dout2 | |||
| bprop.45:[CNode]49:4:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.45:[CNode]49:4:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.94:x* | |||
| bprop.94:y* | |||
| bprop.94:out* | |||
| bprop.94:dout2 | |||
| bprop.94:[CNode]97:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.94:[CNode]97:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,4 +9,4 @@ s | |||
| bprop.39:x* | |||
| bprop.39:out* | |||
| bprop.39:dout2 | |||
| bprop.39:[CNode]41:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.39:[CNode]41:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.98:x* | |||
| bprop.98:y* | |||
| bprop.98:out* | |||
| bprop.98:dout2 | |||
| bprop.98:[CNode]101:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.98:[CNode]101:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -16,4 +16,4 @@ bprop.74:x* | |||
| bprop.74:y* | |||
| bprop.74:out* | |||
| bprop.74:dout2 | |||
| bprop.74:[CNode]77:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.74:[CNode]77:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -19,4 +19,4 @@ | |||
| bprop.54:off_value* | |||
| bprop.54:out* | |||
| bprop.54:dout2 | |||
| bprop.54:[CNode]59:5:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.54:[CNode]59:5:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -7,4 +7,4 @@ l | |||
| bprop.7:x* | |||
| bprop.7:out* | |||
| bprop.7:dout2 | |||
| bprop.7:[CNode]9:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.7:[CNode]9:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -7,4 +7,4 @@ l | |||
| bprop.4:x* | |||
| bprop.4:out* | |||
| bprop.4:dout2 | |||
| bprop.4:[CNode]6:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.4:[CNode]6:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -8,4 +8,4 @@ f | |||
| bprop.0:x* | |||
| bprop.0:out* | |||
| bprop.0:dout2 | |||
| bprop.0:[CNode]1:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.0:[CNode]1:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -12,4 +12,4 @@ | |||
| bprop.102:axis* | |||
| bprop.102:out* | |||
| bprop.102:dout2 | |||
| bprop.102:[CNode]105:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.102:[CNode]105:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -12,4 +12,4 @@ | |||
| bprop.106:axis* | |||
| bprop.106:out* | |||
| bprop.106:dout2 | |||
| bprop.106:[CNode]109:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.106:[CNode]109:3:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,4 +9,4 @@ s | |||
| bprop.42:x* | |||
| bprop.42:out* | |||
| bprop.42:dout2 | |||
| bprop.42:[CNode]44:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.42:[CNode]44:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,4 +9,4 @@ s | |||
| bprop.36:x* | |||
| bprop.36:out* | |||
| bprop.36:dout2 | |||
| bprop.36:[CNode]38:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.36:[CNode]38:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -9,4 +9,4 @@ r | |||
| bprop.10:x* | |||
| bprop.10:out* | |||
| bprop.10:dout2 | |||
| bprop.10:[CNode]12:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8c5478c3baea09296b77804c1b1016500b393d9513a0b382238539de4f65e5fc8e2ce5b00361039d0c4e69bd7f109a58f426d56b2a58bd719de8724107c0b940a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.10:[CNode]12:2:€13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e81c03ee41e2ae73005664666a652e28787ff0c69916d109f77e40345be916e60b2366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22308a904f11db2ab38215d47cda4850a3914997fcd06109bfaa0989d884b51a65565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d5f6486474eab638624ee1777f9652b2edd3f5f2873c570e2992de1d8b0e878737e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6dae47635340bc244097f692e3d086cc9ae28fd823d60946e421051e28dbccdadf5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -267,8 +267,18 @@ class ConcatOffset(PrimitiveWithInfer): | |||
| axis = self.axis | |||
| x_shp = input_x['shape'] | |||
| x_type = input_x['dtype'] | |||
| offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name) | |||
| self.add_prim_attr('T', x_type[0].element_type()) | |||
| # if input_x is dynamic shape | |||
| for each in x_shp: | |||
| if -1 in each: | |||
| return { | |||
| 'shape': [len(x_shp), len(x_shp[0])], | |||
| 'dtype': mstype.int64, | |||
| 'value': None | |||
| } | |||
| offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name) | |||
| offset_values = [] | |||
| for i in range(len(x_shp)): | |||
| values = [] | |||
| @@ -1791,12 +1801,22 @@ class SliceGrad(PrimitiveWithInfer): | |||
| def __infer__(self, dy, x, begin, size): | |||
| dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value'] | |||
| dy_shape_len = len(dy_shape) | |||
| for i in range(dy_shape_len): | |||
| validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name) | |||
| validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name) | |||
| if (size_value is not None) and (-1 not in x_shape): | |||
| for i in range(dy_shape_len): | |||
| validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name) | |||
| validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name) | |||
| if 'max_shape' in x: | |||
| max_shape = x['max_shape'] | |||
| min_shape = x['min_shape'] | |||
| else: | |||
| max_shape = [1] * dy_shape_len | |||
| min_shape = [1] * dy_shape_len | |||
| return {'shape': x_shape, | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| 'value': None, | |||
| 'max_shape': max_shape, | |||
| 'min_shape': min_shape} | |||
| class NLLLossGrad(PrimitiveWithInfer): | |||
| @@ -484,6 +484,24 @@ class Reshape(PrimitiveWithInfer): | |||
| def __infer__(self, x, shape): | |||
| shape_v = shape['value'] | |||
| if not shape_v and shape['shape']: | |||
| # for shape is not const value and not none | |||
| shape_rank = shape['shape'][0] | |||
| # unknown dims for shape | |||
| if shape_rank == -1: | |||
| shape_rank = 1 | |||
| out_shape = [-1] * shape_rank | |||
| min_shape = [1] * shape_rank | |||
| max_shape = [1] * shape_rank | |||
| return { | |||
| 'shape': out_shape, | |||
| 'dtype': x['dtype'], | |||
| 'value': None, | |||
| 'max_shape': max_shape, | |||
| 'min_shape': min_shape | |||
| } | |||
| x_shp = x['shape'] | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| validator.check_value_type("shape", shape_v, [tuple], self.name) | |||
| @@ -500,7 +518,7 @@ class Reshape(PrimitiveWithInfer): | |||
| else: | |||
| dim_prod *= shp_i | |||
| arr_prod = np.prod(x_shp) | |||
| if arr_prod <= 0: | |||
| if -1 in x_shp: | |||
| if 'max_shape' in x: | |||
| x_max_shape = x['max_shape'] | |||
| else: | |||
| @@ -516,9 +534,7 @@ class Reshape(PrimitiveWithInfer): | |||
| if neg_index != -1: | |||
| max_shape[neg_index] = int(max_arr_prod / dim_prod) | |||
| min_shape[neg_index] = int(min_arr_prod / dim_prod) | |||
| else: | |||
| raise ValueError(f"For '{self.name}', the 'input_shape' must have -1 in the case of dynamic shape, " | |||
| f"but got {shape_v}.") | |||
| out = {'shape': shape['value'], | |||
| 'dtype': x['dtype'], | |||
| 'value': None, | |||
| @@ -1964,20 +1980,39 @@ class Tile(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output']) | |||
| def check_elim(self, base_tensor, multiplier): | |||
| if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): | |||
| raise TypeError(f"For '{self.name}', the type of ('input_x', 'multiples') should be (Tensor, tuple), " | |||
| f"but got ({type(base_tensor).__name__}, {type(multiplier).__name__}).") | |||
| if not isinstance(base_tensor, Tensor): | |||
| raise TypeError(f"For '{self.name}', the type of 'input_x' should be Tensor, " | |||
| f"but got {type(base_tensor).__name__}.") | |||
| if all(v == 1 for v in multiplier): | |||
| return (True, base_tensor) | |||
| return (False, None) | |||
| def __infer__(self, x, multiples): | |||
| multiples_v = multiples['value'] | |||
| if multiples_v is None: | |||
| rank = max(len(x['shape']), multiples['shape'][0]) | |||
| out_shape = [-1] * rank | |||
| if 'max_shape' not in x: | |||
| max_shape = x['shape'] | |||
| min_shape = x['shape'] | |||
| else: | |||
| max_shape = x['max_shape'] | |||
| min_shape = x['min_shape'] | |||
| return {'shape': out_shape, | |||
| 'dtype': x['dtype'], | |||
| 'value': None, | |||
| 'min_shape': min_shape, | |||
| 'max_shape': max_shape | |||
| } | |||
| x_shp = x['shape'] | |||
| validator.check_value_type("multiples", multiples_v, [tuple], self.name) | |||
| validator.check_value_type( | |||
| "multiples", multiples_v, [tuple], self.name) | |||
| for i, multiple in enumerate(multiples_v): | |||
| validator.check_positive_int(multiple, "multiples[%d]" % i, self.name) | |||
| validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name) | |||
| validator.check_positive_int( | |||
| multiple, "multiples[%d]" % i, self.name) | |||
| validator.check_value_type( | |||
| "x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name) | |||
| len_sub = len(multiples_v) - len(x_shp) | |||
| multiples_w = None | |||
| if len_sub == 0: | |||
| @@ -2800,25 +2835,35 @@ class Slice(PrimitiveWithInfer): | |||
| def __infer__(self, x, begin, size): | |||
| x_shape = x['shape'] | |||
| x_shp_len = len(x_shape) | |||
| validator.check_valid_input('begin', begin['value'], self.name) | |||
| validator.check_valid_input('size', size['value'], self.name) | |||
| begin_v, size_v = begin['value'], size['value'] | |||
| if begin_v is None or size_v is None: | |||
| return {'shape': None, | |||
| out_shape = [-1] * size['shape'][0] | |||
| if 'max_shape' in x: | |||
| max_shape = x['max_shape'] | |||
| min_shape = x['min_shape'] | |||
| else: | |||
| min_shape = x['shape'] | |||
| max_shape = x['shape'] | |||
| return {'shape': out_shape, | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| 'value': None, | |||
| 'min_shape': min_shape, | |||
| 'max_shape': max_shape} | |||
| validator.check_valid_input('begin', begin['value'], self.name) | |||
| validator.check_valid_input('size', size['value'], self.name) | |||
| validator.check_value_type("input begin", begin_v, [tuple, list], self.name) | |||
| validator.check_value_type("input size", size_v, [tuple, list], self.name) | |||
| for key, value in zip(('begin', 'size'), (begin_v, size_v)): | |||
| validator.check(f'len of {key}', len(value), | |||
| 'len x\'s dim', x_shp_len) | |||
| for i in range(x_shp_len): | |||
| validator.check_positive_int(size_v[i], f'input size[{i}]') | |||
| validator.check_non_negative_int(begin_v[i], f'input begin[{i}]') | |||
| if x_shape[i] < begin_v[i] + size_v[i]: | |||
| y = begin_v[i] + size_v[i] | |||
| raise ValueError(f"For '{self.name}', the sliced shape can not be greater than origin shape, but got " | |||
| f"sliced shape is {y}, and origin shape is {x_shape}.") | |||
| if -1 not in x_shape: | |||
| for i in range(x_shp_len): | |||
| validator.check_positive_int(size_v[i], f'input size[{i}]') | |||
| validator.check_non_negative_int(begin_v[i], f'input begin[{i}]') | |||
| if x_shape[i] < begin_v[i] + size_v[i]: | |||
| y = begin_v[i] + size_v[i] | |||
| raise ValueError(f"For '{self.name}', the sliced shape can not be greater than origin shape, " | |||
| f"but got sliced shape is {y}, and origin shape is {x_shape}.") | |||
| return {'shape': size_v, | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| @@ -0,0 +1,335 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import sys | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops.composite import GradOperation | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="CPU") | |||
| class TileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.tile = P.Tile() | |||
| def construct(self, x, multiples): | |||
| out = self.tile(x, multiples) | |||
| return out | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_tile_multiple_tensor_cpu(): | |||
| """ | |||
| /// Feature: Tile op dynamic shape | |||
| /// Description: Tile forward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| multiples_1 = Tensor(np.array([2, 1]), mstype.int64) | |||
| multiples_2 = Tensor(np.array([4, 1]), mstype.int64) | |||
| x = Tensor(np.array([[1, 2, 3, 4]]), mstype.float32) | |||
| tile_net = TileNet() | |||
| expect_1 = np.array([[1., 2., 3., 4.], | |||
| [1., 2., 3., 4.]]) | |||
| expect_2 = np.array([[1., 2., 3., 4.], | |||
| [1., 2., 3., 4.], | |||
| [1., 2., 3., 4.], | |||
| [1., 2., 3., 4.]]) | |||
| expect = [expect_1, expect_2] | |||
| for i, multiples in enumerate([multiples_1, multiples_2]): | |||
| output = tile_net(x, multiples) | |||
| assert (output.asnumpy() == expect[i]).all() | |||
| class GradTile(nn.Cell): | |||
| def __init__(self, network): | |||
| super().__init__() | |||
| self.grad = GradOperation(sens_param=True) | |||
| self.network = network | |||
| self.unique = P.Unique() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, input_x, multiples, grad): | |||
| dy = self.unique(grad)[0] | |||
| dy = self.reshape(dy, (2, 4)) | |||
| return self.grad(self.network)(input_x, multiples, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_tile_multiple_tensor_grad_cpu(): | |||
| """ | |||
| /// Feature: Tile op dynamic shape | |||
| /// Description: Tile backward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| multiples = Tensor(np.array([2, 1]), mstype.int64) | |||
| x0 = Tensor(np.array([[1, 2, 3, 4]]), mstype.float32) | |||
| tile_net = GradTile(TileNet()) | |||
| dout = Tensor(np.arange(1, 9), mstype.float32) | |||
| output = tile_net(x0, multiples, dout) | |||
| expect = np.array([[6., 8., 10., 12.]]) | |||
| assert (output.asnumpy() == expect).all() | |||
| class ConcatOffsetNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.unique = P.Unique() | |||
| self.concat_offset = G.ConcatOffset(3, 0) | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x, y, z): | |||
| x = self.reshape(self.unique(x)[0], (-1, 1, 2, 1)) | |||
| y = self.reshape(self.unique(y)[0], (-1, 1, 2, 1)) | |||
| z = self.reshape(self.unique(z)[0], (-1, 1, 2, 1)) | |||
| out = self.concat_offset((x, y, z)) | |||
| return out | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_concat_offset_dynamic_cpu(): | |||
| """ | |||
| /// Feature: Concatoffset op dynamic shape | |||
| /// Description: Concatoffset forward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| x = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32) | |||
| x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32) | |||
| x3 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32) | |||
| net = ConcatOffsetNet() | |||
| out = net(x, x2, x3) | |||
| expect = np.array([[0, 0, 0, 0], | |||
| [3, 0, 0, 0], | |||
| [6, 0, 0, 0]]) | |||
| if isinstance(out, tuple): | |||
| assert (np.array(out) == expect).all() | |||
| else: | |||
| assert (out.asnumpy() == expect).all() | |||
| class ConcatNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.unique = P.Unique() | |||
| self.concat = P.Concat() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x, y, z, shape_tensor): | |||
| x = self.reshape(x, shape_tensor) | |||
| y = self.reshape(y, shape_tensor) | |||
| z = self.reshape(z, shape_tensor) | |||
| out = self.concat((x, y, z)) | |||
| return out | |||
| class GradConcat(nn.Cell): | |||
| def __init__(self, network): | |||
| super().__init__() | |||
| self.grad = GradOperation(sens_param=True) | |||
| self.network = network | |||
| self.unique = P.Unique() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x, y, z, shape, grad): | |||
| # grad = self.reshape(grad, (-1,)) | |||
| dy = self.reshape(self.unique(grad)[0], (-1, 1, 2, 1)) | |||
| return self.grad(self.network)(x, y, z, shape, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_concat_dynamic_grad_cpu(): | |||
| """ | |||
| /// Feature: Concat op dynamic shape | |||
| /// Description: Concat backward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| x = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32) | |||
| x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32) | |||
| x3 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.float32) | |||
| shape = Tensor(np.array([3, 1, 2, 1]), mstype.int64) | |||
| dout = Tensor(np.arange(1, 19), mstype.float32) | |||
| net = GradConcat(ConcatNet()) | |||
| output = net(x, x2, x3, shape, dout) | |||
| expect = np.array([1., 2., 3., 4., 5., 6.]) | |||
| assert (output.asnumpy() == expect).all() | |||
| class SliceNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.slice = P.Slice() | |||
| def construct(self, x, begin, size): | |||
| return self.slice(x, begin, size) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_slice_begin_size_tensor_cpu(): | |||
| """ | |||
| /// Feature: Slice op dynamic shape | |||
| /// Description: Slice forward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| x = Tensor( | |||
| np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]), mstype.float32) | |||
| begin = Tensor( | |||
| np.array([0, 1, 0]), mstype.int64) | |||
| size = Tensor( | |||
| np.array([2, 1, 2]), mstype.int64) | |||
| slice_net = SliceNet() | |||
| output = slice_net(x, begin, size) | |||
| expect = np.array([[[2., -2.]], | |||
| [[4., -4.]]]) | |||
| assert (output.asnumpy() == expect).all() | |||
| class GradSlice(nn.Cell): | |||
| def __init__(self, network): | |||
| super().__init__() | |||
| self.grad = GradOperation(sens_param=True) | |||
| self.network = network | |||
| self.unique = P.Unique() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, input_x, begin, size, grad): | |||
| # grad = self.reshape(grad, (-1,)) | |||
| dy = self.unique(grad)[0] | |||
| dy = self.reshape(dy, size) | |||
| return self.grad(self.network)(input_x, begin, size, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_slice_begin_size_tensor_grad(): | |||
| """ | |||
| /// Feature: Slice op dynamic shape | |||
| /// Description: Slice backward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| dy = Tensor(np.array([1, 2, 3, 4]), mstype.float32) | |||
| x = Tensor( | |||
| np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]), mstype.float32) | |||
| begin = Tensor( | |||
| np.array([0, 1, 0]), mstype.int64) | |||
| size = Tensor( | |||
| np.array([2, 1, 2]), mstype.int64) | |||
| net = GradSlice(SliceNet()) | |||
| output = net(x, begin, size, dy) | |||
| expect = np.array([[[0., 0., 0.], | |||
| [1., 2., 0.]], | |||
| [[0., 0., 0.], | |||
| [3., 4., 0.]], | |||
| [[0., 0., 0.], | |||
| [0., 0., 0.]]]) | |||
| assert (output.asnumpy() == expect).all() | |||
| class ReduceMeanNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reduce_mean = P.ReduceMean(keep_dims=True) | |||
| self.reshape = P.Reshape() | |||
| self.tile = P.Tile() | |||
| def construct(self, x, shape): | |||
| y = self.reshape(x, shape) | |||
| return self.reduce_mean(y, 0) | |||
| class GradReduceMean(nn.Cell): | |||
| def __init__(self, network): | |||
| super().__init__() | |||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| self.reshape = P.Reshape() | |||
| self.unique = P.Unique() | |||
| def construct(self, input_x, shape, grad): | |||
| grad = self.reshape(self.unique(grad)[0], (1, 2)) | |||
| return self.grad(self.network)(input_x, shape, grad) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_reducemean_dynamic_cpu(): | |||
| """ | |||
| /// Feature: ReduceMean op dynamic shape | |||
| /// Description: ReduceMean forward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| x = Tensor(np.array([10, 10, 2, 2]), mstype.float32) | |||
| x2 = Tensor(np.array([2, 2]), mstype.int64) | |||
| reduce_mean = ReduceMeanNet() | |||
| out = reduce_mean(x, x2) | |||
| expect = np.array([[6., 6.]]) | |||
| assert (out.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_reducemean_dynamic_grad_cpu(): | |||
| """ | |||
| /// Feature: ReduceMean op dynamic shape | |||
| /// Description: ReduceMean backward with dynamic shape | |||
| /// Expectation: Euqal to expected value | |||
| """ | |||
| if sys.platform != 'linux': | |||
| return | |||
| x = Tensor(np.array([10, 10, 2, 2]), mstype.float32) | |||
| x2 = Tensor(np.array([2, 2]), mstype.int64) | |||
| dout = Tensor(np.array([1, 3]), mstype.float32) | |||
| reduce_mean = GradReduceMean(ReduceMeanNet()) | |||
| out = reduce_mean(x, x2, dout) | |||
| expect = np.array([[0.5, 1.5, 0.5, 1.5]]) | |||
| assert (out[0].asnumpy() == expect).all() | |||
| @@ -83,10 +83,35 @@ class NetReduceLogic(nn.Cell): | |||
| self.reduce_any(indice, self.axis3),) | |||
| class NetReduceProd(nn.Cell): | |||
| def __init__(self): | |||
| super(NetReduceProd, self).__init__() | |||
| self.axis0 = 0 | |||
| self.axis1 = 1 | |||
| self.axis2 = -1 | |||
| self.axis3 = (0, 1) | |||
| self.axis4 = () | |||
| self.reduce_prod = P.ReduceProd(False) | |||
| self.reduce_prod_keep = P.ReduceProd(True) | |||
| @ms_function | |||
| def construct(self, indices): | |||
| return (self.reduce_prod(indices, self.axis0), | |||
| self.reduce_prod(indices, self.axis1), | |||
| self.reduce_prod(indices, self.axis2), | |||
| self.reduce_prod(indices, self.axis3), | |||
| self.reduce_prod_keep(indices, self.axis4)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_reduce(): | |||
| """ | |||
| /// Feature: Reduce | |||
| /// Description: reduce tensor elements, include reduce_mean, reduce_max, etc. | |||
| /// Expectation: Euqal to numpy results | |||
| """ | |||
| reduce = NetReduce() | |||
| indice = Tensor(np.array([ | |||
| [[0., 2., 1., 4., 0., 2.], [3., 1., 2., 2., 4., 0.]], | |||
| @@ -151,6 +176,11 @@ def test_reduce(): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_logic(): | |||
| """ | |||
| /// Feature: Reduce logic | |||
| /// Description: Include reduce_all, reduce_any | |||
| /// Expectation: Euqal to numpy results | |||
| """ | |||
| reduce_logic = NetReduceLogic() | |||
| indice_bool = Tensor([[[False, True, True, True, False, True], | |||
| [True, True, True, True, True, False]], | |||
| @@ -179,5 +209,38 @@ def test_reduce_logic(): | |||
| assert (output[7].asnumpy() == expect_any_4).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_prod(): | |||
| """ | |||
| /// Feature: Reduce prod | |||
| /// Description: Product of tensor elements | |||
| /// Expectation: Euqal to numpy results | |||
| """ | |||
| reduce_prod = NetReduceProd() | |||
| indices = Tensor(np.array([[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], | |||
| [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], | |||
| [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]]).astype(np.float32)) | |||
| output = reduce_prod(indices) | |||
| expect_prod_0 = np.array([[28, 28, 28, 28, 28, 28], | |||
| [80, 80, 80, 80, 80, 80], | |||
| [162, 162, 162, 162, 162, 162]]).astype(np.float32) | |||
| expect_prod_1 = np.array([[6, 6, 6, 6, 6, 6], | |||
| [120, 120, 120, 120, 120, 120], | |||
| [504, 504, 504, 504, 504, 504]]).astype(np.float32) | |||
| expect_prod_2 = np.array([[1.00000e+00, 6.40000e+01, 7.29000e+02], | |||
| [4.09600e+03, 1.56250e+04, 4.66560e+04], | |||
| [1.17649e+05, 2.62144e+05, 5.31441e+05]]).astype(np.float32) | |||
| expect_prod_3 = np.array([362880, 362880, 362880, 362880, 362880, 362880]).astype(np.float32) | |||
| expect_prod_4 = np.array([[[2.2833798e+33]]]).astype(np.float32) | |||
| assert (output[0].asnumpy() == expect_prod_0).all() | |||
| assert (output[1].asnumpy() == expect_prod_1).all() | |||
| assert (output[2].asnumpy() == expect_prod_2).all() | |||
| assert (output[3].asnumpy() == expect_prod_3).all() | |||
| assert (output[4].asnumpy() == expect_prod_4).all() | |||
| test_reduce() | |||
| test_reduce_logic() | |||