| @@ -75,7 +75,7 @@ AnfNodePtr CreateInt32Tensor(int64_t value) { | |||||
| if (it != int_tensor_map.end()) { | if (it != int_tensor_map.end()) { | ||||
| return it->second; | return it->second; | ||||
| } | } | ||||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(py::int_(value), kInt32); | |||||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(value, kInt32); | |||||
| ValuePtr value_ptr = MakeValue(tensor_ptr); | ValuePtr value_ptr = MakeValue(tensor_ptr); | ||||
| auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr); | auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr); | ||||
| int_tensor_map[value] = anf_node_ptr; | int_tensor_map[value] = anf_node_ptr; | ||||
| @@ -382,7 +382,7 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32); | tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32); | ||||
| *tensor_mask = kValueNodeTensorMask; | *tensor_mask = kValueNodeTensorMask; | ||||
| } else if (py::isinstance<py::int_>(input_object)) { | } else if (py::isinstance<py::int_>(input_object)) { | ||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt64); | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64); | |||||
| *tensor_mask = kValueNodeTensorMask; | *tensor_mask = kValueNodeTensorMask; | ||||
| } else if (py::isinstance<py::array>(input_object)) { | } else if (py::isinstance<py::array>(input_object)) { | ||||
| tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr); | tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr); | ||||
| @@ -20,16 +20,13 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <list> | |||||
| #include <utility> | #include <utility> | ||||
| #include <cfloat> | #include <cfloat> | ||||
| #include "abstract/abstract_value.h" | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/shape_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| bool ValueToBool(const ValuePtr &v, bool *value) { | bool ValueToBool(const ValuePtr &v, bool *value) { | ||||
| @@ -37,13 +34,13 @@ bool ValueToBool(const ValuePtr &v, bool *value) { | |||||
| if (v->isa<BoolImm>()) { | if (v->isa<BoolImm>()) { | ||||
| *value = v->cast<BoolImmPtr>()->value(); | *value = v->cast<BoolImmPtr>()->value(); | ||||
| } else if (v->isa<Int32Imm>()) { | } else if (v->isa<Int32Imm>()) { | ||||
| *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true; | |||||
| *value = v->cast<Int32ImmPtr>()->value() != 0; | |||||
| } else if (v->isa<UInt32Imm>()) { | } else if (v->isa<UInt32Imm>()) { | ||||
| *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true; | |||||
| *value = v->cast<UInt32ImmPtr>()->value() != 0; | |||||
| } else if (v->isa<FP32Imm>()) { | } else if (v->isa<FP32Imm>()) { | ||||
| *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true; | |||||
| *value = v->cast<FP32ImmPtr>()->value() != 0; | |||||
| } else if (v->isa<FP64Imm>()) { | } else if (v->isa<FP64Imm>()) { | ||||
| *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true; | |||||
| *value = v->cast<FP64ImmPtr>()->value() != 0; | |||||
| } else if (v->isa<tensor::Tensor>()) { | } else if (v->isa<tensor::Tensor>()) { | ||||
| auto tensor = v->cast<tensor::TensorPtr>(); | auto tensor = v->cast<tensor::TensorPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| @@ -65,11 +62,11 @@ bool BaseRefToInt(const ValuePtr &v, int64_t *value) { | |||||
| auto tensor = v->cast<tensor::TensorPtr>(); | auto tensor = v->cast<tensor::TensorPtr>(); | ||||
| (void)tensor->data_sync(); | (void)tensor->data_sync(); | ||||
| if (tensor->Dtype()->ToString() == "Int32") { | if (tensor->Dtype()->ToString() == "Int32") { | ||||
| int32_t *tensor_data = static_cast<int32_t *>(tensor->data_c()); | |||||
| auto *tensor_data = static_cast<int32_t *>(tensor->data_c()); | |||||
| auto vb = tensor_data[0]; | auto vb = tensor_data[0]; | ||||
| *value = static_cast<int64_t>(vb); | *value = static_cast<int64_t>(vb); | ||||
| } else if (tensor->Dtype()->ToString() == "Int64") { | } else if (tensor->Dtype()->ToString() == "Int64") { | ||||
| int64_t *tensor_data = static_cast<int64_t *>(tensor->data_c()); | |||||
| auto *tensor_data = static_cast<int64_t *>(tensor->data_c()); | |||||
| auto vb = tensor_data[0]; | auto vb = tensor_data[0]; | ||||
| *value = vb; | *value = vb; | ||||
| } else { | } else { | ||||
| @@ -86,39 +83,19 @@ bool BaseRefToBool(const BaseRef &v, bool *value) { | |||||
| return ValueToBool(utils::cast<ValuePtr>(v), value); | return ValueToBool(utils::cast<ValuePtr>(v), value); | ||||
| } else if (utils::isa<bool>(v)) { | } else if (utils::isa<bool>(v)) { | ||||
| auto vb = utils::cast<bool>(v); | auto vb = utils::cast<bool>(v); | ||||
| if (vb == true) { | |||||
| *value = true; | |||||
| } else { | |||||
| *value = false; | |||||
| } | |||||
| *value = vb; | |||||
| } else if (utils::isa<int>(v)) { | } else if (utils::isa<int>(v)) { | ||||
| auto vb = utils::cast<int>(v); | auto vb = utils::cast<int>(v); | ||||
| if (vb == 0) { | |||||
| *value = false; | |||||
| } else { | |||||
| *value = true; | |||||
| } | |||||
| *value = vb != 0; | |||||
| } else if (utils::isa<unsigned int>(v)) { | } else if (utils::isa<unsigned int>(v)) { | ||||
| auto vb = utils::cast<unsigned int>(v); | auto vb = utils::cast<unsigned int>(v); | ||||
| if (vb == 0) { | |||||
| *value = false; | |||||
| } else { | |||||
| *value = true; | |||||
| } | |||||
| *value = vb != 0; | |||||
| } else if (utils::isa<float>(v)) { | } else if (utils::isa<float>(v)) { | ||||
| auto vb = utils::cast<float>(v); | auto vb = utils::cast<float>(v); | ||||
| if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) { | |||||
| *value = false; | |||||
| } else { | |||||
| *value = true; | |||||
| } | |||||
| *value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON); | |||||
| } else if (utils::isa<double>(v)) { | } else if (utils::isa<double>(v)) { | ||||
| auto vb = utils::cast<double>(v); | auto vb = utils::cast<double>(v); | ||||
| if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) { | |||||
| *value = false; | |||||
| } else { | |||||
| *value = true; | |||||
| } | |||||
| *value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON); | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "value is not supported to cast to be bool"; | MS_LOG(DEBUG) << "value is not supported to cast to be bool"; | ||||
| return false; | return false; | ||||
| @@ -187,13 +164,13 @@ bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMap | |||||
| return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node); | return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node); | ||||
| } | } | ||||
| bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph, | |||||
| bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph, | |||||
| NodeMapEquiv *const equiv_node) { | NodeMapEquiv *const equiv_node) { | ||||
| std::unordered_set<AnfNodePtr> done; | std::unordered_set<AnfNodePtr> done; | ||||
| std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo; | std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo; | ||||
| todo.push(std::make_pair(root1, root2)); | todo.push(std::make_pair(root1, root2)); | ||||
| while (todo.size() > 0) { | |||||
| while (!todo.empty()) { | |||||
| AnfNodePtr node1 = todo.top().first; | AnfNodePtr node1 = todo.top().first; | ||||
| if (done.count(node1) > 0) { | if (done.count(node1) > 0) { | ||||
| todo.pop(); | todo.pop(); | ||||
| @@ -231,7 +208,7 @@ bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equ | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph, | |||||
| bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph, | |||||
| NodeMapEquiv *const equiv_node) { | NodeMapEquiv *const equiv_node) { | ||||
| auto fg1_fg2 = std::make_pair(fg1, fg2); | auto fg1_fg2 = std::make_pair(fg1, fg2); | ||||
| if (equiv_func_graph == nullptr) { | if (equiv_func_graph == nullptr) { | ||||
| @@ -267,23 +244,35 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { | |||||
| if (scalar == nullptr) { | if (scalar == nullptr) { | ||||
| MS_EXCEPTION(ArgumentError) << "Nullptr Error!"; | MS_EXCEPTION(ArgumentError) << "Nullptr Error!"; | ||||
| } | } | ||||
| tensor::TensorPtr tensor = nullptr; | |||||
| if (scalar->isa<FloatImm>()) { | |||||
| tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32); | |||||
| } else if (scalar->isa<Int32Imm>()) { | |||||
| tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32); | |||||
| } else if (scalar->isa<Int64Imm>()) { | |||||
| tensor = std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), kInt64); | |||||
| } else if (scalar->isa<BoolImm>()) { | |||||
| const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0; | |||||
| tensor = std::make_shared<tensor::Tensor>(bool_value, kBool); | |||||
| } else { | |||||
| auto type = scalar->type(); | |||||
| auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); | |||||
| MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str; | |||||
| TypePtr data_type = scalar->type(); | |||||
| MS_EXCEPTION_IF_NULL(data_type); | |||||
| TypeId type_id = data_type->type_id(); | |||||
| switch (type_id) { | |||||
| case kNumberTypeBool: | |||||
| return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type); | |||||
| case kNumberTypeInt8: | |||||
| return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type); | |||||
| case kNumberTypeInt16: | |||||
| return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type); | |||||
| case kNumberTypeInt32: | |||||
| return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type); | |||||
| case kNumberTypeInt64: | |||||
| return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type); | |||||
| case kNumberTypeUInt8: | |||||
| return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type); | |||||
| case kNumberTypeUInt16: | |||||
| return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type); | |||||
| case kNumberTypeUInt32: | |||||
| return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type); | |||||
| case kNumberTypeUInt64: | |||||
| return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type); | |||||
| case kNumberTypeFloat32: | |||||
| return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type); | |||||
| case kNumberTypeFloat64: | |||||
| return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type); | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << "is valid."; | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| return tensor; | |||||
| } | } | ||||
| void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) { | void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) { | ||||
| @@ -301,7 +290,7 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> * | |||||
| } | } | ||||
| } | } | ||||
| } else if (value->isa<tensor::Tensor>()) { | } else if (value->isa<tensor::Tensor>()) { | ||||
| tensor::TensorPtr tensor = value->cast<tensor::TensorPtr>(); | |||||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| tensors->push_back(tensor); | tensors->push_back(tensor); | ||||
| } | } | ||||
| @@ -57,7 +57,8 @@ enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 }; | |||||
| using FuncGraphPairMapEquiv = std::unordered_map<std::pair<FuncGraphPtr, FuncGraphPtr>, EquivState, PairHasher>; | using FuncGraphPairMapEquiv = std::unordered_map<std::pair<FuncGraphPtr, FuncGraphPtr>, EquivState, PairHasher>; | ||||
| using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>; | using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>; | ||||
| bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); | |||||
| bool Isomorphic(const FuncGraphPtr &g1, const FuncGraphPtr &g2, FuncGraphPairMapEquiv *equiv_func_graph, | |||||
| NodeMapEquiv *equiv_node); | |||||
| tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); | tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); | ||||
| @@ -491,6 +491,16 @@ Tensor::Tensor(double input, const TypePtr &data_type) | |||||
| data_(MakeTensorData(data_type_, {}, input)), | data_(MakeTensorData(data_type_, {}, input)), | ||||
| id_(MakeId()) {} | id_(MakeId()) {} | ||||
| Tensor::Tensor(uint64_t input, const TypePtr &data_type) | |||||
| : MetaTensor(TypeIdOf(data_type, kNumberTypeUInt64), {}), | |||||
| data_(MakeTensorData(data_type_, {}, input)), | |||||
| id_(MakeId()) {} | |||||
| Tensor::Tensor(bool input, const TypePtr &data_type) | |||||
| : MetaTensor(TypeIdOf(data_type, kNumberTypeBool), {}), | |||||
| data_(MakeTensorData(data_type_, {}, input)), | |||||
| id_(MakeId()) {} | |||||
| bool Tensor::operator==(const Tensor &tensor) const { | bool Tensor::operator==(const Tensor &tensor) const { | ||||
| return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); | return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); | ||||
| } | } | ||||
| @@ -172,6 +172,18 @@ class Tensor : public MetaTensor { | |||||
| // param data_type [TypeId] data type | // param data_type [TypeId] data type | ||||
| explicit Tensor(double input, const TypePtr &data_type = nullptr); | explicit Tensor(double input, const TypePtr &data_type = nullptr); | ||||
| // brief Create 0 dimension tensor from a uint scalar. | |||||
| // | |||||
| // param input [uint] the data for tensor | |||||
| // param data_type [TypeId] data type | |||||
| explicit Tensor(uint64_t input, const TypePtr &data_type = nullptr); | |||||
| // brief Create 0 dimension tensor from a bool scalar. | |||||
| // | |||||
| // param input [bool] the data for tensor | |||||
| // param data_type [TypeId] data type | |||||
| explicit Tensor(bool input, const TypePtr &data_type = nullptr); | |||||
| ~Tensor() override = default; | ~Tensor() override = default; | ||||
| MS_DECLARE_PARENT(Tensor, MetaTensor); | MS_DECLARE_PARENT(Tensor, MetaTensor); | ||||
| @@ -88,6 +88,7 @@ class L1Regularizer(Cell): | |||||
| l1_regularization = self.scale * self.reduce_sum(self.abs(weights)) | l1_regularization = self.scale * self.reduce_sum(self.abs(weights)) | ||||
| return l1_regularization | return l1_regularization | ||||
| class Dropout(Cell): | class Dropout(Cell): | ||||
| r""" | r""" | ||||
| Dropout layer for the input. | Dropout layer for the input. | ||||
| @@ -210,6 +211,7 @@ class Flatten(Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return F.reshape(x, (F.shape(x)[0], -1)) | return F.reshape(x, (F.shape(x)[0], -1)) | ||||
| @constexpr | @constexpr | ||||
| def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): | def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): | ||||
| """get broadcast_weight_bias shape""" | """get broadcast_weight_bias shape""" | ||||
| @@ -217,6 +219,7 @@ def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): | |||||
| broad_bias_shape = x_shape[:-1] + (out_channel,) | broad_bias_shape = x_shape[:-1] + (out_channel,) | ||||
| return broad_weight_shape, broad_bias_shape | return broad_weight_shape, broad_bias_shape | ||||
| class Dense(Cell): | class Dense(Cell): | ||||
| r""" | r""" | ||||
| The dense connected layer. | The dense connected layer. | ||||
| @@ -262,6 +265,7 @@ class Dense(Cell): | |||||
| [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] | [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] | ||||
| [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] | [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] | ||||
| """ | """ | ||||
| @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) | @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) | ||||
| def __init__(self, | def __init__(self, | ||||
| in_channels, | in_channels, | ||||
| @@ -323,7 +327,6 @@ class Dense(Cell): | |||||
| x = self.activation(x) | x = self.activation(x) | ||||
| return x | return x | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) | s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) | ||||
| if self.has_bias: | if self.has_bias: | ||||
| @@ -339,10 +342,12 @@ def _is_equal_one(x): | |||||
| return False | return False | ||||
| return bool(x.asnumpy().mean() == 1.0) | return bool(x.asnumpy().mean() == 1.0) | ||||
| @constexpr | @constexpr | ||||
| def _dtype_check(x_dtype): | def _dtype_check(x_dtype): | ||||
| if x_dtype not in [mstype.float32, mstype.float16]: | if x_dtype not in [mstype.float32, mstype.float16]: | ||||
| raise TypeError("The input type must be float32 or float16.") | |||||
| raise TypeError("The input type must be float32 or float16.") | |||||
| @constexpr | @constexpr | ||||
| def _is_float_dtype(dtype): | def _is_float_dtype(dtype): | ||||
| @@ -539,7 +544,6 @@ class OneHot(Cell): | |||||
| return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype)) | return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype)) | ||||
| class Pad(Cell): | class Pad(Cell): | ||||
| """ | """ | ||||
| Pads the input tensor according to the paddings and mode. | Pads the input tensor according to the paddings and mode. | ||||
| @@ -672,6 +676,7 @@ class Interpolate(Cell): | |||||
| >>> print(result.shape) | >>> print(result.shape) | ||||
| (1, 1, 5, 5) | (1, 1, 5, 5) | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Interpolate, self).__init__() | super(Interpolate, self).__init__() | ||||
| @@ -767,6 +772,7 @@ class Tril(Cell): | |||||
| [[1 0] | [[1 0] | ||||
| [3 4]] | [3 4]] | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Tril, self).__init__() | super(Tril, self).__init__() | ||||
| self.dtype = P.DType() | self.dtype = P.DType() | ||||
| @@ -809,6 +815,7 @@ class Triu(Cell): | |||||
| [[1 2] | [[1 2] | ||||
| [0 4]] | [0 4]] | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Triu, self).__init__() | super(Triu, self).__init__() | ||||
| self.dtype = P.DType() | self.dtype = P.DType() | ||||
| @@ -859,6 +866,7 @@ class MatrixDiag(Cell): | |||||
| [[ 1. 0.] | [[ 1. 0.] | ||||
| [ 0. -1.]] | [ 0. -1.]] | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(MatrixDiag, self).__init__() | super(MatrixDiag, self).__init__() | ||||
| self.matrix_diag = inner.MatrixDiag() | self.matrix_diag = inner.MatrixDiag() | ||||
| @@ -895,6 +903,7 @@ class MatrixDiagPart(Cell): | |||||
| [-1. 1.] | [-1. 1.] | ||||
| [-1. 1.]] | [-1. 1.]] | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(MatrixDiagPart, self).__init__() | super(MatrixDiagPart, self).__init__() | ||||
| self.matrix_diag_part = inner.MatrixDiagPart() | self.matrix_diag_part = inner.MatrixDiagPart() | ||||
| @@ -936,6 +945,7 @@ class MatrixSetDiag(Cell): | |||||
| [[-1. 0.] | [[-1. 0.] | ||||
| [ 0. 1.]]] | [ 0. 1.]]] | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(MatrixSetDiag, self).__init__() | super(MatrixSetDiag, self).__init__() | ||||
| self.matrix_set_diag = inner.MatrixSetDiag() | self.matrix_set_diag = inner.MatrixSetDiag() | ||||
| @@ -407,7 +407,7 @@ class ParameterUpdate(Cell): | |||||
| >>> param = network.parameters_dict()['weight'] | >>> param = network.parameters_dict()['weight'] | ||||
| >>> update = nn.ParameterUpdate(param) | >>> update = nn.ParameterUpdate(param) | ||||
| >>> update.phase = "update_param" | >>> update.phase = "update_param" | ||||
| >>> weight = Tensor(0.001, mindspore.float32) | |||||
| >>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32) | |||||
| >>> update(weight) | >>> update(weight) | ||||
| """ | """ | ||||