Merge pull request !3026 from ZPaC/add-front-end-ps-optim-expressiontags/v0.6.0-beta
| @@ -38,6 +38,14 @@ static inline std::string GetEnv(const std::string &envvar) { | |||||
| return std::string(value); | return std::string(value); | ||||
| } | } | ||||
| static inline int SetEnv(const char *envname, const char *envvar, int overwrite = 1) { | |||||
| #if defined(_WIN32) | |||||
| return 0; | |||||
| #else | |||||
| return ::setenv(envname, envvar, overwrite); | |||||
| #endif | |||||
| } | |||||
| } // namespace common | } // namespace common | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -72,6 +72,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||||
| Register(kSpaceToBatchOpName, {1}); | Register(kSpaceToBatchOpName, {1}); | ||||
| Register(kBatchToSpaceOpName, {1}); | Register(kBatchToSpaceOpName, {1}); | ||||
| Register(kPadOpName, {1}); | Register(kPadOpName, {1}); | ||||
| Register(kPushOpName, {1}); | |||||
| } | } | ||||
| ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { | ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "transform/df_graph_manager.h" | #include "transform/df_graph_manager.h" | ||||
| #endif | #endif | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "common/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| @@ -168,6 +169,11 @@ bool MsContext::OpenTsd() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| auto role = common::GetEnv("MS_ROLE"); | |||||
| if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) { | |||||
| return true; | |||||
| } | |||||
| unsigned int device_id; | unsigned int device_id; | ||||
| unsigned int rank_size = 1; | unsigned int rank_size = 1; | ||||
| @@ -173,6 +173,10 @@ constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad"; | |||||
| constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; | constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; | ||||
| constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; | constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; | ||||
| constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; | constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; | ||||
| constexpr auto kPushOpName = "Push"; | |||||
| constexpr auto kPullOpName = "Pull"; | |||||
| constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; | |||||
| constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | |||||
| // attr key name | // attr key name | ||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||
| @@ -234,6 +238,8 @@ constexpr auto kAttrSizeSplits = "size_splits"; | |||||
| constexpr auto kAttrOutputDefault = "output_default"; | constexpr auto kAttrOutputDefault = "output_default"; | ||||
| constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; | constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; | ||||
| constexpr auto kAttrOffset = "offset"; | constexpr auto kAttrOffset = "offset"; | ||||
| constexpr auto kAttrPsKey = "ps_key"; | |||||
| constexpr auto kAttrOptimizerType = "optim_type"; | |||||
| // attr value | // attr value | ||||
| constexpr auto kValueTargetSwitch = "target_switch"; | constexpr auto kValueTargetSwitch = "target_switch"; | ||||
| @@ -286,12 +292,24 @@ const std::set<std::string> kOpFormatList = { | |||||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; | kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; | ||||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | ||||
| const std::set<std::string> kOptOperatorSet = { | const std::set<std::string> kOptOperatorSet = { | ||||
| kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, | |||||
| kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName, | |||||
| kApplyAdaMaxOpName, kApplyAddSignOpName, kApplyCenteredRMSPOpName, | |||||
| kApplyFtrlOpName, kApplyFtrlV2OpName, kApplyGradientDescentOpName, | |||||
| kApplyPowerSignOpName, kApplyProximalAdagradOpName, kApplyProximalGradientDescentOpName, | |||||
| kMomentumOpName, | |||||
| kApplyMomentumOpName, | |||||
| kApplyAdadeltaOpName, | |||||
| kApplyAdagradOpName, | |||||
| kApplyAdagradDAName, | |||||
| kApplyAdamOpName, | |||||
| kApplyAdaMaxOpName, | |||||
| kApplyAddSignOpName, | |||||
| kApplyCenteredRMSPOpName, | |||||
| kApplyFtrlOpName, | |||||
| kApplyFtrlV2OpName, | |||||
| kApplyGradientDescentOpName, | |||||
| kApplyPowerSignOpName, | |||||
| kApplyProximalAdagradOpName, | |||||
| kApplyProximalGradientDescentOpName, | |||||
| kApplyRMSPropOpName, | kApplyRMSPropOpName, | ||||
| kPushOpName, | |||||
| kPullOpName, | |||||
| }; | }; | ||||
| const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | ||||
| @@ -65,6 +65,7 @@ class Parameter: | |||||
| self.has_indexed_slices_grad = has_indexed_slices_grad | self.has_indexed_slices_grad = has_indexed_slices_grad | ||||
| self._is_init = False | self._is_init = False | ||||
| self._sliced = False | self._sliced = False | ||||
| self.is_param_ps = False | |||||
| if context.get_context("mode") == context.PYNATIVE_MODE: | if context.get_context("mode") == context.PYNATIVE_MODE: | ||||
| self.init_data() | self.init_data() | ||||
| @@ -75,6 +76,9 @@ class Parameter: | |||||
| def __parameter__(self): | def __parameter__(self): | ||||
| """For parse check.""" | """For parse check.""" | ||||
| def set_param_ps(self): | |||||
| self.is_param_ps = True | |||||
| @property | @property | ||||
| def name(self): | def name(self): | ||||
| """Get the name of the parameter.""" | """Get the name of the parameter.""" | ||||
| @@ -831,6 +831,20 @@ class Cell: | |||||
| self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") | self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") | ||||
| self.enable_hook = True | self.enable_hook = True | ||||
| def set_param_ps(self, recurse=True): | |||||
| """ | |||||
| Set whether the trainable parameter is updated by parameter server. | |||||
| Note: | |||||
| This only works when running task in parameter server mode. | |||||
| Args: | |||||
| recurse (bool): Whether sets the trainable parameters of subcells. Default: True. | |||||
| """ | |||||
| params = self.trainable_params(recurse) | |||||
| for param in params: | |||||
| param.set_param_ps() | |||||
| class GraphKernel(Cell): | class GraphKernel(Cell): | ||||
| """ | """ | ||||
| Base class for GraphKernel. | Base class for GraphKernel. | ||||
| @@ -20,14 +20,14 @@ The optimizer is used to calculate and update the gradients. | |||||
| """ | """ | ||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| from .momentum import Momentum | from .momentum import Momentum | ||||
| from .adam import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR | |||||
| from .adam import Adam, PSAdam, AdamWeightDecay, AdamWeightDecayDynamicLR | |||||
| from .lamb import Lamb | from .lamb import Lamb | ||||
| from .sgd import SGD | from .sgd import SGD | ||||
| from .lars import LARS | from .lars import LARS | ||||
| from .ftrl import FTRL | |||||
| from .ftrl import FTRL, PSFTRL | |||||
| from .rmsprop import RMSProp | from .rmsprop import RMSProp | ||||
| from .proximal_ada_grad import ProximalAdagrad | from .proximal_ada_grad import ProximalAdagrad | ||||
| from .lazyadam import LazyAdam | from .lazyadam import LazyAdam | ||||
| __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', | |||||
| 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad'] | |||||
| __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam', | |||||
| 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad'] | |||||
| @@ -27,6 +27,7 @@ from mindspore._checkparam import Rel | |||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| _adam_opt = C.MultitypeFuncGraph("adam_opt") | _adam_opt = C.MultitypeFuncGraph("adam_opt") | ||||
| _adam_push_pull_opt = C.MultitypeFuncGraph("_adam_push_pull_opt") | |||||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | ||||
| @@ -129,6 +130,31 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b | |||||
| eps, gradient)) | eps, gradient)) | ||||
| return success | return success | ||||
| @_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Tuple", "Tensor", "Tensor", "Tensor") | |||||
| def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||||
| moment1, moment2): | |||||
| """Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse.""" | |||||
| success = True | |||||
| op_shape = P.Shape() | |||||
| shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), | |||||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||||
| op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0])) | |||||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, | |||||
| eps, gradient[1], gradient[0]), shapes), params)) | |||||
| return success | |||||
| @_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||||
| def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||||
| moment1, moment2): | |||||
| """Apply adam optimizer by push and pull to the weight parameter using Tensor.""" | |||||
| success = True | |||||
| op_shape = P.Shape() | |||||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), | |||||
| (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) | |||||
| return success | |||||
| class Adam(Optimizer): | class Adam(Optimizer): | ||||
| r""" | r""" | ||||
| @@ -274,6 +300,51 @@ class Adam(Optimizer): | |||||
| gradients, params, moment1, moment2) | gradients, params, moment1, moment2) | ||||
| return success | return success | ||||
| class PSAdam(Optimizer): | |||||
| '''The same usage as Adam optimizer except the parameters are set PS mode.''' | |||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, | |||||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0): | |||||
| super(PSAdam, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||||
| validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) | |||||
| self.beta1 = Tensor(beta1, mstype.float32) | |||||
| self.beta2 = Tensor(beta2, mstype.float32) | |||||
| self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") | |||||
| self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") | |||||
| self.eps = Tensor(eps, mstype.float32) | |||||
| self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') | |||||
| self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.push = P.Push("Adam", [0, 1, 2]) | |||||
| self.push.add_prim_attr("primitive_target", "CPU") | |||||
| self.pull = P.Pull() | |||||
| self.pull.add_prim_attr("primitive_target", "CPU") | |||||
| def construct(self, gradients): | |||||
| params = self.parameters | |||||
| moment1 = self.moment1 | |||||
| moment2 = self.moment2 | |||||
| gradients = self.decay_weight(gradients) | |||||
| gradients = self.scale_grad(gradients) | |||||
| lr = self.get_lr() | |||||
| beta1_power = self.beta1_power * self.beta1 | |||||
| self.beta1_power = beta1_power | |||||
| beta2_power = self.beta2_power * self.beta2 | |||||
| self.beta2_power = beta2_power | |||||
| if self.is_group_lr: | |||||
| success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, | |||||
| self.beta1, self.beta2, self.eps), | |||||
| lr, gradients, params, moment1, moment2) | |||||
| else: | |||||
| success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, | |||||
| self.beta1, self.beta2, self.eps, lr), | |||||
| gradients, params, moment1, moment2) | |||||
| return success | |||||
| class AdamWeightDecay(Optimizer): | class AdamWeightDecay(Optimizer): | ||||
| """ | """ | ||||
| @@ -22,6 +22,7 @@ from mindspore._checkparam import Rel | |||||
| from .optimizer import Optimizer, _apply_decay, _grad_scale | from .optimizer import Optimizer, _apply_decay, _grad_scale | ||||
| _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | ||||
| _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") | |||||
| @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", | @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", | ||||
| @@ -41,6 +42,26 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra | |||||
| success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | ||||
| return success | return success | ||||
| @_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", | |||||
| "Tensor", "Tensor") | |||||
| def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, | |||||
| weight, moment): | |||||
| success = True | |||||
| op_shape = P.Shape() | |||||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0])) | |||||
| success = F.depend(success, pull(push((gradient[1], gradient[0]), shapes), weight)) | |||||
| return success | |||||
| @_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor") | |||||
| def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, | |||||
| weight, moment): | |||||
| success = True | |||||
| op_shape = P.Shape() | |||||
| success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), | |||||
| (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) | |||||
| return success | |||||
| def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None): | def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None): | ||||
| """Check param.""" | """Check param.""" | ||||
| @@ -131,3 +152,37 @@ class FTRL(Optimizer): | |||||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), | success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), | ||||
| linear, grads, params, moments) | linear, grads, params, moments) | ||||
| return success | return success | ||||
| class PSFTRL(Optimizer): | |||||
| def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, | |||||
| use_locking=False, loss_scale=1.0, weight_decay=0.0): | |||||
| super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) | |||||
| if self.is_group: | |||||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||||
| _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) | |||||
| self.moments = self.parameters.clone(prefix="moments", init=initial_accum) | |||||
| self.linear = self.parameters.clone(prefix="linear", init='zeros') | |||||
| self.l1 = l1 | |||||
| self.l2 = l2 | |||||
| self.lr_power = lr_power | |||||
| self.weight_decay = weight_decay | |||||
| self.decay_tf = tuple((lambda: True)() for x in self.parameters) | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.push = P.Push("Ftrl", [0, 1, 2]) | |||||
| self.push.add_prim_attr("primitive_target", "CPU") | |||||
| self.pull = P.Pull() | |||||
| self.pull.add_prim_attr("primitive_target", "CPU") | |||||
| def construct(self, grads): | |||||
| params = self.parameters | |||||
| moments = self.moments | |||||
| linear = self.linear | |||||
| lr = self.learning_rate | |||||
| if self.weight_decay > 0.0: | |||||
| grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) | |||||
| grads = self.scale_grad(grads) | |||||
| success = self.map_(F.partial(_ftrl_push_pull_opt, self.push, self.pull, lr, self.l1, self.l2, self.lr_power), | |||||
| linear, grads, params, moments) | |||||
| return success | |||||
| @@ -78,7 +78,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||||
| ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | ||||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop) | |||||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull) | |||||
| from .thor_ops import * | from .thor_ops import * | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -333,7 +333,9 @@ __all__ = [ | |||||
| "Mod", | "Mod", | ||||
| "PopulationCount", | "PopulationCount", | ||||
| "ParallelConcat", | "ParallelConcat", | ||||
| "EmbeddingLookup" | |||||
| "EmbeddingLookup", | |||||
| "Push", | |||||
| "Pull" | |||||
| ] | ] | ||||
| __all__.sort() | __all__.sort() | ||||
| @@ -488,3 +488,54 @@ class PopulationCount(PrimitiveWithInfer): | |||||
| args = {"x": x_dtype} | args = {"x": x_dtype} | ||||
| validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name) | validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name) | ||||
| return mstype.tensor_type(mstype.uint8) | return mstype.tensor_type(mstype.uint8) | ||||
| class Push(PrimitiveWithInfer): | |||||
| """ | |||||
| Pushing the inputs of the corresponding optimizer to parameter server. | |||||
| Args: | |||||
| optim_type (string): The optimizer type. Default: 'ApplyMomentum'. | |||||
| only_shape_indices (list): The indices of input of which only shape | |||||
| will be pushed to parameter server. Default: None. | |||||
| Inputs: | |||||
| - **optim_inputs** (tuple) - The inputs for this kind of optimizer. | |||||
| - **optim_input_shapes** (tuple) - The shapes of the inputs. | |||||
| Outputs: | |||||
| Tensor, the key of the weight which needs to be updated. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None): | |||||
| """init Push""" | |||||
| self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key']) | |||||
| def infer_shape(self, inputs, shapes): | |||||
| return [1] | |||||
| def infer_dtype(self, inputs, shapes): | |||||
| return mstype.uint64 | |||||
| class Pull(PrimitiveWithInfer): | |||||
| """ | |||||
| Pulling weight from parameter server. | |||||
| Inputs: | |||||
| - **key** (Tensor) - The key of the weight. | |||||
| - **weight** (Tensor) - The weight to be updated. | |||||
| Outputs: | |||||
| None. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init Pull""" | |||||
| self.init_prim_io_names(inputs=['key', 'weight'], outputs=['output']) | |||||
| def infer_shape(self, key_shape, weight_shape): | |||||
| return [1] | |||||
| def infer_dtype(self, key_dtype, weight_dtype): | |||||
| return mstype.float32 | |||||