|
|
|
@@ -304,15 +304,19 @@ class WithEvalCell(Cell): |
|
|
|
>>> eval_net = nn.WithEvalCell(net, loss_fn) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network, loss_fn): |
|
|
|
def __init__(self, network, loss_fn, add_cast_fp32=False): |
|
|
|
super(WithEvalCell, self).__init__(auto_prefix=False) |
|
|
|
self._network = network |
|
|
|
self._loss_fn = loss_fn |
|
|
|
self.add_cast_fp32 = add_cast_fp32 |
|
|
|
|
|
|
|
|
|
|
|
def construct(self, data, label): |
|
|
|
outputs = self._network(data) |
|
|
|
label = _mp_cast_helper(mstype.float32, label) |
|
|
|
loss = self._loss_fn(F.cast(outputs, mstype.float32), label) |
|
|
|
if self.add_cast_fp32: |
|
|
|
label = _mp_cast_helper(mstype.float32, label) |
|
|
|
outputs = F.cast(outputs, mstype.float32) |
|
|
|
loss = self._loss_fn(outputs, label) |
|
|
|
return loss, outputs, label |
|
|
|
|
|
|
|
|
|
|
|
|