diff --git a/3rdparty/operator/amsoftmax-inl.h b/3rdparty/operator/amsoftmax-inl.h new file mode 100644 index 0000000..d899bbb --- /dev/null +++ b/3rdparty/operator/amsoftmax-inl.h @@ -0,0 +1,287 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file amsoftmax-inl.h + * \brief AmSoftmax from + * \author Jia Guo + */ +#ifndef MXNET_OPERATOR_AMSOFTMAX_INL_H_ +#define MXNET_OPERATOR_AMSOFTMAX_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +namespace amsoftmax_enum { +enum AmSoftmaxOpInputs {kData, kWeight, kLabel}; +enum AmSoftmaxOpOutputs {kOut, kOOut}; +enum AmSoftmaxResource {kTempSpace}; +} + +struct AmSoftmaxParam : public dmlc::Parameter { + float margin; + float s; + int num_hidden; + int verbose; + float eps; + DMLC_DECLARE_PARAMETER(AmSoftmaxParam) { + DMLC_DECLARE_FIELD(margin).set_default(0.5).set_lower_bound(0.0) + .describe("AmSoftmax margin"); + DMLC_DECLARE_FIELD(s).set_default(64.0).set_lower_bound(1.0) + .describe("s to X"); + DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1) + .describe("Number of hidden nodes of the output"); + DMLC_DECLARE_FIELD(verbose).set_default(0) + .describe("Log for beta change"); + DMLC_DECLARE_FIELD(eps).set_default(1e-10f) + .describe("l2 eps"); + } +}; + +template +class AmSoftmaxOp : public Operator { + public: + explicit AmSoftmaxOp(AmSoftmaxParam param) { + this->param_ = param; + count_ = 0; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 3); + CHECK_EQ(out_data.size(), 2); + CHECK_EQ(req.size(), 2); + CHECK_EQ(req[amsoftmax_enum::kOut], kWriteTo); + Stream *stream = ctx.get_stream(); + const int n = in_data[amsoftmax_enum::kData].size(0); //batch size + const int m = in_data[amsoftmax_enum::kWeight].size(0);//num classes + Tensor x = in_data[amsoftmax_enum::kData].FlatTo2D(stream); + Tensor w = in_data[amsoftmax_enum::kWeight].FlatTo2D(stream); + Tensor label = in_data[amsoftmax_enum::kLabel].get_with_shape(Shape1(n), stream); + Tensor out = out_data[amsoftmax_enum::kOut].FlatTo2D(stream); + Tensor oout = out_data[amsoftmax_enum::kOOut].get_with_shape(Shape2(n,1), stream); + //Tensor workspace = ctx.requested[amsoftmax_enum::kTempSpace].get_space_typed(Shape2(n, 1), stream); +#if defined(__CUDACC__) + CHECK_EQ(stream->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; +#endif + // original fully connected + out = dot(x, w.T()); + if (ctx.is_train) { + const DType margin = static_cast(param_.margin); + const DType s = static_cast(param_.s); + AmSoftmaxForward(x, w, label, out, oout, margin, s); + } + } + + //virtual void GradNorm(mshadow::Tensor grad, mshadow::Stream* s) { + // using namespace mshadow; + // using namespace mshadow::expr; + // Tensor grad_cpu(grad.shape_); + // AllocSpace(&grad_cpu); + // Copy(grad_cpu, grad, s); + // DType grad_norm = param_.eps; + // for(uint32_t i=0;i grad, mshadow::Stream* s) { + using namespace mshadow; + using namespace mshadow::expr; + Tensor grad_cpu(grad.shape_); + AllocSpace(&grad_cpu); + Copy(grad_cpu, grad, s); + DType grad_norm = param_.eps; + for(uint32_t i=0;i tensor, mshadow::Stream* s) { + using namespace mshadow; + using namespace mshadow::expr; + Tensor tensor_cpu(tensor.shape_); + AllocSpace(&tensor_cpu); + Copy(tensor_cpu, tensor, s); + for(uint32_t i=0;i &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_data.size(), 3); + CHECK_EQ(out_data.size(), 2); + CHECK_GE(in_grad.size(), 2); + CHECK_GE(req.size(), 2); + CHECK_EQ(req[amsoftmax_enum::kData], kWriteTo); + CHECK_EQ(req[amsoftmax_enum::kWeight], kWriteTo); + Stream *stream = ctx.get_stream(); + const int n = in_data[amsoftmax_enum::kData].size(0); + const int m = in_data[amsoftmax_enum::kWeight].size(0); + Tensor x = in_data[amsoftmax_enum::kData].FlatTo2D(stream); + Tensor w = in_data[amsoftmax_enum::kWeight].FlatTo2D(stream); + Tensor label = in_data[amsoftmax_enum::kLabel].get_with_shape(Shape1(n), stream); + Tensor out = out_data[amsoftmax_enum::kOut].FlatTo2D(stream); + Tensor oout = out_data[amsoftmax_enum::kOOut].get_with_shape(Shape2(n,1), stream); + Tensor o_grad = out_grad[amsoftmax_enum::kOut].FlatTo2D(stream); + Tensor x_grad = in_grad[amsoftmax_enum::kData].FlatTo2D(stream); + Tensor w_grad = in_grad[amsoftmax_enum::kWeight].FlatTo2D(stream); + Tensor workspace = ctx.requested[amsoftmax_enum::kTempSpace].get_space_typed(Shape2(n, 1), stream); +#if defined(__CUDACC__) + CHECK_EQ(stream->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; +#endif + // original fully connected + x_grad = dot(o_grad, w); + w_grad = dot(o_grad.T(), x); + // large margin fully connected + const DType margin = static_cast(param_.margin); + const DType s = static_cast(param_.s); + AmSoftmaxBackward(x, w, label, out, oout, o_grad, x_grad, w_grad, workspace, margin, s); + count_+=1; + if (param_.verbose) { + if(count_%param_.verbose==0) { + DType n = GradNorm(x_grad, stream); + LOG(INFO)<<"x_grad norm:"< +Operator *CreateOp(AmSoftmaxParam param, int dtype); + +#if DMLC_USE_CXX11 +class AmSoftmaxProp : public OperatorProperty { + public: + void Init(const std::vector > &kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + std::vector ListArguments() const override { + return {"data", "weight", "label"}; + } + + std::vector ListOutputs() const override { + return {"output", "ooutput"}; + } + + int NumOutputs() const override { + return 2; + } + + int NumVisibleOutputs() const override { + return 1; + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 3) << "Input:[data, label, weight]"; + const TShape &dshape = in_shape->at(amsoftmax_enum::kData); + const TShape &lshape = in_shape->at(amsoftmax_enum::kLabel); + CHECK_EQ(dshape.ndim(), 2) << "data shape should be (batch_size, feature_dim)"; + CHECK_EQ(lshape.ndim(), 1) << "label shape should be (batch_size,)"; + const int n = dshape[0]; + const int feature_dim = dshape[1]; + const int m = param_.num_hidden; + SHAPE_ASSIGN_CHECK(*in_shape, amsoftmax_enum::kWeight, Shape2(m, feature_dim)); + out_shape->clear(); + out_shape->push_back(Shape2(n, m)); // output + out_shape->push_back(Shape2(n, 1)); // output + aux_shape->clear(); + return true; + } + + std::vector BackwardResource( + const std::vector &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {out_grad[amsoftmax_enum::kOut], + in_data[amsoftmax_enum::kData], + in_data[amsoftmax_enum::kWeight], in_data[amsoftmax_enum::kLabel]}; + } + + std::string TypeString() const override { + return "AmSoftmax"; + } + + OperatorProperty *Copy() const override { + auto ptr = new AmSoftmaxProp(); + ptr->param_ = param_; + return ptr; + } + + Operator *CreateOperator(Context ctx) const override { + LOG(FATAL) << "Not Implemented."; + return NULL; + } + + Operator *CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const override; + + private: + AmSoftmaxParam param_; +}; +#endif // DMLC_USE_CXX11 + +} // namespace op +} // namespace mxnet + +#endif diff --git a/3rdparty/operator/amsoftmax.cc b/3rdparty/operator/amsoftmax.cc new file mode 100644 index 0000000..4075fd3 --- /dev/null +++ b/3rdparty/operator/amsoftmax.cc @@ -0,0 +1,64 @@ +#include "./amsoftmax-inl.h" + +namespace mshadow { + +template +inline void AmSoftmaxForward(const Tensor &x, + const Tensor &w, + const Tensor &label, + const Tensor &out, + const Tensor &oout, + const DType margin, + const DType s) { + LOG(FATAL) << "Not Implemented."; +} + +template +inline void AmSoftmaxBackward(const Tensor &x, + const Tensor &w, + const Tensor &label, + const Tensor &out, + const Tensor &oout, + const Tensor &o_grad, + const Tensor &x_grad, + const Tensor &w_grad, + const Tensor &workspace, + const DType margin, + const DType s) { + LOG(FATAL) << "Not Implemented."; +} + +} // namespace mshadow + +namespace mxnet { +namespace op { + +template<> +Operator *CreateOp(AmSoftmaxParam param, int dtype) { + Operator *op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new AmSoftmaxOp(param); + }) + return op; +} + +Operator *AmSoftmaxProp::CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const { + std::vector out_shape, aux_shape; + std::vector out_type, aux_type; + CHECK(InferType(in_type, &out_type, &aux_type)); + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); + DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); +} + +DMLC_REGISTER_PARAMETER(AmSoftmaxParam); + +MXNET_REGISTER_OP_PROPERTY(AmSoftmax, AmSoftmaxProp) +.describe("AmSoftmax from ") +.add_argument("data", "Symbol", "data") +.add_argument("weight", "Symbol", "weight") +.add_argument("label", "Symbol", "label") +.add_arguments(AmSoftmaxParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/3rdparty/operator/amsoftmax.cu b/3rdparty/operator/amsoftmax.cu new file mode 100644 index 0000000..dfd7f1f --- /dev/null +++ b/3rdparty/operator/amsoftmax.cu @@ -0,0 +1,195 @@ +#include "./amsoftmax-inl.h" +#include + +namespace mshadow { +namespace cuda { + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + + +template +__global__ void AmSoftmaxForwardKernel(const Tensor x, + const Tensor w, + const Tensor label, + Tensor out, + Tensor oout, + const DType margin, + const DType s) { + const int n = x.size(0); //batch size + const int feature_dim = x.size(1); //embedding size, 512 for example + const int m = w.size(0);//num classes + const DType cos_m = cos(margin); + const DType sin_m = sin(margin); + CUDA_KERNEL_LOOP(i, n) { + const int yi = static_cast(label[i]); + const DType fo_i_yi = out[i][yi]; + oout[i][0] = fo_i_yi; + if(fo_i_yi>=0.0) { + const DType cos_t = fo_i_yi / s; + const DType sin_t = sqrt(1.0-cos_t*cos_t); + out[i][yi] = fo_i_yi*cos_m - (s*sin_t*sin_m); + } + } +} + +template +inline void AmSoftmaxForward(const Tensor &x, + const Tensor &w, + const Tensor &label, + const Tensor &out, + const Tensor &oout, + const DType margin, + const DType s) { + const int n = x.size(0); + const int m = w.size(0); + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid((n + kBaseThreadNum - 1) / kBaseThreadNum); + AmSoftmaxForwardKernel<<>>(x, w, label, out, oout, margin, s); +} + + +template +__global__ void AmSoftmaxBackwardXKernel(const Tensor x, + const Tensor w, + const Tensor label, + const Tensor out, + const Tensor oout, + const Tensor o_grad, + Tensor x_grad, + const Tensor workspace, + const DType margin, + const DType s) { + const int nthreads = x.size(0) * x.size(1); + //const int nthreads = x.size(0); + const int feature_dim = x.size(1); + const DType cos_m = cos(margin); + const DType nsin_m = sin(margin)*-1.0; + const DType ss = s*s; + CUDA_KERNEL_LOOP(idx, nthreads) { + const int i = idx / feature_dim; + const int l = idx % feature_dim; + //const int i = idx; + const int yi = static_cast(label[i]); + if(oout[i][0]>=0.0) { + //x_grad[i][l] -= o_grad[i][yi] * w[yi][l]; + //c = 1-cost*cost, = sint*sint + const DType cost = oout[i][0]/s; + const DType c = 1.0-cost*cost; + const DType dc_dx = -2.0/ss*oout[i][0]*w[yi][l]; + const DType d_sint_dc = 1.0/(2*sqrt(c)); + const DType d_sint_dx = dc_dx*d_sint_dc; + const DType df_dx = cos_m*w[yi][l] + s*nsin_m*d_sint_dx; + x_grad[i][l] += o_grad[i][yi] * (df_dx - w[yi][l]); + } + } +} + +template +__global__ void AmSoftmaxBackwardWKernel(const Tensor x, + const Tensor w, + const Tensor label, + const Tensor out, + const Tensor oout, + const Tensor o_grad, + Tensor w_grad, + const Tensor workspace, + const DType margin, + const DType s) { + const int nthreads = w.size(0) * w.size(1); + const int n = x.size(0); + const int feature_dim = w.size(1); + const DType cos_m = cos(margin); + const DType nsin_m = sin(margin)*-1.0; + const DType ss = s*s; + CUDA_KERNEL_LOOP(idx, nthreads) { + const int j = idx / feature_dim; + const int l = idx % feature_dim; + DType dw = 0; + for (int i = 0; i < n; ++i) { + const int yi = static_cast(label[i]); + if (yi == j&&oout[i][0]>=0.0) { + const DType cost = oout[i][0]/s; + const DType c = 1.0-cost*cost; + const DType dc_dw = -2.0/ss*oout[i][0]*x[i][l]; + const DType d_sint_dc = 1.0/(2*sqrt(c)); + const DType d_sint_dw = dc_dw*d_sint_dc; + const DType df_dw = cos_m*x[i][l] + s*nsin_m*d_sint_dw; + dw += o_grad[i][yi] * (df_dw - x[i][l]); + } + } + w_grad[j][l] += dw; + } +} + +template +inline void AmSoftmaxBackward(const Tensor &x, + const Tensor &w, + const Tensor &label, + const Tensor &out, + const Tensor &oout, + const Tensor &o_grad, + const Tensor &x_grad, + const Tensor &w_grad, + const Tensor &workspace, + const DType margin, + const DType s) { + const int n = x.size(0); + const int feature_dim = x.size(1); + const int m = w.size(0); + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid((n + kBaseThreadNum - 1) / kBaseThreadNum); + dimGrid.x = ((n * feature_dim + kBaseThreadNum - 1) / kBaseThreadNum); + AmSoftmaxBackwardXKernel<<>>(x, w, label, out, oout, o_grad, x_grad, workspace, + margin, s); + dimGrid.x = ((m * feature_dim + kBaseThreadNum - 1) / kBaseThreadNum); + AmSoftmaxBackwardWKernel<<>>(x, w, label, out, oout, o_grad, w_grad, workspace, + margin, s); +} + +} // namespace cuda + +template +inline void AmSoftmaxForward(const Tensor &x, + const Tensor &w, + const Tensor &label, + const Tensor &out, + const Tensor &oout, + const DType margin, + const DType s) { + cuda::AmSoftmaxForward(x, w, label, out, oout, margin, s); +} + +template +inline void AmSoftmaxBackward(const Tensor &x, + const Tensor &w, + const Tensor &label, + const Tensor &out, + const Tensor &oout, + const Tensor &o_grad, + const Tensor &x_grad, + const Tensor &w_grad, + const Tensor &workspace, + const DType margin, + const DType s) { + cuda::AmSoftmaxBackward(x, w, label, out, oout, o_grad, x_grad, w_grad, workspace, margin, s); +} + +} // namespace mshadow + +namespace mxnet { +namespace op { + +template<> +Operator *CreateOp(AmSoftmaxParam param, int dtype) { + Operator *op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new AmSoftmaxOp(param); + }) + return op; +} + +} // namespace op +} // namespace mxnet diff --git a/3rdparty/operator/lsoftmax-inl.h b/3rdparty/operator/lsoftmax-inl.h index 33d51bf..a457eb4 100644 --- a/3rdparty/operator/lsoftmax-inl.h +++ b/3rdparty/operator/lsoftmax-inl.h @@ -78,6 +78,7 @@ class LSoftmaxOp : public Operator { float _beta = std::atof(env_p); if (param_.verbose) { LOG(INFO)<<"beta:"<<_beta; + LOG(INFO)<