|
|
|
@@ -154,7 +154,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): |
|
|
|
loss_scale = loss_scale_manager.get_loss_scale() |
|
|
|
update_cell = loss_scale_manager.get_update_cell() |
|
|
|
if update_cell is not None: |
|
|
|
if not (context.get_context("enable_ge") or (context.get_context("device_target") == "GPU")): |
|
|
|
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow. |
|
|
|
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": |
|
|
|
raise ValueError("Only `loss_scale_manager=None` and " |
|
|
|
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" |
|
|
|
"are supported in current version. If you use `O2` option, please" |
|
|
|
|