diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index e89c756924..017d512a15 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -224,7 +224,7 @@ class EmbeddingLookup(Cell): elif slice_mode == "table_row_slice" and is_auto_parallel: if target == 'DEVICE': indices_shape_size = 1 - self.gather_revert.shard(((1, 1), (1,))) + self.gather_revert.shard(((1, 1), (get_group_size(),))) self.forward_unique = True indices_strategy = (1,)*indices_shape_size self.gatherv2.shard(((get_group_size(), 1), indices_strategy)) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2fe71579c1..0ed8b0a2cb 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1883,18 +1883,27 @@ class UnsortedSegmentSum(PrimitiveWithInfer): shp = [num_segments_v] shp += x_shp[segment_ids_shp_len:] - if 'max_shape' in x: - output_incoming = x['max_shape'] + if "max_value" in num_segments and "min_value" in num_segments: + output_max_shape = list(num_segments['max_value']) + output_min_shape = list(num_segments['min_value']) + else: + if isinstance(num_segments_type, type(mstype.tensor)): + raise ValueError("In dynamic shape scene, the num_segments should contains max_value and min_value") output_max_shape = [num_segments_v] - output_max_shape += output_incoming[segment_ids_shp_len:] + output_min_shape = [num_segments_v] + if 'max_shape' in x and 'min_shape' in x: + max_output_incoming = x['max_shape'] + min_output_incoming = x['min_shape'] else: - output_max_shape = x_shp - out = {'shape': shp, - 'max_shape': output_max_shape, - 'min_shape': [1] * segment_ids_shp_len + x_shp[segment_ids_shp_len:], - 'dtype': mstype.tensor_type(x_type.element_type()), - 'value': None} - return out + max_output_incoming = x_shp + min_output_incoming = x_shp + output_max_shape += max_output_incoming[segment_ids_shp_len:] + output_min_shape += min_output_incoming[segment_ids_shp_len:] + return {'shape': shp, + 'max_shape': output_max_shape, + 'min_shape': output_min_shape, + 'dtype': mstype.tensor_type(x_type.element_type()), + 'value': None} class UnsortedSegmentMin(PrimitiveWithCheck): @@ -2688,6 +2697,26 @@ class StridedSlice(PrimitiveWithInfer): ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v) value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type()) + if "max_value" in x and "min_value" in x: + validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name) + validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name) + max_value_np = np.array(x["max_value"]) + min_value_np = np.array(x["min_value"]) + slice_index = [] + for begin_i, end_i, strides_i in zip(begin_v, end_v, strides_v): + s = slice(begin_i, end_i, strides_i) + slice_index.append(s) + slice_index = tuple(slice_index) + max_value_slice = max_value_np[slice_index] + min_value_slice = min_value_np[slice_index] + max_value_slice = tuple(max_value_slice.tolist()) + min_value_slice = tuple(min_value_slice.tolist()) + return {'shape': ret_shape, + 'dtype': x['dtype'], + 'value': value, + 'max_value': max_value_slice, + 'min_value': min_value_slice} + return {'shape': ret_shape, 'dtype': x['dtype'], 'value': value} diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index d179224ead..6dd469532f 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -207,8 +207,6 @@ class WideDeepModel(nn.Cell): target = 'CPU' self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) - if target == 'DEVICE': - self.wide_mul.shard(((1, 1, 1), (1, 1, 1))) if config.deep_table_slice_mode == "column_slice": self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE)