From 6dbed701bcf10166ab60a8577ccf0b5c58fb9d8f Mon Sep 17 00:00:00 2001 From: Tadas Baltrusaitis Date: Fri, 11 Aug 2017 14:56:45 -0400 Subject: [PATCH] Fix with CNN inference sizes. --- .../src/FaceDetectorMTCNN.cpp | 83 ++++++++++++++----- 1 file changed, 64 insertions(+), 19 deletions(-) diff --git a/lib/local/LandmarkDetector/src/FaceDetectorMTCNN.cpp b/lib/local/LandmarkDetector/src/FaceDetectorMTCNN.cpp index 5b88c660..94969ed3 100644 --- a/lib/local/LandmarkDetector/src/FaceDetectorMTCNN.cpp +++ b/lib/local/LandmarkDetector/src/FaceDetectorMTCNN.cpp @@ -269,27 +269,45 @@ std::vector> CNN::Inference(const cv::Mat& input_img) } if (layer_type == 2) { - // Concatenate all the maps - cv::Mat_ input_concat = input_maps[0].t(); - input_concat = input_concat.reshape(0, 1); - for (size_t in = 1; in < input_maps.size(); ++in) + if(input_maps.size() > 1) { - cv::Mat_ add = input_maps[in].t(); - add = add.reshape(0, 1); - cv::vconcat(input_concat, add, input_concat); + // Concatenate all the maps + cv::Size orig_size = input_maps[0].size(); + cv::Mat_ input_concat = input_maps[0].t(); + input_concat = input_concat.reshape(0, 1); + + for (size_t in = 1; in < input_maps.size(); ++in) + { + cv::Mat_ add = input_maps[in].t(); + add = add.reshape(0, 1); + cv::vconcat(input_concat, add, input_concat); + } + + input_concat = input_concat.t() * cnn_fully_connected_layers_weights[fully_connected_layer]; + + // Add biases + for (size_t k = 0; k < cnn_fully_connected_layers_biases[fully_connected_layer].rows; ++k) + { + input_concat.col(k) = input_concat.col(k) + cnn_fully_connected_layers_biases[fully_connected_layer].at(k); + } + + outputs.clear(); + // Resize and add as output + for (size_t k = 0; k < cnn_fully_connected_layers_biases[fully_connected_layer].rows; ++k) + { + cv::Mat_ reshaped = input_concat.col(k).clone(); + reshaped = reshaped.reshape(1, orig_size.width).t(); + outputs.push_back(reshaped); + } } - - input_concat = input_concat.t() * cnn_fully_connected_layers_weights[fully_connected_layer]; - - for (size_t k = 0; k < cnn_fully_connected_layers_biases[fully_connected_layer].rows; ++k) + else { - input_concat.col(k) = input_concat.col(k) + cnn_fully_connected_layers_biases[fully_connected_layer].at(k); + cv::Mat out = input_maps[0].t() * cnn_fully_connected_layers_weights[fully_connected_layer] + cnn_fully_connected_layers_biases[fully_connected_layer].t(); + outputs.clear(); + outputs.push_back(out); } - outputs.clear(); - outputs.push_back(input_concat); - fully_connected_layer++; } if (layer_type == 3) // PReLU @@ -513,6 +531,35 @@ void FaceDetectorMTCNN::Read(string location) } } +cv::Mat_ generate_bounding_boxes(cv::Mat_ heatmap, vector > corrections, double scale, double threshold, int face_support) +{ + // use heatmap to generate bounding boxes in the original image space + + // Correction for the pooling + int stride = 2; + + // Offsets for, x, y, width and height + //cv::Mat_ dx1 = corrections.col(1); + //cv::Mat_ dy1 = corrections.col(2); + //cv::Mat_ dx2 = corrections.col(3); + //cv::Mat_ dy2 = corrections.col(4); + + // Find the parts of a heatmap above the threshold(x, y, and indices) + cv::Mat_ mask = heatmap >= threshold; + + // Find the corresponding scores and bbox corrections + //score = heatmap(inds); + //correction = [dx1(inds) dy1(inds) dx2(inds) dy2(inds)]; + + // Correcting for Matlab's format + //bboxes = [y - 1 x - 1]; + //bboxes = [fix((stride*(bboxes)+1) / scale) fix((stride*(bboxes)+face_support) / scale) score correction]; + + return cv::Mat_(); + +} + + // The actual MTCNN face detection step bool FaceDetectorMTCNN::DetectFaces(vector >& o_regions, const cv::Mat& input_img, std::vector& o_confidences, int min_face_size, double t1, double t2, double t3) { @@ -552,14 +599,12 @@ bool FaceDetectorMTCNN::DetectFaces(vector >& o_regions, const std::vector > pnet_out = PNet.Inference(normalised_img); - // TODO resize appropriately the output - cv::Mat_ out_prob; cv::exp(pnet_out[0]- pnet_out[1], out_prob); out_prob = 1.0 / (1.0 + out_prob); - cv::imshow("out_map", out_prob); - cv::waitKey(0); + // Grab the detections + } return true;