|
|
|
@@ -16,10 +16,79 @@ |
|
|
|
|
|
|
|
#include <set> |
|
|
|
#include <memory> |
|
|
|
#include <algorithm> |
|
|
|
#include "ops/gather.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace ops { |
|
|
|
REGISTER_PRIMITIVE_C(kNameGather, Gather); |
|
|
|
// gather |
|
|
|
AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
const std::string &op_name = primitive->name(); |
|
|
|
abstract::CheckArgsSize(op_name, input_args, 3); |
|
|
|
abstract::AbstractTensorPtr params = |
|
|
|
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0); |
|
|
|
abstract::AbstractTensorPtr indices = |
|
|
|
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1); |
|
|
|
// check |
|
|
|
std::set<TypePtr> valid_params_types = {kTensorType}; |
|
|
|
CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name); |
|
|
|
std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64}; |
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name); |
|
|
|
CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name); |
|
|
|
|
|
|
|
bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()); |
|
|
|
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty()); |
|
|
|
int64_t axis_val = 0; |
|
|
|
// 3rd input is a Tensor when Gather is a dynamic shape operator |
|
|
|
if (input_args[2]->isa<abstract::AbstractTensor>()) { |
|
|
|
auto axis = input_args[2]->cast<abstract::AbstractTensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(axis); |
|
|
|
auto axis_value_ptr = axis->BuildValue(); |
|
|
|
MS_EXCEPTION_IF_NULL(axis_value_ptr); |
|
|
|
auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(axis_tensor); |
|
|
|
axis_val = *static_cast<int64_t *>(axis_tensor->data_c()); |
|
|
|
} else if (input_args[2]->isa<abstract::AbstractScalar>()) { |
|
|
|
auto axis = input_args[2]->cast<abstract::AbstractScalarPtr>(); |
|
|
|
axis_val = GetValue<int64_t>(axis->BuildValue()); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[2]->type_name(); |
|
|
|
} |
|
|
|
auto params_shp = params->shape()->shape(); |
|
|
|
auto indices_shp = indices->shape()->shape(); |
|
|
|
auto params_rank = static_cast<int64_t>(params_shp.size()); |
|
|
|
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis_val, kIncludeLeft, {-params_rank, params_rank}, op_name); |
|
|
|
// either inputs or both can be dynamic and computation requires min/max shapes for both |
|
|
|
ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape(); |
|
|
|
ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape(); |
|
|
|
ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape(); |
|
|
|
ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape(); |
|
|
|
// check axis_val within interval: [-params_rank, params_rank) |
|
|
|
if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) { |
|
|
|
MS_LOG(EXCEPTION) << "For Gather - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) " |
|
|
|
<< "Got " << axis_val << "."; |
|
|
|
} |
|
|
|
if (axis_val < 0) { |
|
|
|
axis_val += params_rank; |
|
|
|
} |
|
|
|
auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector { |
|
|
|
ShapeVector out_vec; |
|
|
|
std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); |
|
|
|
copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec)); |
|
|
|
copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec)); |
|
|
|
return out_vec; |
|
|
|
}; |
|
|
|
ShapeVector out_shape = calc_shape(indices_shp, params_shp); |
|
|
|
if (ind_dyn || param_dyn) { |
|
|
|
ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min); |
|
|
|
ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max); |
|
|
|
return std::make_shared<abstract::AbstractTensor>( |
|
|
|
params->element(), std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape)); |
|
|
|
} |
|
|
|
return std::make_shared<abstract::AbstractTensor>(params->element(), std::make_shared<abstract::Shape>(out_shape)); |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(Gather, prim::kPrimGather, GatherInfer, nullptr, true); |
|
|
|
} // namespace ops |
|
|
|
} // namespace mindspore |