fix: inconsistent shape error while training embedding layertags/v0.110.4-Transformer-Model
| @@ -49,12 +49,25 @@ namespace Tensorflow.Framework | |||||
| public static implicit operator Tensor(IndexedSlices indexedSlices) | public static implicit operator Tensor(IndexedSlices indexedSlices) | ||||
| { | { | ||||
| return indexedSlices.values; | |||||
| return _indexed_slices_to_tensor(indexedSlices); | |||||
| } | } | ||||
| public static implicit operator IndexedSlices(Tensor tensor) | public static implicit operator IndexedSlices(Tensor tensor) | ||||
| { | { | ||||
| return tensor.Tag as IndexedSlices; | return tensor.Tag as IndexedSlices; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Converts an IndexedSlices object `value` to a Tensor. | |||||
| /// </summary> | |||||
| /// <param name="indexedSlices"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="as_ref"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false) | |||||
| { | |||||
| return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -110,6 +110,17 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
| var output_array = model.predict(input_array); | var output_array = model.predict(input_array); | ||||
| Assert.AreEqual((32, 10, 64), output_array.shape); | Assert.AreEqual((32, 10, 64), output_array.shape); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void EmbeddingGrad() | |||||
| { | |||||
| var inputs = keras.layers.Input(shape: new[] { 32, 10 }); | |||||
| var outputs = keras.layers.Embedding(1000, 64, input_length: 10).Apply(inputs); | |||||
| var model = keras.Model(inputs: inputs, outputs: outputs); | |||||
| var input_array = np.random.randint(1000, size: (1, 32, 10)); | |||||
| var output_array = np.random.random(size: (1, 32, 10, 64)); | |||||
| model.compile("rmsprop", "mse", new[] { "accuracy" }); | |||||
| model.fit(input_array, output_array); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense | /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense | ||||