You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.CrossEntropyLoss.rst 918 B

1234567891011121314151617
  1. .. py:class:: mindspore.nn.transformer.CrossEntropyLoss(parallel_config=default_dpmp_config)
  2. 计算输入和输出之间的交叉熵损失。
  3. **参数:**
  4. - **parallel_config** (OpParallelConfig) - 表示并行配置。默认值为 `default_dpmp_config` ,表示一个带有默认参数的 `OpParallelConfig` 实例。
  5. **输入:**
  6. - **logits** (Tensor) - shape为(N, C)的Tensor。表示的输出logits。其中N表示任意大小的维度,C表示类别个数。数据类型必须为float16或float32。
  7. - **labels** (Tensor) - shape为(N, )的Tensor。表示样本的真实标签,其中每个元素的取值区间为[0,C)。
  8. - **input_mask** (Tensor) - shape为(N, )的Tensor。input_mask表示是否有填充输入。1表示有效,0表示无效,其中元素值为0的位置不会计算进损失值。
  9. **输出:**
  10. Tensor,表示对应的交叉熵损失。