| @@ -31,11 +31,18 @@ namespace Tensorflow.Framework | |||||
| _values = values; | _values = values; | ||||
| _indices = indices; | _indices = indices; | ||||
| _dense_shape = dense_shape; | _dense_shape = dense_shape; | ||||
| _values.Tag = this; | |||||
| } | } | ||||
| public static implicit operator Tensor(IndexedSlices indexedSlices) | public static implicit operator Tensor(IndexedSlices indexedSlices) | ||||
| { | { | ||||
| return indexedSlices.values; | return indexedSlices.values; | ||||
| } | } | ||||
| public static implicit operator IndexedSlices(Tensor tensor) | |||||
| { | |||||
| return tensor.Tag as IndexedSlices; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -156,7 +156,7 @@ namespace Tensorflow.Gradients | |||||
| // For axis 0 gathers, build an appropriately shaped IndexedSlices. | // For axis 0 gathers, build an appropriately shaped IndexedSlices. | ||||
| if((int)axis_static == 0) | if((int)axis_static == 0) | ||||
| { | { | ||||
| var params_tail_shape = params_shape[1]; | |||||
| var params_tail_shape = params_shape[new NumSharp.Slice(start:1)]; | |||||
| var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); | var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); | ||||
| var values = array_ops.reshape(grad, values_shape); | var values = array_ops.reshape(grad, values_shape); | ||||
| indices = array_ops.reshape(indices, indices_size); | indices = array_ops.reshape(indices, indices_size); | ||||
| @@ -223,8 +223,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx }); | var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx }); | ||||
| // TODO | // TODO | ||||
| throw new NotImplementedException("_result = _UniqueOutput._make(_result)"); | |||||
| // return _op.outputs[0]; | |||||
| //var _result = _UniqueOutput._make(_op.outputs); | |||||
| return (_op.outputs[0], _op.outputs[1]); | |||||
| } | } | ||||
| public static Tensor where() | public static Tensor where() | ||||
| @@ -58,6 +58,11 @@ namespace Tensorflow | |||||
| private TF_Output? _tf_output; | private TF_Output? _tf_output; | ||||
| /// <summary> | |||||
| /// used for keep other pointer when do implicit operating | |||||
| /// </summary> | |||||
| public object Tag { get; set; } | |||||
| public int[] shape | public int[] shape | ||||
| { | { | ||||
| get | get | ||||
| @@ -219,11 +224,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public Tensor this[int start, int? stop, int? step] | |||||
| public Tensor this[Slice slice] | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| var slice_spec = new int[] { start }; | |||||
| var slice_spec = new int[] { slice.Start.Value }; | |||||
| var begin = new List<int>(); | var begin = new List<int>(); | ||||
| var end = new List<int>(); | var end = new List<int>(); | ||||
| var strides = new List<int>(); | var strides = new List<int>(); | ||||
| @@ -236,14 +241,16 @@ namespace Tensorflow | |||||
| foreach (var s in slice_spec) | foreach (var s in slice_spec) | ||||
| { | { | ||||
| begin.Add(s); | begin.Add(s); | ||||
| if (stop == null) | |||||
| if(slice.Stop.HasValue) | |||||
| { | |||||
| end.Add(slice.Stop.Value); | |||||
| } | |||||
| else | |||||
| { | { | ||||
| end.Add(0); | end.Add(0); | ||||
| end_mask |= (1 << index); | end_mask |= (1 << index); | ||||
| } | } | ||||
| else | |||||
| end.Add(s + 1); | |||||
| strides.Add(1); | |||||
| strides.Add(slice.Step); | |||||
| index += 1; | index += 1; | ||||
| } | } | ||||
| @@ -277,7 +284,57 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public Tensor this[int slice_spec] => this[slice_spec, null, null]; | |||||
| public Tensor this[int start] | |||||
| { | |||||
| get | |||||
| { | |||||
| var slice_spec = new int[] { start }; | |||||
| var begin = new List<int>(); | |||||
| var end = new List<int>(); | |||||
| var strides = new List<int>(); | |||||
| var index = 0; | |||||
| var (new_axis_mask, shrink_axis_mask) = (0, 0); | |||||
| var (begin_mask, end_mask) = (0, 0); | |||||
| var ellipsis_mask = 0; | |||||
| foreach (var s in slice_spec) | |||||
| { | |||||
| begin.Add(s); | |||||
| end.Add(s + 1); | |||||
| strides.Add(1); | |||||
| shrink_axis_mask |= (1 << index); | |||||
| index += 1; | |||||
| } | |||||
| return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| { | |||||
| string name = scope; | |||||
| if (begin != null) | |||||
| { | |||||
| var (packed_begin, packed_end, packed_strides) = | |||||
| (array_ops.stack(begin.ToArray()), | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | |||||
| this, | |||||
| packed_begin, | |||||
| packed_end, | |||||
| packed_strides, | |||||
| begin_mask: begin_mask, | |||||
| end_mask: end_mask, | |||||
| shrink_axis_mask: shrink_axis_mask, | |||||
| new_axis_mask: new_axis_mask, | |||||
| ellipsis_mask: ellipsis_mask, | |||||
| name: name); | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| }); | |||||
| } | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| @@ -227,9 +227,8 @@ namespace Tensorflow | |||||
| public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) | public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) | ||||
| { | { | ||||
| var (unique_indices, new_index_positions) = array_ops.unique(indices); | var (unique_indices, new_index_positions) = array_ops.unique(indices); | ||||
| var summed_values = math_ops.unsorted_segment_sum( | |||||
| values, new_index_positions, | |||||
| array_ops.shape(unique_indices)[0]); | |||||
| var shape = array_ops.shape(unique_indices)[0]; | |||||
| var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape); | |||||
| return (summed_values, unique_indices); | return (summed_values, unique_indices); | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -29,14 +29,16 @@ namespace Tensorflow | |||||
| public Operation update_op(Optimizer optimizer, Tensor g) | public Operation update_op(Optimizer optimizer, Tensor g) | ||||
| { | { | ||||
| var update_op = optimizer._apply_dense(g, _v); | |||||
| return update_op; | |||||
| } | |||||
| public Operation update_op(Optimizer optimizer, IndexedSlices g) | |||||
| { | |||||
| var update_op = optimizer._apply_dense(g, _v); | |||||
| Operation update_op = null; | |||||
| if (g.Tag == null) | |||||
| { | |||||
| update_op = optimizer._apply_dense(g, _v); | |||||
| } | |||||
| else if (g.Tag is IndexedSlices) | |||||
| { | |||||
| return optimizer._apply_sparse_duplicate_indices(g, _v); | |||||
| } | |||||
| return update_op; | return update_op; | ||||
| } | } | ||||