Another optimization of moving the bias term addition to the matrix multiplication.

This commit is contained in:
Tadas Baltrusaitis
2017-08-23 20:43:19 +01:00
parent b57bdd2127
commit a2a3bd08c2

View File

@@ -537,7 +537,7 @@ void convolution_direct(std::vector<cv::Mat_<float> >& outputs, const std::vecto
int yB = height_in - height_k + 1;
int xB = width_n - width_k + 1;
cv::Mat_<float> input_matrix(input_maps.size() * height_k * width_k, yB * xB);
cv::Mat_<float> input_matrix(input_maps.size() * height_k * width_k + 1.0, yB * xB, 1.0f);
// Comibine im2col accross channels to prepare for matrix multiplication
for (size_t i = 0; i < input_maps.size(); ++i)
@@ -545,15 +545,13 @@ void convolution_direct(std::vector<cv::Mat_<float> >& outputs, const std::vecto
im2col_bias(input_maps[i], width_k, height_k, input_matrix(cv::Rect(0, i * height_k * width_k, yB * xB, height_k * width_k)));
}
// Actual multiplication
// Actual convolution (through multiplication)
cv::Mat_<float> out = weight_matrix * input_matrix;
// Move back to vectors and reshape accordingly (also add the bias)
for (size_t k = 0; k < out.rows; ++k)
{
cv::Mat_<float> reshaped = out.row(k) + biases[k];
reshaped = reshaped.reshape(1, yB);
outputs.push_back(reshaped);
outputs.push_back(out.row(k).reshape(1, yB));
}
}
@@ -844,8 +842,19 @@ void CNN::Read(const string& location)
k_flat.copyTo(weight_matrix(cv::Rect(k, i * kernels_rearr[0][0].rows * kernels_rearr[0][0].cols, 1, kernels_rearr[0][0].rows * kernels_rearr[0][0].cols)));
}
}
// Transpose the weight matrix for more convenient computation
cnn_convolutional_layers_weights.push_back(weight_matrix.t());
weight_matrix = weight_matrix.t();
// Add a bias term to the weight matrix for efficiency
cv::Mat_<float> W(weight_matrix.rows, weight_matrix.cols + 1, 1.0);
for (size_t k = 0; k < weight_matrix.rows; ++k)
{
W.at<float>(k, weight_matrix.cols) = biases[k];
}
weight_matrix.copyTo(W(cv::Rect(0, 0, weight_matrix.cols, weight_matrix.rows)));
cnn_convolutional_layers_weights.push_back(W);
}
else if (layer_type == 1)