| @@ -222,7 +222,8 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include <algorithm> | |||
| #include <iterator> | |||
| #include "abstract/infer_functions.h" | |||
| @@ -385,5 +386,35 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto perm = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| auto input_shp = input->shape()->shape(); | |||
| auto perm_val = perm->BuildValue(); | |||
| if (perm_val->isa<AnyValue>()) { | |||
| MS_LOG(EXCEPTION) << "Perm can't be anything: " << args_spec_list[1]->ToString(); | |||
| } | |||
| auto perm_val_data = perm_val->cast<ValueTuplePtr>()->value(); | |||
| ShapeVector perm_vec; | |||
| (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| ShapeVector result_shp; | |||
| std::set<size_t> indices; | |||
| for (size_t i = 0; i < perm_vec.size(); i++) { | |||
| size_t idx = static_cast<size_t>(perm_vec[i]); | |||
| if (indices.find(idx) != indices.end()) { | |||
| MS_LOG(EXCEPTION) << "Perm values must be unique"; | |||
| } | |||
| if (idx >= perm_vec.size()) { | |||
| MS_LOG(EXCEPTION) << "One value in perm is " << idx << ", not in range [0, " << perm_vec.size() << ")"; | |||
| } | |||
| result_shp.push_back(input_shp[idx]); | |||
| indices.insert(idx); | |||
| } | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp)); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -60,6 +60,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||
| {prim::kPrimShape, {InferImplShape, false}}, | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| @@ -589,7 +589,7 @@ class Squeeze(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class Transpose(PrimitiveWithInfer): | |||
| class Transpose(PrimitiveWithCheck): | |||
| """ | |||
| Permutes the dimensions of input tensor according to input permutation. | |||
| @@ -621,32 +621,13 @@ class Transpose(PrimitiveWithInfer): | |||
| """Initialize Transpose""" | |||
| self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) | |||
| def __infer__(self, x, perm): | |||
| x_shape = x['shape'] | |||
| p_value = perm['value'] | |||
| x_type = x['dtype'] | |||
| validator.check_value_type("p_value", p_value, [tuple], self.name) | |||
| validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | |||
| if len(x_shape) != len(p_value): | |||
| def check_shape(self, x, perm): | |||
| validator.check_value_type("perm", perm, [tuple], self.name) | |||
| if len(x) != len(perm): | |||
| raise ValueError('The dimension of x and perm must be equal.') | |||
| tmp = list(p_value) | |||
| for i, dim in enumerate(p_value): | |||
| validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name) | |||
| validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name) | |||
| tmp.remove(dim) | |||
| if dim in tmp: | |||
| raise ValueError('The value of perm is wrong.') | |||
| out_shapes = [] | |||
| for i in p_value: | |||
| out_shapes.append(x_shape[i]) | |||
| out = {'shape': tuple(out_shapes), | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| return out | |||
| def check_dtype(self, x, perm): | |||
| validator.check_subclass("x", x, mstype.tensor, self.name) | |||
| class Unique(Primitive): | |||
| """ | |||