|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Controlling dump behavior."""
- from warnings import warn
-
- import mindspore.context as context
- from mindspore._c_expression import security
-
-
- def set_dump(target, enabled=True):
- """
- Enable or disable dump for the target and its contents.
-
- Target should be an instance of Cell or Primitive. The default enabled
- status for a cell or primitive is False. Please note that this API takes
- effect only when the dump_mode field in dump config file is 2. See the
- `dump document <https://mindspore.cn/docs/programming_guide/zh-CN/master/dump_in_graph_mode.html>`_
- for details.
-
- .. warning::
- This is an experimental prototype that is subject to change or deletion.
-
- Note:
- 1. This API is only effective for GRAPH_MODE with Ascend backend.
- 2. When target is a cell and enabled is True, this API will the enable
- dump for the primitive operator members of the cell instance and
- its child cell instances recursively. If an operator is not a
- member of the cell instance, the dump flag will not be set for
- this operator (e.g. functional operators used directly in
- construct method). To make this API effective, please use
- self.some_op = SomeOp() in your cell's __init__ method.
- 3. After using set_dump(cell, True), operators in forward computation
- of the cell will be dumped. Most backward computation (computation
- generated by the grad operations) will not be dumped by design.
- However, due to the graph optimization, a few backward computation
- data will still be dumped. You can ignore the backward computation
- data which contains "Gradients" in their filenames.
- 4. This API is not designed to use in the middle of training process.
- If you call this API in the middle of training, it only takes effect
- for the later compiled graphs. If there is no new graph compilation,
- you will see no effect.
- 5. For operator SparseSoftmaxCrossEntropyWithLogits, the forward
- computation and backward computation use the same set of
- operators. So you can only see dump data from backward computation.
- Please note that operator SoftmaxCrossEntropyWithLogits will also use
- the above operator internally when initialized with sparse=True and
- reduction="mean".
-
- Args:
- target (Union[Cell, Primitive]): The Cell instance or Primitive instance
- to which the dump flag is set.
- enabled (bool): True means enable dump, False means disable dump.
- Default: True.
-
- Supported Platforms:
- ``Ascend``
-
- Examples:
- >>> # Please set the dump config file and environment variable before
- >>> # running this example to actually get the dump data.
- >>> # See the document of this API for details.
- >>> import numpy as np
- >>>
- >>> import mindspore.nn as nn
- >>> import mindspore.context as context
- >>> from mindspore import Tensor, set_dump
- >>>
- >>> context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
- >>>
- >>> class MyNet(nn.Cell):
- ... def __init__(self):
- ... super().__init__()
- ... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
- ... self.relu1 = nn.ReLU()
- ...
- ... def construct(self, x):
- ... x = self.conv1(x)
- ... x = self.relu1(x)
- ... return x
- >>>
- >>> net = MyNet()
- >>> set_dump(net.conv1)
- >>> input_tensor = Tensor(np.ones([1, 5, 10, 10], dtype=np.float32))
- >>> net(input_tensor)
- """
- if security.enable_security():
- raise ValueError('The set_dump API is not supported, please recompile '
- 'source without "-s on".')
-
- import mindspore.nn as nn # avoid circular import
- from mindspore.ops import Primitive
- if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
- raise ValueError(f"The \"target\" parameter must be an instance of "
- f"Cell or Primitive, "
- f"but got an instance of {type(target)}.")
-
- if not isinstance(enabled, bool):
- raise ValueError("The \"enabled\" parameter must be bool.")
-
- # Checking for device target and mode.
- current_target = context.get_context("device_target")
- if current_target != "Ascend":
- # We will not return here in case user changed device_target later.
- warn("Current device_target is {}, which is not supported by set_dump. "
- "Only Ascend device target is supported currently. "
- "If you have Ascend device, consider set device_target to Ascend "
- "before calling set_dump.".format(current_target))
-
- current_mode = context.get_context("mode")
- if current_mode != context.GRAPH_MODE:
- # We will not return here in case user changed mode later.
- warn(
- "Current mode is PYNATIVE_MODE, which is not supported by set_dump. "
- "Only GRAPH_MODE is supported currently. "
- "Consider set mode to GRAPH_MODE "
- "before calling set_dump.")
-
- # The actual set dump logic.
- mode = "true" if enabled else "false"
- if isinstance(target, nn.Cell):
- primitives = getattr(target, "_primitives", {})
- for value in primitives.values():
- if value:
- value.add_prim_attr("dump", mode)
- for cell in target.cells():
- set_dump(cell, enabled)
- return
-
- if isinstance(target, Primitive):
- target.add_prim_attr("dump", mode)
- return
|