From: @ginfung Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -453,7 +453,7 @@ class Validator: | |||
| return padding | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_types, prim_name): | |||
| def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None): | |||
| """Checks whether some type is subclass of another type""" | |||
| if not isinstance(template_types, Iterable): | |||
| template_types = (template_types,) | |||
| @@ -467,9 +467,12 @@ class Validator: | |||
| hit = True | |||
| break | |||
| if not hit: | |||
| if addition_error_info is None: | |||
| addition_error_info = '' | |||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||
| raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass' | |||
| f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.') | |||
| f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.' | |||
| f' {addition_error_info}') | |||
| @staticmethod | |||
| def check_const_input(arg_name, arg_value, prim_name): | |||
| @@ -401,7 +401,9 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & | |||
| } | |||
| if (tail_type_ == kGradFirst) { | |||
| if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && (*sequeue)[1]->isa<abstract::AbstractUndetermined>()) { | |||
| if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && | |||
| ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || | |||
| ((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa<Number>()))) { | |||
| ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); | |||
| } else { | |||
| ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); | |||
| @@ -413,7 +415,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & | |||
| for (size_t i = 1; i < sequeue->size(); ++i) { | |||
| if (tail_type_ == kGradAll) { | |||
| MS_EXCEPTION_IF_NULL((*sequeue)[i]); | |||
| if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>()) { | |||
| if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || | |||
| ((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa<Number>())) { | |||
| elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | |||
| } | |||
| } else { | |||
| @@ -224,7 +224,8 @@ void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNode | |||
| // b = Load(para1, u2) | |||
| // u3 = UpdateState(u2, x) | |||
| void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) { | |||
| AnfNodePtr other_input = nullptr; | |||
| // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load. | |||
| AnfNodePtr other_input = load; | |||
| for (size_t i = 1; i < make_tuple->size(); i++) { | |||
| if (make_tuple->input(i) != load) { | |||
| other_input = make_tuple->input(i); | |||
| @@ -489,7 +489,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { | |||
| continue; | |||
| } | |||
| AbstractBasePtr par_abs = param_node->abstract(); | |||
| if (par_abs->isa<abstract::AbstractUndetermined>()) { | |||
| if (par_abs->isa<abstract::AbstractUndetermined>() || | |||
| (par_abs->BuildType() != nullptr && par_abs->BuildType()->isa<Number>())) { | |||
| new_paras.push_back(param_node); | |||
| } | |||
| } | |||
| @@ -98,7 +98,7 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) | |||
| AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| bool broaden = value->isa<MetaTensor>(); | |||
| bool broaden = value->isa<MetaTensor>() || value->isa<Scalar>(); | |||
| return abstract::FromValue(value, broaden); | |||
| } | |||
| @@ -142,6 +142,21 @@ std::string GetCompileExceptionInfo() { | |||
| return oss.str(); | |||
| } | |||
| void SetGpuLoopSink(const ResourcePtr &resource_) { | |||
| auto func_graph = resource_->func_graph(); | |||
| if (func_graph != nullptr && func_graph->manager() != nullptr) { | |||
| auto manager = func_graph->manager(); | |||
| size_t graph_nums = manager->func_graphs().size(); | |||
| int64_t sinksize = ConfigManager::GetInstance().iter_num(); | |||
| if (graph_nums == 1) { | |||
| resource_->set_gpu_loopsink(true, sinksize); | |||
| } else { | |||
| resource_->set_gpu_loopsink(false, sinksize); | |||
| } | |||
| MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to " | |||
| << sinksize; | |||
| } | |||
| } | |||
| } // namespace | |||
| py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) { | |||
| @@ -704,19 +719,7 @@ void Pipeline::Run() { | |||
| MS_LOG(DEBUG) << "Action " << action.first << " end."; | |||
| }; | |||
| if (action.first == "task_emit") { | |||
| auto func_graph = resource_->func_graph(); | |||
| if (func_graph != nullptr && func_graph->manager() != nullptr) { | |||
| auto manager = func_graph->manager(); | |||
| size_t graph_nums = manager->func_graphs().size(); | |||
| int64_t sinksize = ConfigManager::GetInstance().iter_num(); | |||
| if (graph_nums == 1) { | |||
| resource_->set_gpu_loopsink(true, sinksize); | |||
| } else { | |||
| resource_->set_gpu_loopsink(false, sinksize); | |||
| } | |||
| MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to " | |||
| << sinksize; | |||
| } | |||
| SetGpuLoopSink(resource_); | |||
| } | |||
| if (!result) { | |||
| MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; | |||
| @@ -210,7 +210,7 @@ class _MindSporeFunction: | |||
| return None | |||
| new_inputs = [] | |||
| for i in args_list: | |||
| if isinstance(i, Tensor): | |||
| if isinstance(i, (Tensor, int, float)): | |||
| new_inputs.append(i) | |||
| return self._executor(tuple(new_inputs), phase) | |||
| @@ -88,7 +88,7 @@ std::string AbstractBase::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return Clone(); } | |||
| AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } | |||
| AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | |||
| MS_EXCEPTION_IF_NULL(other); | |||
| @@ -171,10 +171,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| return args_spec_list[0]; | |||
| } | |||
| auto depends = args_spec_list[0]->Broaden(); | |||
| // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value. | |||
| if (depends->isa<AbstractScalar>()) { | |||
| depends->set_value(kAnyValue); | |||
| } | |||
| return depends; | |||
| } | |||
| @@ -609,7 +609,7 @@ class Cell(Cell_): | |||
| new_inputs = [] | |||
| for i in inputs: | |||
| if isinstance(i, Tensor): | |||
| if isinstance(i, (Tensor, int, float)): | |||
| new_inputs.append(i) | |||
| if self._auto_parallel_mode: | |||
| @@ -199,10 +199,10 @@ class ForwardValueAndGrad(Cell): | |||
| If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through | |||
| the location parameter or key-value pair parameter. If the value is transferred through the key-value pair | |||
| parameter, the key must be sens. | |||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||
| Inputs: | |||
| - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. | |||
| - sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||
| Outputs: | |||
| - **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running. | |||
| @@ -242,7 +242,7 @@ class ForwardValueAndGrad(Cell): | |||
| >>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) | |||
| """ | |||
| def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False): | |||
| def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False, sens=1.0): | |||
| super(ForwardValueAndGrad, self).__init__(auto_prefix=False) | |||
| if not isinstance(network, (Cell, FunctionType, MethodType)): | |||
| raise TypeError(f"The type of training network should be cell, function type or method type, " | |||
| @@ -259,19 +259,16 @@ class ForwardValueAndGrad(Cell): | |||
| self.get_all = get_all | |||
| self.get_by_list = get_by_list | |||
| self.sens_param = sens_param | |||
| self.sens = sens | |||
| self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| if self.sens_param: | |||
| sens = inputs[-1] | |||
| inputs = inputs[:-1] | |||
| else: | |||
| sens = None | |||
| loss = self.network(*inputs) | |||
| if self.sens_param: | |||
| if not isinstance(sens, Tensor): | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), sens) | |||
| sens = self.sens | |||
| if not isinstance(self.sens, Tensor): | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||
| else: | |||
| grads = self.grad(self.network, weights)(*inputs) | |||
| @@ -223,7 +223,8 @@ class DType(PrimitiveWithInfer): | |||
| """Initialize DType""" | |||
| def __infer__(self, x): | |||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) | |||
| addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.' | |||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info) | |||
| out = {'shape': (), | |||
| 'dtype': mstype.type_type, | |||
| 'value': x['dtype'].element_type()} | |||
| @@ -414,13 +414,14 @@ def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1): | |||
| weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| optimizer = Momentum(weights, 0.1, 0.9) | |||
| train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) | |||
| train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, | |||
| sens=1.0) | |||
| losses = [] | |||
| for i in range(0, epoch): | |||
| data = Tensor(np.ones([batch_size, 3, 224, 224] | |||
| ).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([batch_size]).astype(np.int32)) | |||
| loss, grads = train_network(data, label, 1.0) | |||
| loss, grads = train_network(data, label) | |||
| grads = F.identity(grads) | |||
| optimizer(grads) | |||
| losses.append(loss) | |||
| @@ -439,13 +440,14 @@ def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=33 | |||
| weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| optimizer = Momentum(weights, 0.1, 0.9) | |||
| train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) | |||
| train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, | |||
| sens=1.0) | |||
| losses = [] | |||
| for i in range(0, epoch): | |||
| data = Tensor(np.ones([batch_size, 3, 224, 224] | |||
| ).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([batch_size]).astype(np.int32)) | |||
| loss, grads = train_network(data, label, 1.0) | |||
| loss, grads = train_network(data, label) | |||
| grads = F.identity(grads) | |||
| optimizer(grads) | |||
| losses.append(loss) | |||
| @@ -95,15 +95,23 @@ def test_ReduceAll(): | |||
| assert output[3].shape == expect3.shape | |||
| x_1 = np.array([[True, True], [True, False], [False, False]]) | |||
| axis_1 = 0 | |||
| x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) | |||
| axis_2 = 0 | |||
| class ReduceAllDynamic(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, x, axis): | |||
| super(ReduceAllDynamic, self).__init__() | |||
| self.reduceall = P.ReduceAll(False) | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.x = x | |||
| self.axis = axis | |||
| def construct(self, x, axis): | |||
| x = self.test_dynamic(x) | |||
| return self.reduceall(x, axis) | |||
| def construct(self): | |||
| dynamic_x = self.test_dynamic(self.x) | |||
| return self.reduceall(dynamic_x, self.axis) | |||
| @pytest.mark.level0 | |||
| @@ -111,18 +119,14 @@ class ReduceAllDynamic(nn.Cell): | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_all_dynamic(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = ReduceAllDynamic() | |||
| net1 = ReduceAllDynamic(Tensor(x_1), axis_1) | |||
| net2 = ReduceAllDynamic(Tensor(x_2), axis_2) | |||
| x_1 = np.array([[True, True], [True, False], [False, False]]) | |||
| axis_1 = 0 | |||
| expect_1 = np.all(x_1, axis=axis_1, keepdims=False) | |||
| x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) | |||
| axis_2 = 0 | |||
| expect_2 = np.all(x_2, axis=axis_2, keepdims=False) | |||
| output_1 = net(Tensor(x_1), axis_1) | |||
| output_2 = net(Tensor(x_2), axis_2) | |||
| output1 = net1() | |||
| output2 = net2() | |||
| np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) | |||
| np.testing.assert_almost_equal(output1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output2.asnumpy(), expect_2) | |||
| @@ -95,15 +95,23 @@ def test_ReduceAny(): | |||
| assert output[3].shape == expect3.shape | |||
| x_1 = np.array([[True, True], [True, False], [False, False]]) | |||
| axis_1 = 0 | |||
| x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) | |||
| axis_2 = 0 | |||
| class ReduceAnyDynamic(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, x, axis): | |||
| super(ReduceAnyDynamic, self).__init__() | |||
| self.reduceany = P.ReduceAny(False) | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.x = x | |||
| self.axis = axis | |||
| def construct(self, x, axis): | |||
| x = self.test_dynamic(x) | |||
| return self.reduceany(x, axis) | |||
| def construct(self): | |||
| dynamic_x = self.test_dynamic(self.x) | |||
| return self.reduceany(dynamic_x, self.axis) | |||
| @pytest.mark.level0 | |||
| @@ -111,18 +119,14 @@ class ReduceAnyDynamic(nn.Cell): | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_any_dynamic(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = ReduceAnyDynamic() | |||
| net1 = ReduceAnyDynamic(Tensor(x_1), axis_1) | |||
| net2 = ReduceAnyDynamic(Tensor(x_2), axis_2) | |||
| x_1 = np.array([[True, True], [True, False], [False, False]]) | |||
| axis_1 = 0 | |||
| expect_1 = np.any(x_1, axis=axis_1, keepdims=False) | |||
| x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) | |||
| axis_2 = 0 | |||
| expect_2 = np.any(x_2, axis=axis_2, keepdims=False) | |||
| output_1 = net(Tensor(x_1), axis_1) | |||
| output_2 = net(Tensor(x_2), axis_2) | |||
| output1 = net1() | |||
| output2 = net2() | |||
| np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) | |||
| np.testing.assert_almost_equal(output1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output2.asnumpy(), expect_2) | |||
| @@ -179,36 +179,41 @@ def test_ReduceMax(): | |||
| assert np.all(diff8 < error8) | |||
| x_1 = x8 | |||
| axis_1 = 0 | |||
| x_2 = x1 | |||
| axis_2 = 0 | |||
| class ReduceMaxDynamic(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, x, axis): | |||
| super(ReduceMaxDynamic, self).__init__() | |||
| self.reducemax = P.ReduceMax(False) | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.x = x | |||
| self.axis = axis | |||
| def construct(self, x, axis): | |||
| x = self.test_dynamic(x) | |||
| return self.reducemax(x, axis) | |||
| def construct(self): | |||
| dynamic_x = self.test_dynamic(self.x) | |||
| return self.reducemax(dynamic_x, self.axis) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_max_dynamic(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = ReduceMaxDynamic() | |||
| net1 = ReduceMaxDynamic(Tensor(x_1), axis_1) | |||
| net2 = ReduceMaxDynamic(Tensor(x_2), axis_2) | |||
| x_1 = x8 | |||
| axis_1 = 0 | |||
| expect_1 = np.max(x_1, axis=0, keepdims=False) | |||
| x_2 = x1 | |||
| axis_2 = 0 | |||
| expect_2 = np.max(x_2, axis=0, keepdims=False) | |||
| output_1 = net(Tensor(x_1), axis_1) | |||
| output_2 = net(Tensor(x_2), axis_2) | |||
| output1 = net1() | |||
| output2 = net2() | |||
| np.testing.assert_almost_equal(output1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output2.asnumpy(), expect_2) | |||
| np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) | |||
| class ReduceMaxTypeNet(nn.Cell): | |||
| def __init__(self, nptype): | |||
| @@ -268,14 +268,16 @@ def test_ReduceMean(): | |||
| assert output[14].shape == expect14.shape | |||
| class ReduceMeanDynamic(nn.Cell): | |||
| def __init__(self, keepdims=False): | |||
| def __init__(self, x, axis, keepdims=False): | |||
| super(ReduceMeanDynamic, self).__init__() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.reducemean = P.ReduceMean(keep_dims=keepdims) | |||
| self.x = x | |||
| self.axis = axis | |||
| def construct(self, input_x, axis): | |||
| input_x = self.test_dynamic(input_x) | |||
| output = self.reducemean(input_x, axis) | |||
| def construct(self): | |||
| dynamic_x = self.test_dynamic(self.x) | |||
| output = self.reducemean(dynamic_x, self.axis) | |||
| return output | |||
| @pytest.mark.level0 | |||
| @@ -283,32 +285,30 @@ class ReduceMeanDynamic(nn.Cell): | |||
| @pytest.mark.env_onecard | |||
| def test_dynamic_reduce_mean_keepdims_true(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = ReduceMeanDynamic(keepdims=True) | |||
| x_tensor_1 = Tensor(x14) | |||
| output_1 = net(x_tensor_1, axis14) | |||
| x_tensor_2 = Tensor(x0) | |||
| output_2 = net(x_tensor_2, axis0) | |||
| net1 = ReduceMeanDynamic(Tensor(x14), axis14, keepdims=True) | |||
| net2 = ReduceMeanDynamic(Tensor(x0), axis0, keepdims=True) | |||
| output1 = net1() | |||
| output2 = net2() | |||
| expect_1 = np.mean(x14, axis=np_axis14, keepdims=True) | |||
| diff_1 = abs(output_1.asnumpy() - expect_1) | |||
| diff_1 = abs(output1.asnumpy() - expect_1) | |||
| error_1 = np.ones(shape=expect_1.shape) * 1.0e-5 | |||
| assert np.all(diff_1 < error_1) | |||
| assert output_1.shape == expect_1.shape | |||
| assert output1.shape == expect_1.shape | |||
| expect_2 = np.mean(x0, axis=axis0, keepdims=True) | |||
| diff_2 = abs(output_2.asnumpy() - expect_2) | |||
| diff_2 = abs(output2.asnumpy() - expect_2) | |||
| error_2 = np.ones(shape=expect_2.shape) * 1.0e-5 | |||
| assert np.all(diff_2 < error_2) | |||
| assert output_2.shape == expect_2.shape | |||
| assert output2.shape == expect_2.shape | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dynamic_reduce_mean_keepdims_false(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = ReduceMeanDynamic(keepdims=False) | |||
| x_tensor = Tensor(x12) | |||
| output = net(x_tensor, axis12) | |||
| net = ReduceMeanDynamic(Tensor(x12), axis12, keepdims=False) | |||
| output = net() | |||
| expect = np.mean(x12, axis=axis12, keepdims=False) | |||
| diff = abs(output.asnumpy() - expect) | |||
| @@ -179,33 +179,37 @@ def test_ReduceMin(): | |||
| assert np.all(diff8 < error8) | |||
| x_1 = x8 | |||
| axis_1 = 0 | |||
| x_2 = x1 | |||
| axis_2 = 0 | |||
| class ReduceMinDynamic(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, x, axis): | |||
| super(ReduceMinDynamic, self).__init__() | |||
| self.reducemin = P.ReduceMin(False) | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.x = x | |||
| self.axis = axis | |||
| def construct(self, x, axis): | |||
| x = self.test_dynamic(x) | |||
| return self.reducemin(x, axis) | |||
| def construct(self): | |||
| dynamic_x = self.test_dynamic(self.x) | |||
| return self.reducemin(dynamic_x, self.axis) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_min_dynamic(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = ReduceMinDynamic() | |||
| net1 = ReduceMinDynamic(Tensor(x_1), axis_1) | |||
| net2 = ReduceMinDynamic(Tensor(x_2), axis_2) | |||
| x_1 = x8 | |||
| axis_1 = 0 | |||
| expect_1 = np.min(x_1, axis=0, keepdims=False) | |||
| x_2 = x1 | |||
| axis_2 = 0 | |||
| expect_2 = np.min(x_2, axis=0, keepdims=False) | |||
| output_1 = net(Tensor(x_1), axis_1) | |||
| output_2 = net(Tensor(x_2), axis_2) | |||
| output1 = net1() | |||
| output2 = net2() | |||
| np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) | |||
| np.testing.assert_almost_equal(output1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output2.asnumpy(), expect_2) | |||
| @@ -270,15 +270,23 @@ def test_ReduceSum(): | |||
| assert output[14].shape == expect14.shape | |||
| x_1 = x8 | |||
| axis_1 = 0 | |||
| x_2 = x1 | |||
| axis_2 = 0 | |||
| class ReduceSumDynamic(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, x, axis): | |||
| super(ReduceSumDynamic, self).__init__() | |||
| self.reducesum = P.ReduceSum(True) | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.x = x | |||
| self.axis = axis | |||
| def construct(self, x, axis): | |||
| x = self.test_dynamic(x) | |||
| return self.reducesum(x, axis) | |||
| def construct(self): | |||
| dynamic_x = self.test_dynamic(self.x) | |||
| return self.reducesum(dynamic_x, self.axis) | |||
| @pytest.mark.level0 | |||
| @@ -286,21 +294,18 @@ class ReduceSumDynamic(nn.Cell): | |||
| @pytest.mark.env_onecard | |||
| def test_reduce_sum_dynamic(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = ReduceSumDynamic() | |||
| net1 = ReduceSumDynamic(Tensor(x_1), axis_1) | |||
| net2 = ReduceSumDynamic(Tensor(x_2), axis_2) | |||
| x_1 = x8 | |||
| axis_1 = 0 | |||
| expect_1 = np.sum(x_1, axis=axis_1, keepdims=True) | |||
| x_2 = x1 | |||
| axis_2 = 0 | |||
| expect_2 = np.sum(x_2, axis=axis_2, keepdims=True) | |||
| output_1 = net(Tensor(x_1), axis_1) | |||
| output_2 = net(Tensor(x_2), axis_2) | |||
| output1 = net1() | |||
| output2 = net2() | |||
| np.testing.assert_almost_equal(output1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output2.asnumpy(), expect_2) | |||
| np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) | |||
| np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) | |||
| class ReduceSumTypeNet(nn.Cell): | |||
| def __init__(self, nptype): | |||
| @@ -32,18 +32,26 @@ TEST_F(TestUtils, test_join) { | |||
| AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false); | |||
| AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false); | |||
| AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true); | |||
| abs_s_anything->set_value(kAnyValue); | |||
| AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); | |||
| ASSERT_EQ(*res_s1, *abs_s_anything); | |||
| // AbstractTuple join; | |||
| std::vector<int64_t> list1 = {1, 2, 3, 4, 5}; | |||
| std::vector<int64_t> list2 = {5, 4, 3, 2, 1}; | |||
| AbstractBasePtr abs_t1 = FromValue(list1, true); | |||
| AbstractBasePtr abs_t2 = FromValue(list2, true); | |||
| AbstractBasePtr res_t1 = abs_t1->Join(abs_t2); | |||
| ASSERT_EQ(res_t1, abs_t1); | |||
| abs_s1 = FromValue(static_cast<int64_t>(1), false); | |||
| AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); | |||
| AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); | |||
| AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything})); | |||
| AbstractBasePtr res_t1 = t1->Join(t2); | |||
| res_t1 = t1->Join(t2); | |||
| ASSERT_EQ(res_t1, t1); | |||
| res_t1 = t1->Join(t3); | |||
| @@ -111,11 +111,8 @@ TEST_F(TestOptLib, test_inline) { | |||
| // add infer and renormalize | |||
| std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>(); | |||
| AbstractBasePtrList args_spec_list; | |||
| tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3}); | |||
| tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3}); | |||
| AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true); | |||
| AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true); | |||
| AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true); | |||
| AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true); | |||
| args_spec_list.push_back(abstract_v1); | |||
| args_spec_list.push_back(abstract_v2); | |||
| AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); | |||
| @@ -184,7 +184,7 @@ TEST_F(TestData, test_broaden) { | |||
| AbstractBasePtr s2 = s1->Broaden(); | |||
| ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); | |||
| ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); | |||
| ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>()); | |||
| ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>()); | |||
| AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), | |||
| AnalysisContext::DummyContext()); | |||
| @@ -196,7 +196,7 @@ TEST_F(TestData, test_broaden) { | |||
| AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get()); | |||
| ASSERT_TRUE(l2_cast != nullptr); | |||
| AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); | |||
| ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>()); | |||
| ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>()); | |||
| } | |||
| } // namespace abstract | |||
| @@ -20,7 +20,6 @@ from mindspore import Tensor, Parameter | |||
| from mindspore import context | |||
| from mindspore import dtype as mstype | |||
| from mindspore.nn import Cell | |||
| from mindspore.ops import operations as P | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ | |||
| @@ -684,27 +683,6 @@ def test_tensor_assign_bool_index(): | |||
| net4(Ta, Tensor(u_scalar, mstype.int32)) | |||
| def test_trivial_call_function_twice_with_diff_key_value_para(): | |||
| class Net(Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) | |||
| self.concat = P.Concat(axis=0) | |||
| def compute(self, x, is_decoder): | |||
| if is_decoder: | |||
| return self.arange[:x] | |||
| return self.arange[1:x + 1] | |||
| def construct(self): | |||
| result1 = self.compute(7, is_decoder=True) | |||
| result2 = self.compute(6, is_decoder=False) | |||
| return self.concat((result1, result2)) | |||
| net = Net() | |||
| net() | |||
| test_cases = [ | |||
| ('TensorAssignWithTupleEllipsis2', { | |||
| 'block': TensorAssignWithTupleEllipsis2(), | |||
| @@ -19,7 +19,7 @@ from mindspore import Tensor, ms_function | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| @ms_function | |||
| @@ -33,8 +33,7 @@ def test_scalar_compute(): | |||
| p = (3, 4) | |||
| q = [5, 6] | |||
| w = {"x": 7, "y": 8} | |||
| ret = compute(int_x, int_y, p, q, w) | |||
| assert ret == -1 | |||
| compute(int_x, int_y, p, q, w) | |||
| def test_tensor_compute(): | |||
| @@ -45,6 +45,17 @@ class GradNet(nn.Cell): | |||
| return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag) | |||
| class GradNet1(nn.Cell): | |||
| def __init__(self, net, get_all): | |||
| super(GradNet1, self).__init__() | |||
| self.forward_net = net | |||
| self.sens = Tensor(np.ones((2, 2), np.float32) * 5) | |||
| self.grad_all = C.GradOperation(get_all=get_all) | |||
| def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c): | |||
| return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c) | |||
| x = Tensor(np.ones((2, 2), np.float32)) | |||
| y = Tensor(np.ones((2, 2), np.float32) * 2) | |||
| z = Tensor(np.ones((2, 2), np.float32) * 3) | |||
| @@ -68,33 +79,18 @@ forward_net = FirstInputTupleNet() | |||
| grad_all_inputs_net = GradNet(forward_net, get_all=True) | |||
| def test_outermost_net_inputs_including_non_tensor(): | |||
| forward_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) | |||
| forward_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1) | |||
| def test_grad_net_inputs_including_non_tensor(): | |||
| assert len(grad_all_inputs_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)) == 2 | |||
| assert len(grad_all_inputs_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)) == 2 | |||
| def test_grad_first_input_net(): | |||
| class FirstInputTensorNet(nn.Cell): | |||
| def __init__(self): | |||
| super(FirstInputTensorNet, self).__init__() | |||
| def construct(self, tensor_x, tuple_a, list_b, tensor_y, scalar, dict_c, flag): | |||
| if flag: | |||
| return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"] | |||
| return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"] | |||
| def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c): | |||
| return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - tensor_z + dict_c["y"] | |||
| grad_fist_input_tensor_net = GradNet(FirstInputTensorNet(), get_all=False) | |||
| ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, sl, args_d0, flag_0) | |||
| grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False) | |||
| ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, y, args_d0) | |||
| assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32)) | |||
| grad_fist_input_tuple_net = GradNet(forward_net, get_all=False) | |||
| assert not grad_fist_input_tuple_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) | |||
| def test_net_inputs_including_str(): | |||
| with pytest.raises(TypeError) as err: | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """ test_framstruct """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| @@ -76,11 +76,13 @@ def dynamic_make_tuple(x, lower, upper): | |||
| def test_dynamic_make_tuple(): | |||
| assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) | |||
| # Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language. | |||
| with pytest.raises(RuntimeError): | |||
| dynamic_make_tuple(2, 1, 5) | |||
| def test_make_tuple(): | |||
| # Staticly recursively creating static type is valid in mindspore. | |||
| # Statically recursively creating static type is valid in mindspore. | |||
| @ms_function | |||
| def make_tuple(x): | |||
| out = () | |||