|
|
|
@@ -118,11 +118,9 @@ const size_t UndeterminedShapeType::fields_num = 6; |
|
|
|
|
|
|
|
std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs; |
|
|
|
void InitUndeterminedFromEnv(const std::string &sparse_shape_types) { |
|
|
|
if (!g_undetermined_configs.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
std::string tmp; |
|
|
|
std::stringstream input(sparse_shape_types); |
|
|
|
g_undetermined_configs.clear(); |
|
|
|
while (std::getline(input, tmp, ';')) { |
|
|
|
auto config = UndeterminedShapeType(tmp); |
|
|
|
g_undetermined_configs.insert(std::make_pair(config.param_name(), config)); |
|
|
|
@@ -145,17 +143,19 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt |
|
|
|
|
|
|
|
if (!key->sparse_grad().empty()) { |
|
|
|
// Will be fixed once undetermined type ready |
|
|
|
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); |
|
|
|
if (sparse_shape_types.empty()) { |
|
|
|
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; |
|
|
|
if (g_undetermined_configs.empty()) { |
|
|
|
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); |
|
|
|
MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types; |
|
|
|
if (sparse_shape_types.empty()) { |
|
|
|
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; |
|
|
|
} |
|
|
|
InitUndeterminedFromEnv(sparse_shape_types); |
|
|
|
} |
|
|
|
InitUndeterminedFromEnv(sparse_shape_types); |
|
|
|
|
|
|
|
auto shape_types = g_undetermined_configs.find(key->sparse_grad()); |
|
|
|
if (shape_types == g_undetermined_configs.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Param " << key->ToString() |
|
|
|
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES: " |
|
|
|
<< sparse_shape_types; |
|
|
|
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString(); |
|
|
|
AbstractBasePtrList sparse_list; |
|
|
|
|