|
|
|
@@ -26,6 +26,7 @@ class MXNetNode |
|
|
|
{ |
|
|
|
public: |
|
|
|
bool has_attr(const char* key) const; |
|
|
|
bool is_attr_scalar(const char* key) const; |
|
|
|
|
|
|
|
class AttrProxy |
|
|
|
{ |
|
|
|
@@ -102,6 +103,18 @@ bool MXNetNode::has_attr(const char* key) const |
|
|
|
return it != attrs.end(); |
|
|
|
} |
|
|
|
|
|
|
|
bool MXNetNode::is_attr_scalar(const char* key) const |
|
|
|
{ |
|
|
|
const std::map<std::string, std::string>::const_iterator it = attrs.find(key); |
|
|
|
if (it == attrs.end()) |
|
|
|
return false; |
|
|
|
|
|
|
|
if (it->second.empty()) |
|
|
|
return false; |
|
|
|
|
|
|
|
return it->second[0] != '('; |
|
|
|
} |
|
|
|
|
|
|
|
int MXNetNode::attr_i(const char* key) const |
|
|
|
{ |
|
|
|
const std::map<std::string, std::string>::const_iterator it = attrs.find(key); |
|
|
|
@@ -154,6 +167,12 @@ std::vector<int> MXNetNode::attr_ai(const char* key) const |
|
|
|
// (1,2,3) |
|
|
|
std::vector<int> list; |
|
|
|
|
|
|
|
if (is_attr_scalar(key)) |
|
|
|
{ |
|
|
|
list.push_back(attr_i(key)); |
|
|
|
return list; |
|
|
|
} |
|
|
|
|
|
|
|
int i = 0; |
|
|
|
int c = 0; |
|
|
|
int nconsumed = 0; |
|
|
|
@@ -200,6 +219,12 @@ std::vector<float> MXNetNode::attr_af(const char* key) const |
|
|
|
// (0.1,0.2,0.3) |
|
|
|
std::vector<float> list; |
|
|
|
|
|
|
|
if (is_attr_scalar(key)) |
|
|
|
{ |
|
|
|
list.push_back(attr_f(key)); |
|
|
|
return list; |
|
|
|
} |
|
|
|
|
|
|
|
float i = 0.f; |
|
|
|
int c = 0; |
|
|
|
int nconsumed = 0; |
|
|
|
|