diff --git a/src/train_softmax.py b/src/train_softmax.py index 4856137..77ac68a 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -269,7 +269,10 @@ def get_symbol(args, arg_params, aux_params): b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s - new_zy = mx.sym.where(cond, new_zy, zy) + zy_keep = zy + _zy = sin_t*(-1.0*s*m) + zy_keep += _zy + new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)