Browse Source

!23158 Add note for api

Merge pull request !23158 from liuyang/master_note
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
67c1395cb4
3 changed files with 28 additions and 1 deletions
  1. +17
    -0
      mindspore/nn/cell.py
  2. +6
    -0
      mindspore/nn/wrap/loss_scale.py
  3. +5
    -1
      mindspore/train/loss_scale_manager.py

+ 17
- 0
mindspore/nn/cell.py View File

@@ -328,6 +328,9 @@ class Cell(Cell_):
return tuple(res)

def cast_inputs(self, inputs, dst_type):
"""
Cast inputs to specified type.
"""
res = list()
for item in inputs:
if isinstance(item, tuple):
@@ -971,6 +974,9 @@ class Cell(Cell_):
yield param

def check_names(self):
"""
Check the names of cell parameters.
"""
names = set("")
for value, param in self.parameters_and_names():
if param.name in names:
@@ -1116,6 +1122,11 @@ class Cell(Cell_):
return cells

def add_flags(self, **flags):
"""
Add customized attributes for cell.

This method is also called when the cell class is instantiated and the class parameter 'flag' is set to True.
"""
if not hasattr(self, "_mindspore_flags"):
self._mindspore_flags = {}
self._mindspore_flags.update({**flags})
@@ -1123,6 +1134,9 @@ class Cell(Cell_):
return self

def add_flags_recursive(self, **flags):
"""
If a cell contains child cells, this method can recursively customize attributes of all cells.
"""
self.add_flags(**flags)
for cell in self.cells():
cell.add_flags_recursive(**flags)
@@ -1133,6 +1147,9 @@ class Cell(Cell_):
self._cell_init_args += str({**args})

def get_flags(self):
"""
Get the attributes of cell's flags.
"""
if not hasattr(self, "_mindspore_flags"):
self._mindspore_flags = {}
return self._mindspore_flags


+ 6
- 0
mindspore/nn/wrap/loss_scale.py View File

@@ -134,6 +134,9 @@ class DynamicLossScaleUpdateCell(Cell):
self.const_true = Tensor(True, dtype=mstype.bool_)

def get_loss_scale(self):
"""
Get Loss Scale value.
"""
return self.loss_scale_value

def construct(self, loss_scale, overflow):
@@ -205,6 +208,9 @@ class FixedLossScaleUpdateCell(Cell):
self.loss_scale_value = loss_scale_value

def get_loss_scale(self):
"""
Get Loss Scale value.
"""
return self.loss_scale_value

def construct(self, _, overflow):


+ 5
- 1
mindspore/train/loss_scale_manager.py View File

@@ -19,7 +19,11 @@ from .. import nn


class LossScaleManager:
"""Loss scale manager abstract class."""
"""
Loss scale manager abstract class.

Derive FixedLossScaleManager and DynamicLossScaleManager that override all LossScaleManager's method.
"""
def get_loss_scale(self):
"""Get loss scale value."""



Loading…
Cancel
Save