#include "./amsoftmax-inl.h" namespace mshadow { template inline void AmSoftmaxForward(const Tensor &x, const Tensor &w, const Tensor &label, const Tensor &out, const Tensor &oout, const DType margin, const DType s) { LOG(FATAL) << "Not Implemented."; } template inline void AmSoftmaxBackward(const Tensor &x, const Tensor &w, const Tensor &label, const Tensor &out, const Tensor &oout, const Tensor &o_grad, const Tensor &x_grad, const Tensor &w_grad, const Tensor &workspace, const DType margin, const DType s) { LOG(FATAL) << "Not Implemented."; } } // namespace mshadow namespace mxnet { namespace op { template<> Operator *CreateOp(AmSoftmaxParam param, int dtype) { Operator *op = NULL; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new AmSoftmaxOp(param); }) return op; } Operator *AmSoftmaxProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { std::vector out_shape, aux_shape; std::vector out_type, aux_type; CHECK(InferType(in_type, &out_type, &aux_type)); CHECK(InferShape(in_shape, &out_shape, &aux_shape)); DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); } DMLC_REGISTER_PARAMETER(AmSoftmaxParam); MXNET_REGISTER_OP_PROPERTY(AmSoftmax, AmSoftmaxProp) .describe("AmSoftmax from ") .add_argument("data", "Symbol", "data") .add_argument("weight", "Symbol", "weight") .add_argument("label", "Symbol", "label") .add_arguments(AmSoftmaxParam::__FIELDS__()); } // namespace op } // namespace mxnet