This commit is contained in:
nttstar
2019-02-15 14:59:09 +08:00
parent 61b371fd63
commit 4d059ccf8d

View File

@@ -13,7 +13,7 @@ class AccMetric(mx.metric.EvalMetric):
def update(self, labels, preds):
self.count+=1
label = labels[0]
pred_label = preds[0]
pred_label = preds[1]
#print('ACC', label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)