GitOrigin-RevId: 57a3b9d418
tags/v1.10.0
| @@ -51,14 +51,7 @@ class _Hashable: | |||||
| return self.value == o.value | return self.value == o.value | ||||
| def _matmul( | |||||
| inp1, | |||||
| inp2, | |||||
| transpose_a=False, | |||||
| transpose_b=False, | |||||
| compute_mode="default", | |||||
| format="default", | |||||
| ): | |||||
| def _matmul(inp1, inp2, transpose_a=False, transpose_b=False, compute_mode="default"): | |||||
| dim1, dim2 = inp1.ndim, inp2.ndim | dim1, dim2 = inp1.ndim, inp2.ndim | ||||
| assert dim1 > 0 and dim2 > 0 | assert dim1 > 0 and dim2 > 0 | ||||
| maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
| @@ -1206,6 +1206,7 @@ def batch_norm( | |||||
| if x is None: | if x is None: | ||||
| x = Const(value, inp.dtype, inp.device) | x = Const(value, inp.dtype, inp.device) | ||||
| x.format = inp.format | |||||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
| (result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
| return result | return result | ||||
| @@ -1227,14 +1228,14 @@ def batch_norm( | |||||
| if not training: | if not training: | ||||
| op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps | |||||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11" | |||||
| ) | ) | ||||
| ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ||||
| return ret | return ret | ||||
| else: | else: | ||||
| op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
| avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps | |||||
| avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11" | |||||
| ) | ) | ||||
| if has_mean or has_var: | if has_mean or has_var: | ||||
| running_mean = make_full_if_none(running_mean, 0) | running_mean = make_full_if_none(running_mean, 0) | ||||
| @@ -272,6 +272,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
| x = Const(value, inp.dtype, inp.device) | x = Const(value, inp.dtype, inp.device) | ||||
| if inp.ndim == 0: | if inp.ndim == 0: | ||||
| return x | return x | ||||
| # set x's format to use FormatTransformation rule for Broadcast. | |||||
| x.format = inp.format | |||||
| return broadcast_to(x, inp.shape) | return broadcast_to(x, inp.shape) | ||||
| @@ -91,13 +91,14 @@ class Optimizer(metaclass=ABCMeta): | |||||
| else: | else: | ||||
| param_group["params"] = list(param_group["params"]) | param_group["params"] = list(param_group["params"]) | ||||
| for param in param_group["params"]: | |||||
| if not isinstance(param, Parameter): | |||||
| raise TypeError( | |||||
| "optimizer can only optimize Parameters, but one of the params is " | |||||
| + str(type(param)) | |||||
| ) | |||||
| param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
| with _config._override(auto_format_convert=False): | |||||
| for param in param_group["params"]: | |||||
| if not isinstance(param, Parameter): | |||||
| raise TypeError( | |||||
| "optimizer can only optimize Parameters, but one of the params is " | |||||
| + str(type(param)) | |||||
| ) | |||||
| param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
| for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
| if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
| @@ -58,7 +58,6 @@ def run_around_tests(): | |||||
| "benchmark_kernel": config.benchmark_kernel, | "benchmark_kernel": config.benchmark_kernel, | ||||
| "deterministic_kernel": config.deterministic_kernel, | "deterministic_kernel": config.deterministic_kernel, | ||||
| "compute_mode": config._compute_mode, | "compute_mode": config._compute_mode, | ||||
| "conv_format": config._conv_format, | |||||
| "amp_enabled": amp.enabled, | "amp_enabled": amp.enabled, | ||||
| "convert_inputs": _get_convert_inputs(), | "convert_inputs": _get_convert_inputs(), | ||||
| "amp_dtype_autocast": _get_amp_dtype_autocast(), | "amp_dtype_autocast": _get_amp_dtype_autocast(), | ||||
| @@ -82,7 +81,6 @@ def run_around_tests(): | |||||
| "benchmark_kernel": config.benchmark_kernel, | "benchmark_kernel": config.benchmark_kernel, | ||||
| "deterministic_kernel": config.deterministic_kernel, | "deterministic_kernel": config.deterministic_kernel, | ||||
| "compute_mode": config._compute_mode, | "compute_mode": config._compute_mode, | ||||
| "conv_format": config._conv_format, | |||||
| "amp_enabled": amp.enabled, | "amp_enabled": amp.enabled, | ||||
| "convert_inputs": _get_convert_inputs(), | "convert_inputs": _get_convert_inputs(), | ||||
| "amp_dtype_autocast": _get_amp_dtype_autocast(), | "amp_dtype_autocast": _get_amp_dtype_autocast(), | ||||
| @@ -386,13 +386,6 @@ def test_backward_conv2d_dimshuffle(is_symbolic): | |||||
| return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) | return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) | ||||
| inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | ||||
| # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
| # w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||||
| # b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||||
| # grads = [ | |||||
| # np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||||
| # np.array([12, 12, 12]).reshape((1, 1, 1, 3)), | |||||
| # ] | |||||
| _compare_backward([inp], Net(), is_symbolic) | _compare_backward([inp], Net(), is_symbolic) | ||||
| @@ -403,37 +396,10 @@ def test_backward_groupconv2d_bn(is_symbolic): | |||||
| super().__init__() | super().__init__() | ||||
| self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) | self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) | ||||
| self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) | self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) | ||||
| # self.bn = M.BatchNorm2d(2048) | |||||
| self.bn = M.BatchNorm2d(2048) | |||||
| def forward(self, inp): | def forward(self, inp): | ||||
| # test manually convert to NHWC, usually used in detection head | |||||
| return self.conv1(self.conv0(inp)) | |||||
| return self.bn(self.conv1(self.conv0(inp))) | |||||
| inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) | inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) | ||||
| _compare_backward([inp], Net(), is_symbolic) | _compare_backward([inp], Net(), is_symbolic) | ||||
| # def func(x, w, b, bn_w, bn_b): | |||||
| # x = F.conv2d(x, w, b, groups=2) | |||||
| # x = F.batch_norm( | |||||
| # x, | |||||
| # running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| # running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| # weight=bn_w, | |||||
| # bias=bn_b, | |||||
| # training=True, | |||||
| # inplace=True, | |||||
| # ) | |||||
| # return x | |||||
| # data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
| # w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc") | |||||
| # b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
| # bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
| # bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
| # grads = [ | |||||
| # np.array([66, 210]).reshape((2, 1, 1, 1, 1)), | |||||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
| # ] | |||||
| # _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic) | |||||