GitOrigin-RevId: fd2fe8bec9
tags/v1.9.0
| @@ -230,7 +230,7 @@ for name, mode in [ | |||||
| def subgraph( | def subgraph( | ||||
| name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False | name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False | ||||
| ): | ): | ||||
| if device.physical_name.startswith("cpu"): | |||||
| if not device.physical_name.startswith("gpu"): | |||||
| gopt_level = None # disable jit and compile | gopt_level = None # disable jit and compile | ||||
| jit_fusion = False | jit_fusion = False | ||||
| @@ -370,7 +370,15 @@ def subgraph_fn( | |||||
| jit_fusion=jit_fusion, | jit_fusion=jit_fusion, | ||||
| custom_grad=custom_grad, | custom_grad=custom_grad, | ||||
| )(func) | )(func) | ||||
| return lambda *args: apply(op(), *args) | |||||
| def wrapped_func(*args): | |||||
| if custom_grad: | |||||
| outputs = op()(*args) | |||||
| else: | |||||
| outputs = apply(op(), *args) | |||||
| return outputs | |||||
| return wrapped_func | |||||
| else: | else: | ||||
| return interpret_subgraph(func, dtype, device) | return interpret_subgraph(func, dtype, device) | ||||
| @@ -988,7 +988,6 @@ def _get_softplus_op(dtype=None, device=None): | |||||
| device=device, | device=device, | ||||
| nr_inputs=1, | nr_inputs=1, | ||||
| jit_fusion=True, | jit_fusion=True, | ||||
| # gopt_level=0, | |||||
| custom_grad=True, | custom_grad=True, | ||||
| ) | ) | ||||
| def softplus(inputs, f, c): | def softplus(inputs, f, c): | ||||
| @@ -18,14 +18,7 @@ from ..core.ops import builtin | |||||
| from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor.array_method import _broadcast, _remove_axis | from ..core.tensor.array_method import _broadcast, _remove_axis | ||||
| from ..core.tensor.utils import ( | |||||
| astensor1d, | |||||
| convert_inputs, | |||||
| get_device, | |||||
| isscalar, | |||||
| setscalar, | |||||
| subgraph_fn, | |||||
| ) | |||||
| from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | |||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .elemwise import ceil | from .elemwise import ceil | ||||
| @@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||||
| where = _get_where_op(dtype=dtype, device=device) | where = _get_where_op(dtype=dtype, device=device) | ||||
| (oup,) = where(mask, x, y) | (oup,) = where(mask, x, y) | ||||
| if isscalar(mask): | |||||
| setscalar(oup) | |||||
| return oup | return oup | ||||
| @@ -67,7 +67,7 @@ void init_common(py::module m) { | |||||
| [](const CompNode& cn) { return cn.to_string_logical(); }) | [](const CompNode& cn) { return cn.to_string_logical(); }) | ||||
| .def_property_readonly( | .def_property_readonly( | ||||
| "physical_name", | "physical_name", | ||||
| [](const CompNode& cn) { return cn.to_string(); }) | |||||
| [](const CompNode& cn) { return cn.to_string_physical(); }) | |||||
| .def_property_readonly( | .def_property_readonly( | ||||
| "get_mem_status_bytes", | "get_mem_status_bytes", | ||||
| [](const CompNode& cn) { | [](const CompNode& cn) { | ||||