Browse Source

fixed erfc, geswitch, merge etc

tags/v0.3.0-alpha
jiangjinsheng 6 years ago
parent
commit
6cd71c1bd5
3 changed files with 6 additions and 6 deletions
  1. +1
    -0
      mindspore/ops/operations/__init__.py
  2. +3
    -4
      mindspore/ops/operations/control_ops.py
  3. +2
    -2
      mindspore/ops/operations/math_ops.py

+ 1
- 0
mindspore/ops/operations/__init__.py View File

@@ -142,6 +142,7 @@ __all__ = [
'ReLUV2',
'Elu',
'Erf',
'Erfc',
'Sigmoid',
'HSwish',
'HSigmoid',


+ 3
- 4
mindspore/ops/operations/control_ops.py View File

@@ -84,7 +84,7 @@ class GeSwitch(PrimitiveWithInfer):
the true branch will be activated, or vise verse.

Inputs:
- **data** (Tensor) - The data to be used for switch control.
- **data** (Union[Tensor, Number]) - The data to be used for switch control.
- **pred** (Tensor) - It should be a scalar whose type is bool and shape is `()`, It is used as condition for
switch control.
Outputs:
@@ -144,7 +144,7 @@ class Merge(PrimitiveWithInfer):
One and only one of the inputs should be selected as the output

Inputs:
- **inputs** (Tuple) - The data to be merged.
- **inputs** (Tuple) - The data to be merged. All tuple elements should have same data type.

Outputs:
tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
@@ -171,6 +171,5 @@ class Merge(PrimitiveWithInfer):
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item

validator.check_tensor_type_same(
args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)

+ 2
- 2
mindspore/ops/operations/math_ops.py View File

@@ -1397,14 +1397,14 @@ class EqualCount(PrimitiveWithInfer):
"""
Computes the number of the same elements of two tensors.

The two input tensors should have same shape.
The two input tensors should have same shape and same data type.

Inputs:
- **input_x** (Tensor) - The first input tensor.
- **input_y** (Tensor) - The second input tensor.

Outputs:
Tensor, with the type as `mindspore.int32` and size as (1,).
Tensor, with the type same as input tensor and size as (1,).

Examples:
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)


Loading…
Cancel
Save