| @@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Engine | |||||
| var data = iterator.next(); | var data = iterator.next(); | ||||
| var outputs = train_step(data[0], data[1]); | var outputs = train_step(data[0], data[1]); | ||||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | ||||
| throw new NotImplementedException(""); | |||||
| return null; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -38,7 +38,6 @@ namespace Tensorflow.Keras.Engine | |||||
| // The _minimize call does a few extra steps unnecessary in most cases, | // The _minimize call does a few extra steps unnecessary in most cases, | ||||
| // such as loss scaling and gradient clipping. | // such as loss scaling and gradient clipping. | ||||
| _minimize(tape, optimizer, loss, trainable_variables); | _minimize(tape, optimizer, loss, trainable_variables); | ||||
| compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
| return new[] { ("loss", loss) }; | return new[] { ("loss", loss) }; | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
| return control_flow_ops.no_op(); | return control_flow_ops.no_op(); | ||||
| var apply_state = _prepare(var_list); | var apply_state = _prepare(var_list); | ||||
| if(experimental_aggregate_gradients) | |||||
| // if(experimental_aggregate_gradients) | |||||
| { | { | ||||
| // var reduced_grads = _aggregate_gradients(grads_and_vars); | // var reduced_grads = _aggregate_gradients(grads_and_vars); | ||||
| _distributed_apply(grads_and_vars, name, apply_state); | _distributed_apply(grads_and_vars, name, apply_state); | ||||
| @@ -84,6 +84,9 @@ namespace Tensorflow.Keras.Optimizers | |||||
| void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | ||||
| { | { | ||||
| _resource_apply_dense(var, grad, apply_state); | _resource_apply_dense(var, grad, apply_state); | ||||
| // if var.constraint is not None: | |||||
| // with ops.control_dependencies([update_op]): | |||||
| // return var.assign(var.constraint(var)) | |||||
| } | } | ||||
| protected virtual Operation _resource_apply_dense(IVariableV1 var, | protected virtual Operation _resource_apply_dense(IVariableV1 var, | ||||
| @@ -67,6 +67,21 @@ namespace Tensorflow | |||||
| name: name); | name: name); | ||||
| } | } | ||||
| public static Tensor assign(IVariableV1 @ref, object value, | |||||
| bool validate_shape = true, | |||||
| bool use_locking = true, | |||||
| string name = null) | |||||
| { | |||||
| if (@ref.dtype.is_ref_dtype()) | |||||
| return gen_state_ops.assign(@ref, | |||||
| value, | |||||
| validate_shape: validate_shape, | |||||
| use_locking: use_locking, | |||||
| name: name); | |||||
| else | |||||
| return @ref.assign(value, name: name); | |||||
| } | |||||
| public static Tensor assign_sub(IVariableV1 @ref, | public static Tensor assign_sub(IVariableV1 @ref, | ||||
| Tensor value, | Tensor value, | ||||
| bool use_locking = false, | bool use_locking = false, | ||||