This commit is contained in:
Jia Guo
2018-01-11 19:54:23 +08:00
parent 7e3ff9614c
commit 0e8fdb70a9

View File

@@ -269,7 +269,10 @@ def get_symbol(args, arg_params, aux_params):
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
new_zy = mx.sym.where(cond, new_zy, zy)
zy_keep = zy
_zy = sin_t*(-1.0*s*m)
zy_keep += _zy
new_zy = mx.sym.where(cond, new_zy, zy_keep)
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)