Updating dlib and linking to it differently, increasing compilation speed and reducing dependency issues with it.

This commit is contained in:
Tadas Baltrusaitis
2018-06-08 16:48:21 +01:00
parent 8bcc5c46ca
commit cd24af000f
724 changed files with 91433 additions and 92254 deletions

View File

@@ -1,432 +0,0 @@
#
# This is a CMake makefile. You can find the cmake utility and
# information about it at http://www.cmake.org
#
# setting this makes CMake allow normal looking if else statements
SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true)
cmake_minimum_required(VERSION 2.4)
# Suppress cmake warnings about changes in new versions.
if(COMMAND cmake_policy)
cmake_policy(SET CMP0003 NEW)
endif()
add_definitions(-DDLIB_HAVE_SSE2)
add_definitions(-DDLIB_HAVE_SSE3)
add_definitions(-DDLIB_HAVE_SSE41)
# make macros that can add #define directives to the entire project. Not just
# to the dlib library itself. I.e. to dlib and to any projects that depend
# on dlib.
macro ( add_global_define def_name )
if (NOT CMAKE_CXX_FLAGS MATCHES "-D${def_name}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D${def_name}"
CACHE STRING "Flags used by the compiler during all C++ builds."
FORCE)
endif ()
endmacro()
macro ( remove_global_define def_name )
if (CMAKE_CXX_FLAGS MATCHES " -D${def_name}")
string (REGEX REPLACE " -D${def_name}" "" temp_var ${CMAKE_CXX_FLAGS})
set (CMAKE_CXX_FLAGS "${temp_var}"
CACHE STRING "Flags used by the compiler during all C++ builds."
FORCE)
endif ()
endmacro()
# Make sure ENABLE_ASSERTS is defined for debug builds
if (NOT CMAKE_CXX_FLAGS_DEBUG MATCHES "-DENABLE_ASSERTS")
set (CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DENABLE_ASSERTS"
CACHE STRING "Flags used by the compiler during C++ debug builds."
FORCE)
endif ()
# Don't try to call add_library(dlib) and setup dlib's stuff if it has already
# been done by some other part of the current cmake project. We do this
# because it avoids getting warnings/errors about cmake policy CMP0002. This
# happens when a project tries to call add_subdirectory() on dlib more than
# once. This most often happens when the top level of a project depends on two
# or more other things which both depend on dlib.
if (NOT TARGET dlib)
set (DLIB_ISO_CPP_ONLY_STR
"Enable this if you don't want to compile any non-ISO C++ code (i.e. you don't use any of the API Wrappers)" )
set (DLIB_NO_GUI_SUPPORT_STR
"Enable this if you don't want to compile any of the dlib GUI code" )
set (DLIB_ENABLE_STACK_TRACE_STR
"Enable this if you want to turn on the DLIB_STACK_TRACE macros" )
set (DLIB_ENABLE_ASSERTS_STR
"Enable this if you want to turn on the DLIB_ASSERT macro" )
set (DLIB_USE_BLAS_STR
"Disable this if you don't want to use a BLAS library" )
set (DLIB_USE_LAPACK_STR
"Disable this if you don't want to use a LAPACK library" )
set (DLIB_LINK_WITH_LIBPNG_STR
"Disable this if you don't want to link against libpng" )
set (DLIB_LINK_WITH_LIBJPEG_STR
"Disable this if you don't want to link against libjpeg" )
set (DLIB_LINK_WITH_SQLITE3_STR
"Disable this if you don't want to link against sqlite3" )
set (DLIB_LINK_WITH_FFTW_STR
"Disable this if you don't want to link against fftw" )
option(DLIB_ISO_CPP_ONLY ${DLIB_ISO_CPP_ONLY_STR} OFF)
option(DLIB_NO_GUI_SUPPORT ${DLIB_NO_GUI_SUPPORT_STR} OFF)
option(DLIB_ENABLE_STACK_TRACE ${DLIB_ENABLE_STACK_TRACE_STR} OFF)
option(DLIB_ENABLE_ASSERTS ${DLIB_ENABLE_ASSERTS_STR} OFF)
option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} ON)
option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} ON)
option(DLIB_LINK_WITH_LIBPNG ${DLIB_LINK_WITH_LIBPNG_STR} ON)
option(DLIB_LINK_WITH_LIBJPEG ${DLIB_LINK_WITH_LIBJPEG_STR} ON)
option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} ON)
option(DLIB_LINK_WITH_FFTW ${DLIB_LINK_WITH_FFTW_STR} ON)
set(source_files
include/dlib/base64/base64_kernel_1.cpp
include/dlib/bigint/bigint_kernel_1.cpp
include/dlib/bigint/bigint_kernel_2.cpp
include/dlib/bit_stream/bit_stream_kernel_1.cpp
include/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp
include/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp
include/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp
include/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp
include/dlib/md5/md5_kernel_1.cpp
include/dlib/tokenizer/tokenizer_kernel_1.cpp
include/dlib/unicode/unicode.cpp
include/dlib/data_io/image_dataset_metadata.cpp)
if (DLIB_ISO_CPP_ONLY)
add_library(dlib STATIC ${source_files} )
else()
set(source_files ${source_files}
include/dlib/sockets/sockets_kernel_1.cpp
include/dlib/bsp/bsp.cpp
include/dlib/dir_nav/dir_nav_kernel_1.cpp
include/dlib/dir_nav/dir_nav_kernel_2.cpp
include/dlib/dir_nav/dir_nav_extensions.cpp
include/dlib/linker/linker_kernel_1.cpp
include/dlib/logger/extra_logger_headers.cpp
include/dlib/logger/logger_kernel_1.cpp
include/dlib/logger/logger_config_file.cpp
include/dlib/misc_api/misc_api_kernel_1.cpp
include/dlib/misc_api/misc_api_kernel_2.cpp
include/dlib/sockets/sockets_extensions.cpp
include/dlib/sockets/sockets_kernel_2.cpp
include/dlib/sockstreambuf/sockstreambuf.cpp
include/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp
include/dlib/server/server_kernel.cpp
include/dlib/server/server_iostream.cpp
include/dlib/server/server_http.cpp
include/dlib/threads/multithreaded_object_extension.cpp
include/dlib/threads/threaded_object_extension.cpp
include/dlib/threads/threads_kernel_1.cpp
include/dlib/threads/threads_kernel_2.cpp
include/dlib/threads/threads_kernel_shared.cpp
include/dlib/threads/thread_pool_extension.cpp
include/dlib/timer/timer.cpp
include/dlib/stack_trace.cpp
)
# we want to link to the right stuff depending on our platform.
if (WIN32 AND NOT CYGWIN) ###############################################################################
if (DLIB_NO_GUI_SUPPORT)
set (dlib_needed_libraries ws2_32)
else()
set (dlib_needed_libraries ws2_32 comctl32 gdi32 imm32)
endif()
elseif(APPLE) ############################################################################
find_library(pthreadlib pthread)
set (dlib_needed_libraries ${pthreadlib})
if (NOT DLIB_NO_GUI_SUPPORT)
find_library(xlib X11)
# make sure X11 is in the include path
find_path(xlib_path Xlib.h
PATHS
/Developer/SDKs/MacOSX10.4u.sdk/usr/X11R6/include
PATH_SUFFIXES X11
)
if (xlib AND xlib_path)
get_filename_component(x11_path ${xlib_path} PATH CACHE)
include_directories(${x11_path})
set(dlib_needed_libraries ${dlib_needed_libraries} ${xlib} )
else()
message(" *****************************************************************************")
message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***")
message(" *** Make sure libx11-dev is installed if you want GUI support ***")
message(" *****************************************************************************")
set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE )
endif()
endif()
mark_as_advanced(pthreadlib xlib xlib_path x11_path)
else () ##################################################################################
find_library(pthreadlib pthread)
set (dlib_needed_libraries ${pthreadlib})
# link to the nsl library if it exists. this is something you need sometimes
find_library(nsllib nsl)
if (nsllib)
set (dlib_needed_libraries ${dlib_needed_libraries} ${nsllib})
endif ()
# link to the socket library if it exists. this is something you need on solaris
find_library(socketlib socket)
if (socketlib)
set (dlib_needed_libraries ${dlib_needed_libraries} ${socketlib})
endif ()
if (NOT DLIB_NO_GUI_SUPPORT)
include(FindX11)
if (X11_FOUND)
include_directories(${X11_INCLUDE_DIR})
set (dlib_needed_libraries ${dlib_needed_libraries} ${X11_LIBRARIES})
else()
message(" *****************************************************************************")
message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***")
message(" *** Make sure libx11-dev is installed if you want GUI support ***")
message(" *****************************************************************************")
set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE )
endif()
endif()
mark_as_advanced(nsllib pthreadlib socketlib)
endif () ##################################################################################
if (NOT DLIB_NO_GUI_SUPPORT)
set(source_files ${source_files}
include/dlib/gui_widgets/fonts.cpp
include/dlib/gui_widgets/widgets.cpp
include/dlib/gui_widgets/drawable.cpp
include/dlib/gui_widgets/canvas_drawing.cpp
include/dlib/gui_widgets/style.cpp
include/dlib/gui_widgets/base_widgets.cpp
include/dlib/gui_core/gui_core_kernel_1.cpp
include/dlib/gui_core/gui_core_kernel_2.cpp
)
endif()
if (DLIB_LINK_WITH_LIBPNG)
# try to find libpng
set(ZLIB_FIND_QUIETLY ON)
set(PNG_FIND_QUIETLY ON)
include(FindPNG)
if (PNG_FOUND)
include_directories(${PNG_INCLUDE_DIR})
set (dlib_needed_libraries ${dlib_needed_libraries} ${PNG_LIBRARY})
else()
# If we can't find libpng then statically compile it in.
include_directories(external/libpng external/zlib)
set(source_files ${source_files}
include/dlib/external/libpng/png.c
include/dlib/external/libpng/pngerror.c
include/dlib/external/libpng/pngget.c
include/dlib/external/libpng/pngmem.c
include/dlib/external/libpng/pngpread.c
include/dlib/external/libpng/pngread.c
include/dlib/external/libpng/pngrio.c
include/dlib/external/libpng/pngrtran.c
include/dlib/external/libpng/pngrutil.c
include/dlib/external/libpng/pngset.c
include/dlib/external/libpng/pngtrans.c
include/dlib/external/libpng/pngwio.c
include/dlib/external/libpng/pngwrite.c
include/dlib/external/libpng/pngwtran.c
include/dlib/external/libpng/pngwutil.c
include/dlib/external/zlib/adler32.c
include/dlib/external/zlib/compress.c
include/dlib/external/zlib/crc32.c
include/dlib/external/zlib/deflate.c
include/dlib/external/zlib/gzclose.c
include/dlib/external/zlib/gzlib.c
include/dlib/external/zlib/gzread.c
include/dlib/external/zlib/gzwrite.c
include/dlib/external/zlib/infback.c
include/dlib/external/zlib/inffast.c
include/dlib/external/zlib/inflate.c
include/dlib/external/zlib/inftrees.c
include/dlib/external/zlib/trees.c
include/dlib/external/zlib/uncompr.c
include/dlib/external/zlib/zutil.c
)
endif()
set(source_files ${source_files}
include/dlib/image_loader/png_loader.cpp
include/dlib/image_saver/save_png.cpp
)
endif()
if (DLIB_LINK_WITH_LIBJPEG)
# try to find libjpeg
include(FindJPEG)
if (JPEG_FOUND)
include_directories(${JPEG_INCLUDE_DIR})
set (dlib_needed_libraries ${dlib_needed_libraries} ${JPEG_LIBRARY})
else()
# If we can't find libjpeg then statically compile it in.
include_directories(external/libjpeg)
set(source_files ${source_files}
include/dlib/external/libjpeg/jcomapi.cpp
include/dlib/external/libjpeg/jdapimin.cpp
include/dlib/external/libjpeg/jdapistd.cpp
include/dlib/external/libjpeg/jdatasrc.cpp
include/dlib/external/libjpeg/jdcoefct.cpp
include/dlib/external/libjpeg/jdcolor.cpp
include/dlib/external/libjpeg/jddctmgr.cpp
include/dlib/external/libjpeg/jdhuff.cpp
include/dlib/external/libjpeg/jdinput.cpp
include/dlib/external/libjpeg/jdmainct.cpp
include/dlib/external/libjpeg/jdmarker.cpp
include/dlib/external/libjpeg/jdmaster.cpp
include/dlib/external/libjpeg/jdmerge.cpp
include/dlib/external/libjpeg/jdphuff.cpp
include/dlib/external/libjpeg/jdpostct.cpp
include/dlib/external/libjpeg/jdsample.cpp
include/dlib/external/libjpeg/jerror.cpp
include/dlib/external/libjpeg/jidctflt.cpp
include/dlib/external/libjpeg/jidctfst.cpp
include/dlib/external/libjpeg/jidctint.cpp
include/dlib/external/libjpeg/jidctred.cpp
include/dlib/external/libjpeg/jmemmgr.cpp
include/dlib/external/libjpeg/jmemnobs.cpp
include/dlib/external/libjpeg/jquant1.cpp
include/dlib/external/libjpeg/jquant2.cpp
include/dlib/external/libjpeg/jutils.cpp )
endif()
set(source_files ${source_files}
include/dlib/image_loader/jpeg_loader.cpp
)
endif()
if (DLIB_USE_BLAS OR DLIB_USE_LAPACK)
# Try to find BLAS and LAPACK
include(cmake_find_blas.txt)
if (DLIB_USE_BLAS)
if (blas_found)
set (dlib_needed_libraries ${dlib_needed_libraries} ${blas_libraries})
else()
set(DLIB_USE_BLAS OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE )
endif()
endif()
if (DLIB_USE_LAPACK)
if (lapack_found)
set (dlib_needed_libraries ${dlib_needed_libraries} ${lapack_libraries})
else()
set(DLIB_USE_LAPACK OFF CACHE STRING ${DLIB_USE_LAPACK_STR} FORCE )
endif()
endif()
endif()
if (DLIB_LINK_WITH_SQLITE3)
find_library(sqlite sqlite3)
# make sure sqlite3.h is in the include path
find_path(sqlite_path sqlite3.h)
if (sqlite AND sqlite_path)
get_filename_component(sqlite_path2 ${sqlite_path} PATH CACHE)
include_directories(${sqlite_path2})
set(dlib_needed_libraries ${dlib_needed_libraries} ${sqlite} )
else()
set(DLIB_LINK_WITH_SQLITE3 OFF CACHE STRING ${DLIB_LINK_WITH_SQLITE3_STR} FORCE )
endif()
mark_as_advanced(sqlite sqlite_path sqlite_path2)
endif()
if (DLIB_LINK_WITH_FFTW)
find_library(fftw fftw3)
# make sure fftw3.h is in the include path
find_path(fftw_path fftw3.h)
if (fftw AND fftw_path)
get_filename_component(fftw_path2 ${fftw_path} PATH CACHE)
include_directories(${fftw_path2})
set(dlib_needed_libraries ${dlib_needed_libraries} ${fftw} )
else()
set(DLIB_LINK_WITH_FFTW OFF CACHE STRING ${DLIB_LINK_WITH_SQLITE3_STR} FORCE )
endif()
mark_as_advanced(fftw fftw_path fftw_path2)
endif()
add_library(dlib STATIC ${source_files} )
target_link_libraries(dlib ${dlib_needed_libraries} )
endif () ##### end of if NOT DLIB_ISO_CPP_ONLY ##########################################################
#test for some things that really should be true about the compiler
include(TestForSTDNamespace)
include(TestForANSIStreamHeaders)
if (DLIB_LINK_WITH_LIBPNG AND NOT DLIB_ISO_CPP_ONLY)
add_global_define(DLIB_PNG_SUPPORT)
else()
remove_global_define(DLIB_PNG_SUPPORT)
endif()
if (DLIB_LINK_WITH_LIBJPEG AND NOT DLIB_ISO_CPP_ONLY)
add_global_define(DLIB_JPEG_SUPPORT)
else()
remove_global_define(DLIB_JPEG_SUPPORT)
endif()
if (DLIB_LINK_WITH_FFTW AND NOT DLIB_ISO_CPP_ONLY)
add_global_define(DLIB_USE_FFTW)
else()
remove_global_define(DLIB_USE_FFTW)
endif()
if (DLIB_USE_BLAS AND blas_found)
add_global_define(DLIB_USE_BLAS)
else()
remove_global_define(DLIB_USE_BLAS)
endif()
if (DLIB_USE_LAPACK AND lapack_found)
add_global_define(DLIB_USE_LAPACK)
else()
remove_global_define(DLIB_USE_LAPACK)
endif()
if (DLIB_ISO_CPP_ONLY)
add_global_define(DLIB_ISO_CPP_ONLY)
else()
remove_global_define(DLIB_ISO_CPP_ONLY)
endif()
if (DLIB_NO_GUI_SUPPORT)
add_global_define(DLIB_NO_GUI_SUPPORT)
else()
remove_global_define(DLIB_NO_GUI_SUPPORT)
endif()
if (DLIB_ENABLE_STACK_TRACE)
add_global_define(DLIB_ENABLE_STACK_TRACE)
else()
remove_global_define(DLIB_ENABLE_STACK_TRACE)
endif()
if (DLIB_ENABLE_ASSERTS)
add_global_define(ENABLE_ASSERTS)
else()
remove_global_define(ENABLE_ASSERTS)
endif()
endif()

View File

@@ -1,227 +0,0 @@
#
# This is a CMake makefile. You can find the cmake utility and
# information about it at http://www.cmake.org
#
#
# This cmake file tries to find installed BLAS and LAPACK libraries.
# It looks for an installed copy of the Intel MKL library first and then
# attempts to find some other BLAS and LAPACK libraries if you don't have
# the Intel MKL.
#
# blas_found - True if BLAS is available
# lapack_found - True if LAPACK is available
# blas_libraries - link against these to use BLAS library
# lapack_libraries - link against these to use LAPACK library
# setting this makes CMake allow normal looking if else statements
SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true)
SET(blas_found 0)
SET(lapack_found 0)
if (UNIX)
message(STATUS "Searching for BLAS and LAPACK")
include(CheckTypeSize)
check_type_size( "void*" SIZE_OF_VOID_PTR)
if (SIZE_OF_VOID_PTR EQUAL 8)
set( mkl_search_path
/opt/intel/mkl/*/lib/em64t
/opt/intel/mkl/lib/intel64
/opt/intel/lib/intel64
)
find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path})
else()
set( mkl_search_path
/opt/intel/mkl/*/lib/32
/opt/intel/mkl/lib/ia32
/opt/intel/lib/ia32
)
find_library(mkl_intel mkl_intel ${mkl_search_path})
endif()
include(CheckLibraryExists)
# Search for the needed libraries from the MKL. We will try to link against the mkl_rt
# file first since this way avoids linking bugs in some cases.
find_library(mkl_rt mkl_rt ${mkl_search_path})
mark_as_advanced( mkl_rt )
# if we found the MKL
if ( mkl_rt)
set(blas_libraries ${mkl_rt} )
set(lapack_libraries ${mkl_rt} )
set(blas_found 1)
set(lapack_found 1)
set(found_intel_mkl 1)
message(STATUS "Found Intel MKL BLAS/LAPACK library")
endif()
if (NOT found_intel_mkl)
# Search for the needed libraries from the MKL. This time try looking for a different
# set of MKL files and try to link against those.
find_library(mkl_core mkl_core ${mkl_search_path})
find_library(mkl_thread mkl_intel_thread ${mkl_search_path})
find_library(mkl_iomp iomp5 ${mkl_search_path})
find_library(mkl_pthread pthread ${mkl_search_path})
mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp mkl_pthread)
# If we found the MKL
if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp AND mkl_pthread)
set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
set(blas_found 1)
set(lapack_found 1)
set(found_intel_mkl 1)
message(STATUS "Found Intel MKL BLAS/LAPACK library")
endif()
endif()
# try to find some other LAPACK libraries if we didn't find the MKL
set(extra_paths
/usr/lib64
/usr/lib64/atlas-sse3
/usr/lib64/atlas-sse2
/usr/lib64/atlas
/usr/lib
/usr/lib/atlas-sse3
/usr/lib/atlas-sse2
/usr/lib/atlas)
if (NOT lapack_found)
find_library(lapack_lib NAMES lapack lapack-3 PATHS ${extra_paths})
if (lapack_lib)
set(lapack_libraries ${lapack_lib})
set(lapack_found 1)
message(STATUS "Found LAPACK library")
endif()
mark_as_advanced( lapack_lib)
endif()
# try to find some other BLAS libraries if we didn't find the MKL
if (NOT blas_found)
find_library(atlas_lib atlas PATHS ${extra_paths})
find_library(cblas_lib cblas PATHS ${extra_paths})
if (atlas_lib AND cblas_lib)
set(blas_libraries ${atlas_lib} ${cblas_lib})
set(blas_found 1)
message(STATUS "Found ATLAS BLAS library")
endif()
mark_as_advanced( atlas_lib cblas_lib)
endif()
if (NOT blas_found)
find_library(cblas_lib cblas PATHS ${extra_paths})
if (cblas_lib)
set(blas_libraries ${cblas_lib})
set(blas_found 1)
message(STATUS "Found CBLAS library")
endif()
mark_as_advanced( cblas_lib)
endif()
if (NOT blas_found)
find_library(generic_blas blas PATHS ${extra_paths})
if (generic_blas)
set(blas_libraries ${generic_blas})
set(blas_found 1)
message(STATUS "Found BLAS library")
endif()
mark_as_advanced( generic_blas)
endif()
# Make sure we really found a CBLAS library. That is, it needs to expose
# the proper cblas link symbols. So here we test if one of them is present
# and assume everything is good if it is. Note that we don't do this check if
# we found the Intel MKL since for some reason CHECK_FUNCTION_EXISTS doesn't work
# with it. But it's fine since the MKL should always have cblas.
if (blas_found AND NOT found_intel_mkl)
INCLUDE (CheckFunctionExists)
set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries})
CHECK_FUNCTION_EXISTS(cblas_ddot HAVE_CBLAS)
if (NOT HAVE_CBLAS)
message(STATUS "BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK")
set(blas_found 0)
set(lapack_found 0)
endif()
endif()
if (NOT blas_found)
message(" *****************************************************************************")
message(" *** No BLAS library found so using dlib's built in BLAS. However, if you ***")
message(" *** install an optimized BLAS such as openblas or the Intel MKL your code ***")
message(" *** will run faster. On Ubuntu you can install openblas by executing: ***")
message(" *** sudo apt-get install libopenblas-dev liblapack-dev ***")
message(" *****************************************************************************")
endif()
elseif(WIN32 AND NOT MINGW)
message(STATUS "Searching for BLAS and LAPACK")
include(CheckTypeSize)
check_type_size( "void*" SIZE_OF_VOID_PTR)
if (SIZE_OF_VOID_PTR EQUAL 8)
set( mkl_search_path
"C:/Program Files (x86)/Intel/Composer XE/mkl/lib/intel64"
"C:/Program Files (x86)/Intel/Composer XE/compiler/lib/intel64"
"C:/Program Files/Intel/Composer XE/mkl/lib/intel64"
"C:/Program Files/Intel/Composer XE/compiler/lib/intel64"
)
find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path})
else()
set( mkl_search_path
"C:/Program Files (x86)/Intel/Composer XE/mkl/lib/ia32"
"C:/Program Files (x86)/Intel/Composer XE/compiler/lib/ia32"
"C:/Program Files/Intel/Composer XE/mkl/lib/ia32"
"C:/Program Files/Intel/Composer XE/compiler/lib/ia32"
)
find_library(mkl_intel mkl_intel_c ${mkl_search_path})
endif()
# Search for the needed libraries from the MKL.
find_library(mkl_core mkl_core ${mkl_search_path})
find_library(mkl_thread mkl_intel_thread ${mkl_search_path})
find_library(mkl_iomp libiomp5md ${mkl_search_path})
mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp)
# If we found the MKL
if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp )
set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} )
set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} )
set(blas_found 1)
set(lapack_found 1)
message(STATUS "Found Intel MKL BLAS/LAPACK library")
if (MSVC)
# need to set /bigobj when statically linking with the MKL on
# visual studio or it doesn't work right.
if (NOT CMAKE_CXX_FLAGS MATCHES "/bigobj")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj"
CACHE STRING "Flags used by the compiler during all C++ builds."
FORCE)
endif ()
endif()
endif()
endif()

View File

@@ -4,5 +4,9 @@
<ClCompile>
<AdditionalIncludeDirectories>$(SolutionDir)lib\3rdParty\dlib\include\dlib\..;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<AdditionalLibraryDirectories>$(SolutionDir)lib\3rdParty\dlib\lib\$(PlatformTarget)\$(PlatformToolset)\$(Configuration);%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
<AdditionalDependencies>dlib.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
</Project>

View File

@@ -1,217 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="14.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<ItemGroup>
<ClCompile Include="include\dlib\all\source.cpp" />
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGUID>{B47A5F12-2567-44E9-AE49-35763EC82149}</ProjectGUID>
<Keyword>Win32Proj</Keyword>
<Platform>Win32</Platform>
<ProjectName>dlib</ProjectName>
<WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseOfMfc>false</UseOfMfc>
<CharacterSet>Unicode</CharacterSet>
<PlatformToolset>v140</PlatformToolset>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseOfMfc>false</UseOfMfc>
<CharacterSet>Unicode</CharacterSet>
<PlatformToolset>v140</PlatformToolset>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseOfMfc>false</UseOfMfc>
<CharacterSet>Unicode</CharacterSet>
<PlatformToolset>v140</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseOfMfc>false</UseOfMfc>
<CharacterSet>Unicode</CharacterSet>
<PlatformToolset>v140</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="..\tbb\tbb_d.props" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="..\tbb\tbb_d.props" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="..\tbb\tbb.props" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="..\tbb\tbb.props" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup>
<_ProjectFileVersion>10.0.20506.1</_ProjectFileVersion>
<OutDir Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">$(SolutionDir)$(Configuration)\</OutDir>
<IntDir Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">$(ProjectDir)$(Configuration)\</IntDir>
<TargetName Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">dlib</TargetName>
<TargetName Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">dlib</TargetName>
<TargetExt Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">.lib</TargetExt>
<TargetExt Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">.lib</TargetExt>
<OutDir Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">$(SolutionDir)$(Configuration)\</OutDir>
<IntDir Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">$(ProjectDir)$(Configuration)\</IntDir>
<TargetName Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">$(ProjectName)</TargetName>
<TargetName Condition="'$(Configuration)|$(Platform)'=='Release|x64'">$(ProjectName)</TargetName>
<TargetExt Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">.lib</TargetExt>
<TargetExt Condition="'$(Configuration)|$(Platform)'=='Release|x64'">.lib</TargetExt>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<AdditionalIncludeDirectories>$(SolutionDir)dlib/include/dlib/..;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<BasicRuntimeChecks>EnableFastChecks</BasicRuntimeChecks>
<CompileAs>CompileAsCpp</CompileAs>
<DebugInformationFormat>ProgramDatabase</DebugInformationFormat>
<EnableEnhancedInstructionSet>StreamingSIMDExtensions2</EnableEnhancedInstructionSet>
<ExceptionHandling>Sync</ExceptionHandling>
<InlineFunctionExpansion>Disabled</InlineFunctionExpansion>
<Optimization>Disabled</Optimization>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary>
<RuntimeTypeInfo>true</RuntimeTypeInfo>
<WarningLevel>Level3</WarningLevel>
<PreprocessorDefinitions>DLIB_NO_GUI_SUPPORT;WIN32;_WINDOWS;_DEBUG;ENABLE_ASSERTS;DLIB_HAVE_SSE2;DLIB_HAVE_SSE3;DLIB_HAVE_SSE41;CMAKE_INTDIR="Debug";%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AssemblerListingLocation>Debug</AssemblerListingLocation>
<ObjectFileName>$(IntDir)</ObjectFileName>
<ProgramDataBaseFileName>$(IntDir)vc$(PlatformToolsetVersion).pdb</ProgramDataBaseFileName>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
</ClCompile>
<Lib>
<LinkTimeCodeGeneration>true</LinkTimeCodeGeneration>
</Lib>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<AdditionalIncludeDirectories>$(SolutionDir)dlib/include/dlib/..;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<BasicRuntimeChecks>EnableFastChecks</BasicRuntimeChecks>
<CompileAs>CompileAsCpp</CompileAs>
<DebugInformationFormat>ProgramDatabase</DebugInformationFormat>
<EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet>
<ExceptionHandling>Sync</ExceptionHandling>
<InlineFunctionExpansion>Disabled</InlineFunctionExpansion>
<Optimization>Disabled</Optimization>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary>
<RuntimeTypeInfo>true</RuntimeTypeInfo>
<WarningLevel>Level3</WarningLevel>
<PreprocessorDefinitions>DLIB_NO_GUI_SUPPORT;WIN64;_WINDOWS;_DEBUG;ENABLE_ASSERTS;DLIB_HAVE_SSE2;DLIB_HAVE_SSE3;DLIB_HAVE_SSE41;CMAKE_INTDIR="Debug";%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AssemblerListingLocation>Debug</AssemblerListingLocation>
<ObjectFileName>$(IntDir)</ObjectFileName>
<ProgramDataBaseFileName>$(IntDir)vc$(PlatformToolsetVersion).pdb</ProgramDataBaseFileName>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
</ClCompile>
<Lib>
<LinkTimeCodeGeneration>true</LinkTimeCodeGeneration>
</Lib>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<AdditionalIncludeDirectories>$(SolutionDir)dlib/include/dlib/..;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<CompileAs>CompileAsCpp</CompileAs>
<EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet>
<ExceptionHandling>Sync</ExceptionHandling>
<InlineFunctionExpansion>AnySuitable</InlineFunctionExpansion>
<Optimization>Full</Optimization>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary>
<RuntimeTypeInfo>true</RuntimeTypeInfo>
<WarningLevel>Level3</WarningLevel>
<DebugInformationFormat>
</DebugInformationFormat>
<PreprocessorDefinitions>DLIB_NO_GUI_SUPPORT;WIN32;_WINDOWS;NDEBUG;DLIB_HAVE_SSE2;DLIB_HAVE_SSE3;DLIB_HAVE_SSE41;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AssemblerListingLocation>$(IntDir)</AssemblerListingLocation>
<ObjectFileName>$(IntDir)</ObjectFileName>
<ProgramDataBaseFileName>$(IntDir)vc$(PlatformToolsetVersion).pdb</ProgramDataBaseFileName>
<FavorSizeOrSpeed>Speed</FavorSizeOrSpeed>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
</ClCompile>
<ResourceCompile>
<PreprocessorDefinitions>WIN32;_WINDOWS;DLIB_PNG_SUPPORT;DLIB_JPEG_SUPPORT;NDEBUG;DLIB_HAVE_SSE2;DLIB_HAVE_SSE3;DLIB_HAVE_SSE41;CMAKE_INTDIR=\"Release\";%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AdditionalIncludeDirectories>$(SolutionDir)/dlib/..;$(SolutionDir)/dlib/external/libpng;$(SolutionDir)/dlib/external/zlib;$(SolutionDir)/dlib/external/libjpeg;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ResourceCompile>
<Midl>
<AdditionalIncludeDirectories>$(SolutionDir)/dlib/..;$(SolutionDir)/dlib/external/libpng;$(SolutionDir)/dlib/external/zlib;$(SolutionDir)/dlib/external/libjpeg;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<OutputDirectory>$(IntDir)</OutputDirectory>
<HeaderFileName>%(Filename).h</HeaderFileName>
<TypeLibraryName>%(Filename).tlb</TypeLibraryName>
<InterfaceIdentifierFileName>%(Filename)_i.c</InterfaceIdentifierFileName>
<ProxyFileName>%(Filename)_p.c</ProxyFileName>
</Midl>
<Lib>
<LinkTimeCodeGeneration>true</LinkTimeCodeGeneration>
</Lib>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<AdditionalIncludeDirectories>$(SolutionDir)dlib/include/dlib/..;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<CompileAs>CompileAsCpp</CompileAs>
<EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet>
<ExceptionHandling>Sync</ExceptionHandling>
<InlineFunctionExpansion>AnySuitable</InlineFunctionExpansion>
<Optimization>Full</Optimization>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary>
<RuntimeTypeInfo>true</RuntimeTypeInfo>
<WarningLevel>Level3</WarningLevel>
<DebugInformationFormat>
</DebugInformationFormat>
<PreprocessorDefinitions>DLIB_NO_GUI_SUPPORT;WIN64;_WINDOWS;NDEBUG;DLIB_HAVE_SSE2;DLIB_HAVE_SSE3;DLIB_HAVE_SSE41;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AssemblerListingLocation>$(IntDir)</AssemblerListingLocation>
<ObjectFileName>$(IntDir)</ObjectFileName>
<ProgramDataBaseFileName>$(IntDir)vc$(PlatformToolsetVersion).pdb</ProgramDataBaseFileName>
<FavorSizeOrSpeed>Speed</FavorSizeOrSpeed>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
</ClCompile>
<ResourceCompile>
<PreprocessorDefinitions>WIN32;_WINDOWS;DLIB_PNG_SUPPORT;DLIB_JPEG_SUPPORT;NDEBUG;DLIB_HAVE_SSE2;DLIB_HAVE_SSE3;DLIB_HAVE_SSE41;CMAKE_INTDIR=\"Release\";%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AdditionalIncludeDirectories>$(SolutionDir)/dlib/..;$(SolutionDir)/dlib/external/libpng;$(SolutionDir)/dlib/external/zlib;$(SolutionDir)/dlib/external/libjpeg;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ResourceCompile>
<Midl>
<AdditionalIncludeDirectories>$(SolutionDir)/dlib/..;$(SolutionDir)/dlib/external/libpng;$(SolutionDir)/dlib/external/zlib;$(SolutionDir)/dlib/external/libjpeg;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<OutputDirectory>$(IntDir)</OutputDirectory>
<HeaderFileName>%(Filename).h</HeaderFileName>
<TypeLibraryName>%(Filename).tlb</TypeLibraryName>
<InterfaceIdentifierFileName>%(Filename)_i.c</InterfaceIdentifierFileName>
<ProxyFileName>%(Filename)_p.c</ProxyFileName>
</Midl>
<Lib>
<LinkTimeCodeGeneration>true</LinkTimeCodeGeneration>
</Lib>
</ItemDefinitionGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

View File

@@ -1,111 +0,0 @@
# This is a CMake file that sets up the add_python_module() macro. This macro
# lets you easily make python modules that use dlib.
#
# The macro takes the module name as its first argument and then a list of
# source files to compile into the module. See ../tools/python/CMakeLists.txt
# for an example.
#
# It also sets up a macro called install_${module_name}_to() where
# ${module_name} is whatever you named your module. This install_*_to() macro
# takes a folder name and creates an install target that will copy the compiled
# python module to that folder when you run "make install". Note that the path
# given to install_*_to() is relative to your CMakeLists.txt file.
# A list of various paths you need to search on windows since people install
# boost in a bunch of different places.
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH}
"C:/local/boost_1_*"
"C:/Program Files (x86)/boost/boost_1_*"
"C:/Program Files/boost/boost_1_*")
set(BOOST_LIBRARYDIR "C:/local/boost_1_*/lib32-msvc-*")
#SET(Boost_USE_STATIC_LIBS OFF)
#SET(Boost_USE_MULTITHREADED ON)
#SET(Boost_USE_STATIC_RUNTIME OFF)
set(Boost_NO_BOOST_CMAKE ON)
FIND_PACKAGE(Boost 1.41.0 COMPONENTS python REQUIRED)
FIND_PACKAGE(PythonLibs 2.6 REQUIRED)
if (WIN32 AND NOT Boost_LIBRARIES)
message(FATAL_ERROR "We couldn't find the right version of boost python. If you installed boost and you are still "
"getting this error then you might have installed a version of boost that was compiled with a different "
"version of visual studio than the one you are using. So you have to make sure that the version of "
"visual studio is the same version that was used to compile the copy of boost you are using.")
endif()
INCLUDE_DIRECTORIES("${Boost_INCLUDE_DIRS}")
if (PYTHON_INCLUDE_PATH)
INCLUDE_DIRECTORIES("${PYTHON_INCLUDE_PATH}" )
else()
INCLUDE_DIRECTORIES("${PYTHON_INCLUDE_DIRS}" )
endif()
if (CMAKE_COMPILER_IS_GNUCXX)
add_definitions("-fPIC")
endif()
# include dlib so we can link against it
string(REGEX REPLACE "add_python_module$" "" dlib_path ${CMAKE_CURRENT_LIST_FILE})
include(${dlib_path}/cmake)
# We put the extra _ on the end of the name just so it's possible to
# have a module name of dlib and not get a conflict with the target named
# dlib in ../dlib/cmake. We use the target OUPUT_NAME property to ensure the
# output name is set to what the user asked for (i.e. no _).
macro(add_python_module module_name module_sources )
ADD_LIBRARY(${module_name}_ SHARED ${module_sources} ${ARGN} )
TARGET_LINK_LIBRARIES(${module_name}_ ${Boost_LIBRARIES} ${PYTHON_LIBRARIES} dlib)
if(WIN32 AND NOT CYGWIN)
SET_TARGET_PROPERTIES( ${module_name}_
PROPERTIES
PREFIX ""
SUFFIX ".pyd"
OUTPUT_NAME ${module_name}
)
elseif(CYGWIN)
SET_TARGET_PROPERTIES( ${module_name}_
PROPERTIES
PREFIX ""
SUFFIX ".dll"
OUTPUT_NAME ${module_name}
)
else()
SET_TARGET_PROPERTIES( ${module_name}_
PROPERTIES
PREFIX ""
SUFFIX ".so"
OUTPUT_NAME ${module_name}
)
endif()
macro(install_${module_name}_to path)
# Determine the path to our CMakeLists.txt file.
string(REGEX REPLACE "CMakeLists.txt$" "" base_path ${CMAKE_CURRENT_LIST_FILE})
INSTALL(TARGETS ${module_name}_
DESTINATION "${base_path}/${path}"
)
# On windows we will usually need to have the boost-python .dll files in the same folder or
# you will get an error about how they can't be found. So copy the boost .dll files along with
# your module to the install folder to avoid this.
if (WIN32)
list(GET Boost_LIBRARIES 1 boostlibs1)
list(GET Boost_LIBRARIES 3 boostlibs2)
string(REGEX REPLACE ".lib$" ".dll" boostdlls1 ${boostlibs1})
string(REGEX REPLACE ".lib$" ".dll" boostdlls2 ${boostlibs2})
INSTALL(FILES ${boostdlls1} ${boostdlls2}
DESTINATION "${base_path}/${path}"
)
endif()
endmacro()
endmacro()

View File

@@ -1,12 +1,37 @@
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifdef DLIB_ALL_SOURCE_END
#include "dlib_basic_cpp_build_tutorial.txt"
#endif
#ifndef DLIB_ALGs_
#define DLIB_ALGs_
// this file contains miscellaneous stuff
// Give people who forget the -std=c++11 option a reminder
#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))) || \
(defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4) || (__clang_major__ >= 3)))
#if __cplusplus < 201103
#error "Dlib requires C++11 support. Give your compiler the -std=c++11 option to enable it."
#endif
#endif
#if defined __NVCC__
// Disable the "statement is unreachable" message since it will go off on code that is
// actually reachable but just happens to not be reachable sometimes during certain
// template instantiations.
#pragma diag_suppress code_is_unreachable
#endif
#ifdef _MSC_VER
#if _MSC_VER < 1900
#error "dlib versions newer than v19.1 use C++11 and therefore require Visual Studio 2015 or newer."
#endif
// Disable the following warnings for Visual Studio
// this is to disable the "'this' : used in base member initializer list"
@@ -43,6 +68,16 @@
// This warning happens often in generic code that works with functions and isn't useful.
#pragma warning(disable : 4180)
// Disable "warning C4290: C++ exception specification ignored except to indicate a function is not __declspec(nothrow)"
#pragma warning(disable : 4290)
// DNN module uses template-based network declaration that leads to very long
// type names. Visual Studio will produce Warning C4503 in such cases. https://msdn.microsoft.com/en-us/library/074af4b6.aspx says
// that correct binaries are still produced even when this warning happens, but linker errors from visual studio, if they occurr could be confusing.
#pragma warning( disable: 4503 )
#endif
#ifdef __BORLANDC__
@@ -71,6 +106,7 @@ namespace std
#include <algorithm> // for std::swap
#include <new> // for std::bad_alloc
#include <cstdlib>
#include <stddef.h>
#include <limits> // for std::numeric_limits for is_finite()
#include "assert.h"
#include "error.h"
@@ -275,7 +311,7 @@ namespace dlib
typename A,
typename B
>
bool operator> (
constexpr bool operator> (
const A& a,
const B& b
) { return b < a; }
@@ -286,7 +322,7 @@ namespace dlib
typename A,
typename B
>
bool operator!= (
constexpr bool operator!= (
const A& a,
const B& b
) { return !(a == b); }
@@ -297,7 +333,7 @@ namespace dlib
typename A,
typename B
>
bool operator<= (
constexpr bool operator<= (
const A& a,
const B& b
) { return !(b < a); }
@@ -308,7 +344,7 @@ namespace dlib
typename A,
typename B
>
bool operator>= (
constexpr bool operator>= (
const A& a,
const B& b
) { return !(a < b); }
@@ -480,6 +516,13 @@ namespace dlib
// ----------------------------------------------------------------------------------------
struct general_ {};
struct special_ : general_ {};
template<typename> struct int_ { typedef int type; };
// ----------------------------------------------------------------------------------------
/*!A is_same_object
This is a templated function which checks if both of its arguments are actually
@@ -759,10 +802,10 @@ namespace dlib
abs<4>::value == 4
!*/
template <long x, typename enabled=void>
struct tabs { const static long value = x; };
template <long x>
struct tabs<x,typename enable_if_c<(x < 0)>::type> { const static long value = -x; };
template <long x, typename enabled=void>
struct tabs { const static long value = x; };
template <long x>
struct tabs<x,typename enable_if_c<(x < 0)>::type> { const static long value = -x; };
// ----------------------------------------------------------------------------------------
@@ -774,10 +817,10 @@ namespace dlib
abs<4,7>::value == 7
!*/
template <long x, long y, typename enabled=void>
struct tmax { const static long value = x; };
template <long x, long y>
struct tmax<x,y,typename enable_if_c<(y > x)>::type> { const static long value = y; };
template <long x, long y, typename enabled=void>
struct tmax { const static long value = x; };
template <long x, long y>
struct tmax<x,y,typename enable_if_c<(y > x)>::type> { const static long value = y; };
// ----------------------------------------------------------------------------------------
@@ -789,12 +832,12 @@ namespace dlib
abs<4,7>::value == 4
!*/
template <long x, long y, typename enabled=void>
struct tmin { const static long value = x; };
template <long x, long y>
struct tmin<x,y,typename enable_if_c<(y < x)>::type> { const static long value = y; };
template <long x, long y, typename enabled=void>
struct tmin { const static long value = x; };
template <long x, long y>
struct tmin<x,y,typename enable_if_c<(y < x)>::type> { const static long value = y; };
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
#define DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(testname, returnT, funct_name, args) \
struct _two_bytes_##testname { char a[2]; }; \
@@ -1025,6 +1068,88 @@ namespace dlib
void* const data;
};
// ----------------------------------------------------------------------------------------
template <
typename T,
typename F
>
auto max_scoring_element(
const T& container,
F score_func
) -> decltype(std::make_pair(*container.begin(), 0.0))
/*!
requires
- container has .begin() and .end(), allowing it to be enumerated.
- score_func() is a function that takes an element of the container and returns a double.
ensures
- This function finds the element of container that has the largest score,
according to score_func(), and returns a std::pair containing that maximal
element along with the score.
- If the container is empty then make_pair(a default initialized object, -infinity) is returned.
!*/
{
double best_score = -std::numeric_limits<double>::infinity();
auto best_i = container.begin();
for (auto i = container.begin(); i != container.end(); ++i)
{
auto score = score_func(*i);
if (score > best_score)
{
best_score = score;
best_i = i;
}
}
using item_type = typename std::remove_reference<decltype(*best_i)>::type;
if (best_i == container.end())
return std::make_pair(item_type(), best_score);
else
return std::make_pair(*best_i, best_score);
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename F
>
auto min_scoring_element(
const T& container,
F score_func
) -> decltype(std::make_pair(*container.begin(), 0.0))
/*!
requires
- container has .begin() and .end(), allowing it to be enumerated.
- score_func() is a function that takes an element of the container and returns a double.
ensures
- This function finds the element of container that has the smallest score,
according to score_func(), and returns a std::pair containing that minimal
element along with the score.
- If the container is empty then make_pair(a default initialized object, infinity) is returned.
!*/
{
double best_score = std::numeric_limits<double>::infinity();
auto best_i = container.begin();
for (auto i = container.begin(); i != container.end(); ++i)
{
auto score = score_func(*i);
if (score < best_score)
{
best_score = score;
best_i = i;
}
}
using item_type = typename std::remove_reference<decltype(*best_i)>::type;
if (best_i == container.end())
return std::make_pair(item_type(), best_score);
else
return std::make_pair(*best_i, best_score);
}
// ----------------------------------------------------------------------------------------
}

View File

@@ -1,76 +0,0 @@
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ALL_SOURCe_
#define DLIB_ALL_SOURCe_
// ISO C++ code
#include "../base64/base64_kernel_1.cpp"
//#include "../bigint/bigint_kernel_1.cpp"
//#include "../bigint/bigint_kernel_2.cpp"
#include "../bit_stream/bit_stream_kernel_1.cpp"
#include "../entropy_decoder/entropy_decoder_kernel_1.cpp"
#include "../entropy_decoder/entropy_decoder_kernel_2.cpp"
#include "../entropy_encoder/entropy_encoder_kernel_1.cpp"
#include "../entropy_encoder/entropy_encoder_kernel_2.cpp"
//#include "../md5/md5_kernel_1.cpp"
#include "../tokenizer/tokenizer_kernel_1.cpp"
//#include "../unicode/unicode.cpp"
//#include "../data_io/image_dataset_metadata.cpp"
#ifndef DLIB_ISO_CPP_ONLY
// Code that depends on OS specific APIs
// include this first so that it can disable the older version
// of the winsock API when compiled in windows.
#include "../sockets/sockets_kernel_1.cpp"
//#include "../bsp/bsp.cpp"
//#include "../dir_nav/dir_nav_kernel_1.cpp"
//#include "../dir_nav/dir_nav_kernel_2.cpp"
//#include "../dir_nav/dir_nav_extensions.cpp"
//#include "../linker/linker_kernel_1.cpp"
//#include "../logger/extra_logger_headers.cpp"
//#include "../logger/logger_kernel_1.cpp"
//#include "../logger/logger_config_file.cpp"
#include "../misc_api/misc_api_kernel_1.cpp"
#include "../misc_api/misc_api_kernel_2.cpp"
#include "../sockets/sockets_extensions.cpp"
#include "../sockets/sockets_kernel_2.cpp"
#include "../sockstreambuf/sockstreambuf.cpp"
#include "../sockstreambuf/sockstreambuf_unbuffered.cpp"
//#include "../server/server_kernel.cpp"
//#include "../server/server_iostream.cpp"
//#include "../server/server_http.cpp"
//#include "../threads/multithreaded_object_extension.cpp"
#include "../threads/threaded_object_extension.cpp"
#include "../threads/threads_kernel_1.cpp"
#include "../threads/threads_kernel_2.cpp"
#include "../threads/threads_kernel_shared.cpp"
//#include "../threads/thread_pool_extension.cpp"
#include "../timer/timer.cpp"
//#include "../stack_trace.cpp"
#ifdef DLIB_PNG_SUPPORT
#include "../image_loader/png_loader.cpp"
#include "../image_saver/save_png.cpp"
#endif
#ifdef DLIB_JPEG_SUPPORT
#include "../image_loader/jpeg_loader.cpp"
#endif
#ifndef DLIB_NO_GUI_SUPPORT
#include "../gui_widgets/fonts.cpp"
#include "../gui_widgets/widgets.cpp"
#include "../gui_widgets/drawable.cpp"
#include "../gui_widgets/canvas_drawing.cpp"
#include "../gui_widgets/style.cpp"
#include "../gui_widgets/base_widgets.cpp"
#include "../gui_core/gui_core_kernel_1.cpp"
#include "../gui_core/gui_core_kernel_2.cpp"
#endif // DLIB_NO_GUI_SUPPORT
#endif // DLIB_ISO_CPP_ONLY
#endif // DLIB_ALL_SOURCe_

View File

@@ -1,9 +0,0 @@
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ALL_CONSOLe_
#define DLIB_ALL_CONSOLe_
#error "This file has been replaced. Instead you should add dlib/all/source.cpp to your project"
#endif // DLIB_ALL_CONSOLe_

View File

@@ -1,9 +0,0 @@
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ALL_GUi_
#define DLIB_ALL_GUi_
#error "This file has been replaced. Instead you should add dlib/all/source.cpp to your project"
#endif // DLIB_ALL_GUi_

View File

@@ -4,7 +4,9 @@
#define DLIB_AnY_H_
#include "any_abstract.h"
#include "../smart_pointers.h"
#include "../algs.h"
#include <memory>
#include <typeinfo>
namespace dlib
@@ -136,7 +138,7 @@ namespace dlib
virtual ~base() {}
virtual void copy_to (
scoped_ptr<base>& dest
std::unique_ptr<base>& dest
) const = 0;
};
@@ -148,14 +150,14 @@ namespace dlib
derived(const T& val) : item(val) {}
virtual void copy_to (
scoped_ptr<base>& dest
std::unique_ptr<base>& dest
) const
{
dest.reset(new derived<T>(item));
}
};
scoped_ptr<base> data;
std::unique_ptr<base> data;
};
// ----------------------------------------------------------------------------------------

View File

@@ -4,7 +4,6 @@
#define DLIB_AnY_DECISION_FUNCTION_Hh_
#include "any.h"
#include "../smart_pointers.h"
#include "any_decision_function_abstract.h"
@@ -148,7 +147,7 @@ namespace dlib
virtual ~base() {}
virtual void copy_to (
scoped_ptr<base>& dest
std::unique_ptr<base>& dest
) const = 0;
virtual result_type evaluate (
@@ -164,7 +163,7 @@ namespace dlib
derived(const T& val) : item(val) {}
virtual void copy_to (
scoped_ptr<base>& dest
std::unique_ptr<base>& dest
) const
{
dest.reset(new derived<T>(item));
@@ -178,7 +177,7 @@ namespace dlib
}
};
scoped_ptr<base> data;
std::unique_ptr<base> data;
};
// ----------------------------------------------------------------------------------------

View File

@@ -4,7 +4,6 @@
#define DLIB_AnY_FUNCTION_Hh_
#include "any.h"
#include "../smart_pointers.h"
#include "any_function_abstract.h"
@@ -32,6 +31,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 0;
};
@@ -53,6 +62,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 1;
};
@@ -74,6 +93,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 2;
};
@@ -95,6 +124,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 3;
};
@@ -117,6 +156,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 4;
};
@@ -139,6 +188,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 5;
};
@@ -161,6 +220,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 6;
};
@@ -184,6 +253,16 @@ namespace dlib
typedef void arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 7;
};
@@ -207,6 +286,16 @@ namespace dlib
typedef A8 arg8_type;
typedef void arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 8;
};
@@ -230,6 +319,16 @@ namespace dlib
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef void arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 9;
};
@@ -254,10 +353,415 @@ namespace dlib
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef void arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 10;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef void arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 11;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef void arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 12;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef void arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 13;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13,
typename A14
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef A14 arg14_type;
typedef void arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 14;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13,
typename A14,
typename A15
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef A14 arg14_type;
typedef A15 arg15_type;
typedef void arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 15;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13,
typename A14,
typename A15,
typename A16
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef A14 arg14_type;
typedef A15 arg15_type;
typedef A16 arg16_type;
typedef void arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 16;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13,
typename A14,
typename A15,
typename A16,
typename A17
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef A14 arg14_type;
typedef A15 arg15_type;
typedef A16 arg16_type;
typedef A17 arg17_type;
typedef void arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 17;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13,
typename A14,
typename A15,
typename A16,
typename A17,
typename A18
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef A14 arg14_type;
typedef A15 arg15_type;
typedef A16 arg16_type;
typedef A17 arg17_type;
typedef A18 arg18_type;
typedef void arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 18;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13,
typename A14,
typename A15,
typename A16,
typename A17,
typename A18,
typename A19
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef A14 arg14_type;
typedef A15 arg15_type;
typedef A16 arg16_type;
typedef A17 arg17_type;
typedef A18 arg18_type;
typedef A19 arg19_type;
typedef void arg20_type;
const static unsigned long num_args = 19;
};
template <
typename T,
typename A1, typename A2, typename A3,
typename A4, typename A5, typename A6,
typename A7, typename A8, typename A9,
typename A10,
typename A11,
typename A12,
typename A13,
typename A14,
typename A15,
typename A16,
typename A17,
typename A18,
typename A19,
typename A20
>
struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19,A20)>
{
typedef T result_type;
typedef A1 arg1_type;
typedef A2 arg2_type;
typedef A3 arg3_type;
typedef A4 arg4_type;
typedef A5 arg5_type;
typedef A6 arg6_type;
typedef A7 arg7_type;
typedef A8 arg8_type;
typedef A9 arg9_type;
typedef A10 arg10_type;
typedef A11 arg11_type;
typedef A12 arg12_type;
typedef A13 arg13_type;
typedef A14 arg14_type;
typedef A15 arg15_type;
typedef A16 arg16_type;
typedef A17 arg17_type;
typedef A18 arg18_type;
typedef A19 arg19_type;
typedef A20 arg20_type;
const static unsigned long num_args = 20;
};
// ----------------------------------------------------------------------------------------
template <

View File

@@ -166,7 +166,7 @@ struct Tbase
{
virtual ~Tbase() {}
virtual result_type evaluate () const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -177,7 +177,7 @@ struct Tbase<T (A1)>
{
virtual ~Tbase() {}
virtual T evaluate ( A1) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -188,7 +188,7 @@ struct Tbase<T (A1,A2)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -199,7 +199,7 @@ struct Tbase<T (A1,A2,A3)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -211,7 +211,7 @@ struct Tbase<T (A1,A2,A3,A4)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3,A4) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -223,7 +223,7 @@ struct Tbase<T (A1,A2,A3,A4,A5)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3,A4,A5) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -235,7 +235,7 @@ struct Tbase<T (A1,A2,A3,A4,A5,A6)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3,A4,A5,A6) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -248,7 +248,7 @@ struct Tbase<T (A1,A2,A3,A4,A5,A6,A7)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3,A4,A5,A6,A7) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -261,7 +261,7 @@ struct Tbase<T (A1,A2,A3,A4,A5,A6,A7,A8)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -274,7 +274,7 @@ struct Tbase<T (A1,A2,A3,A4,A5,A6,A7,A8,A9)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8,A9) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
template <
@@ -288,7 +288,7 @@ struct Tbase<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10)>
{
virtual ~Tbase() {}
virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10) const = 0;
virtual void copy_to ( scoped_ptr<Tbase>& dest) const = 0;
virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
};
typedef Tbase<function_type> base;
@@ -318,7 +318,7 @@ static typename disable_if<is_function<T>,const T&>::type deref (const U& item)
typename funct_type<T>::type item; \
derived() {} \
derived(const T& val) : item(copy(val)) {} \
virtual void copy_to ( scoped_ptr<base>& dest) const \
virtual void copy_to ( std::unique_ptr<base>& dest) const \
{ dest.reset(new derived(deref<T>(item))); }
template <typename T, typename FT>
@@ -508,7 +508,7 @@ struct derived<T,result_type (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10)> : public base
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
};
scoped_ptr<base> data;
std::unique_ptr<base> data;
#undef DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE

View File

@@ -4,7 +4,6 @@
#define DLIB_AnY_TRAINER_H_
#include "any.h"
#include "../smart_pointers.h"
#include "any_decision_function.h"
@@ -157,7 +156,7 @@ namespace dlib
) const = 0;
virtual void copy_to (
scoped_ptr<base>& dest
std::unique_ptr<base>& dest
) const = 0;
};
@@ -169,7 +168,7 @@ namespace dlib
derived(const T& val) : item(val) {}
virtual void copy_to (
scoped_ptr<base>& dest
std::unique_ptr<base>& dest
) const
{
dest.reset(new derived<T>(item));
@@ -184,7 +183,7 @@ namespace dlib
}
};
scoped_ptr<base> data;
std::unique_ptr<base> data;
};
// ----------------------------------------------------------------------------------------

View File

@@ -90,8 +90,26 @@ namespace dlib
_at_start(true)
{}
array(const array&) = delete;
array& operator=(array&) = delete;
array(
array&& item
) : array()
{
swap(item);
}
array& operator=(
array&& item
)
{
swap(item);
return *this;
}
explicit array (
unsigned long new_size
size_t new_size
) :
array_size(0),
max_array_size(0),
@@ -110,22 +128,22 @@ namespace dlib
);
inline const T& operator[] (
unsigned long pos
size_t pos
) const;
inline T& operator[] (
unsigned long pos
size_t pos
);
void set_size (
unsigned long size
size_t size
);
inline unsigned long max_size(
inline size_t max_size(
) const;
void set_max_size(
unsigned long max
size_t max
);
void swap (
@@ -133,7 +151,7 @@ namespace dlib
);
// functions from the enumerable interface
inline unsigned long size (
inline size_t size (
) const;
inline bool at_start (
@@ -158,7 +176,7 @@ namespace dlib
);
void resize (
unsigned long new_size
size_t new_size
);
const T& back (
@@ -178,6 +196,10 @@ namespace dlib
T& item
);
void push_back (
T&& item
);
typedef T* iterator;
typedef const T* const_iterator;
iterator begin() { return array_elements; }
@@ -190,18 +212,14 @@ namespace dlib
typename mem_manager::template rebind<T>::other pool;
// data members
unsigned long array_size;
unsigned long max_array_size;
size_t array_size;
size_t max_array_size;
T* array_elements;
mutable T* pos;
T* last_pos;
mutable bool _at_start;
// restricted functions
array(array<T>&); // copy constructor
array<T>& operator=(array<T>&); // assignment operator
};
template <
@@ -229,7 +247,7 @@ namespace dlib
serialize(item.max_size(),out);
serialize(item.size(),out);
for (unsigned long i = 0; i < item.size(); ++i)
for (size_t i = 0; i < item.size(); ++i)
serialize(item[i],out);
}
catch (serialization_error e)
@@ -249,12 +267,12 @@ namespace dlib
{
try
{
unsigned long max_size, size;
size_t max_size, size;
deserialize(max_size,in);
deserialize(size,in);
item.set_max_size(max_size);
item.set_size(size);
for (unsigned long i = 0; i < size; ++i)
for (size_t i = 0; i < size; ++i)
deserialize(item[i],in);
}
catch (serialization_error e)
@@ -314,7 +332,7 @@ namespace dlib
>
const T& array<T,mem_manager>::
operator[] (
unsigned long pos
size_t pos
) const
{
// make sure requires clause is not broken
@@ -337,7 +355,7 @@ namespace dlib
>
T& array<T,mem_manager>::
operator[] (
unsigned long pos
size_t pos
)
{
// make sure requires clause is not broken
@@ -360,7 +378,7 @@ namespace dlib
>
void array<T,mem_manager>::
set_size (
unsigned long size
size_t size
)
{
// make sure requires clause is not broken
@@ -386,7 +404,7 @@ namespace dlib
typename T,
typename mem_manager
>
unsigned long array<T,mem_manager>::
size_t array<T,mem_manager>::
size (
) const
{
@@ -401,7 +419,7 @@ namespace dlib
>
void array<T,mem_manager>::
set_max_size(
unsigned long max
size_t max
)
{
reset();
@@ -439,7 +457,7 @@ namespace dlib
typename T,
typename mem_manager
>
unsigned long array<T,mem_manager>::
size_t array<T,mem_manager>::
max_size (
) const
{
@@ -457,8 +475,8 @@ namespace dlib
array<T,mem_manager>& item
)
{
unsigned long array_size_temp = item.array_size;
unsigned long max_array_size_temp = item.max_array_size;
auto array_size_temp = item.array_size;
auto max_array_size_temp = item.max_array_size;
T* array_elements_temp = item.array_elements;
item.array_size = array_size;
@@ -627,7 +645,7 @@ namespace dlib
>
void array<T,mem_manager>::
resize (
unsigned long new_size
size_t new_size
)
{
if (this->max_size() < new_size)
@@ -635,7 +653,7 @@ namespace dlib
array temp;
temp.set_max_size(new_size);
temp.set_size(new_size);
for (unsigned long i = 0; i < this->size(); ++i)
for (size_t i = 0; i < this->size(); ++i)
{
exchange((*this)[i],temp[i]);
}
@@ -750,7 +768,7 @@ namespace dlib
array temp;
temp.set_max_size(this->size()*2 + 1);
temp.set_size(this->size()+1);
for (unsigned long i = 0; i < this->size(); ++i)
for (size_t i = 0; i < this->size(); ++i)
{
exchange((*this)[i],temp[i]);
}
@@ -764,6 +782,17 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
push_back (
T&& item
) { push_back(item); }
// ----------------------------------------------------------------------------------------
template <typename T, typename MM>

View File

@@ -66,7 +66,7 @@ namespace dlib
!*/
explicit array (
unsigned long new_size
size_t new_size
);
/*!
ensures
@@ -85,6 +85,25 @@ namespace dlib
- all memory associated with *this has been released
!*/
array(
array&& item
);
/*!
ensures
- move constructs *this from item. Therefore, the state of item is
moved into *this and #item has a valid but unspecified state.
!*/
array& operator=(
array&& item
);
/*!
ensures
- move assigns *this from item. Therefore, the state of item is
moved into *this and #item has a valid but unspecified state.
- returns a reference to #*this
!*/
void clear (
);
/*!
@@ -97,7 +116,7 @@ namespace dlib
!*/
const T& operator[] (
unsigned long pos
size_t pos
) const;
/*!
requires
@@ -107,7 +126,7 @@ namespace dlib
!*/
T& operator[] (
unsigned long pos
size_t pos
);
/*!
requires
@@ -117,7 +136,7 @@ namespace dlib
!*/
void set_size (
unsigned long size
size_t size
);
/*!
requires
@@ -136,7 +155,7 @@ namespace dlib
if it does throw then the call to set_size() has no effect
!*/
unsigned long max_size(
size_t max_size(
) const;
/*!
ensures
@@ -144,7 +163,7 @@ namespace dlib
!*/
void set_max_size(
unsigned long max
size_t max
);
/*!
ensures
@@ -179,7 +198,7 @@ namespace dlib
!*/
void resize (
unsigned long new_size
size_t new_size
);
/*!
ensures
@@ -255,6 +274,11 @@ namespace dlib
If an exception is thrown then it has no effect on *this.
!*/
void push_back (T&& item) { push_back(item); }
/*!
enable push_back from rvalues
!*/
typedef T* iterator;
typedef const T* const_iterator;

View File

@@ -13,6 +13,11 @@ namespace dlib
{
typedef T pixel_type;
};
template <typename T, typename mm>
struct image_traits<const array2d<T,mm> >
{
typedef T pixel_type;
};
template <typename T, typename mm>
inline long num_rows( const array2d<T,mm>& img) { return img.nr(); }

View File

@@ -60,6 +60,9 @@ namespace dlib
typedef T type;
typedef mem_manager mem_manager_type;
typedef T* iterator;
typedef const T* const_iterator;
// -----------------------------------
@@ -72,7 +75,7 @@ namespace dlib
- (*this)[x] == data[x]
!*/
friend class array2d;
friend class array2d<T,mem_manager>;
friend class row_helper;
public:
@@ -160,6 +163,24 @@ namespace dlib
set_size(rows,cols);
}
array2d(const array2d&) = delete; // copy constructor
array2d& operator=(const array2d&) = delete; // assignment operator
#ifdef DLIB_HAS_RVALUE_REFERENCES
array2d(array2d&& item) : array2d()
{
swap(item);
}
array2d& operator= (
array2d&& rhs
)
{
swap(rhs);
return *this;
}
#endif
virtual ~array2d (
) { clear(); }
@@ -294,8 +315,8 @@ namespace dlib
}
}
unsigned long size (
) const { return static_cast<unsigned long>(nc_ * nr_); }
size_t size (
) const { return static_cast<size_t>(nc_) * static_cast<size_t>(nr_); }
long width_step (
) const
@@ -303,6 +324,27 @@ namespace dlib
return nc_*sizeof(T);
}
iterator begin()
{
return data;
}
iterator end()
{
return data+size();
}
const_iterator begin() const
{
return data;
}
const_iterator end() const
{
return data+size();
}
private:
@@ -315,10 +357,6 @@ namespace dlib
T* last;
mutable bool at_start_;
// restricted functions
array2d(array2d&); // copy constructor
array2d& operator=(array2d&); // assignment operator
};
// ----------------------------------------------------------------------------------------

View File

@@ -64,6 +64,8 @@ namespace dlib
typedef T type;
typedef mem_manager mem_manager_type;
typedef T* iterator;
typedef const T* const_iterator;
// ----------------------------------------
@@ -122,6 +124,18 @@ namespace dlib
- std::bad_alloc
!*/
array2d(const array2d&) = delete; // copy constructor
array2d& operator=(const array2d&) = delete; // assignment operator
array2d(
array2d&& item
);
/*!
ensures
- Moves the state of item into *this.
- #item is in a valid but unspecified state.
!*/
array2d (
long rows,
long cols
@@ -218,6 +232,16 @@ namespace dlib
- swaps *this and item
!*/
array2d& operator= (
array2d&& rhs
);
/*!
ensures
- Moves the state of item into *this.
- #item is in a valid but unspecified state.
- returns #*this
!*/
long width_step (
) const;
/*!
@@ -233,11 +257,41 @@ namespace dlib
An example of such an object is the dlib::cv_image.
!*/
private:
iterator begin(
);
/*!
ensures
- returns a random access iterator pointing to the first element in this
object.
- The iterator will iterate over the elements of the object in row major
order.
!*/
// restricted functions
array2d(array2d&); // copy constructor
array2d& operator=(array2d&); // assignment operator
iterator end(
);
/*!
ensures
- returns a random access iterator pointing to one past the end of the last
element in this object.
!*/
const_iterator begin(
) const;
/*!
ensures
- returns a random access iterator pointing to the first element in this
object.
- The iterator will iterate over the elements of the object in row major
order.
!*/
const_iterator end(
) const;
/*!
ensures
- returns a random access iterator pointing to one past the end of the last
element in this object.
!*/
};

View File

@@ -20,12 +20,53 @@
// (C) Copyright Gennaro Prota 2003.
// (C) Copyright Eric Friedman 2003.
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef BOOST_JOIN
#define BOOST_JOIN( X, Y ) BOOST_DO_JOIN( X, Y )
#define BOOST_DO_JOIN( X, Y ) BOOST_DO_JOIN2(X,Y)
#define BOOST_DO_JOIN2( X, Y ) X##Y
//
#ifndef DLIB_BOOST_JOIN
#define DLIB_BOOST_JOIN( X, Y ) DLIB_BOOST_DO_JOIN( X, Y )
#define DLIB_BOOST_DO_JOIN( X, Y ) DLIB_BOOST_DO_JOIN2(X,Y)
#define DLIB_BOOST_DO_JOIN2( X, Y ) X##Y
#endif
// figure out if the compiler has rvalue references.
#if defined(__clang__)
# if __has_feature(cxx_rvalue_references)
# define DLIB_HAS_RVALUE_REFERENCES
# endif
# if __has_feature(cxx_generalized_initializers)
# define DLIB_HAS_INITIALIZER_LISTS
# endif
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__)
# define DLIB_HAS_RVALUE_REFERENCES
# define DLIB_HAS_INITIALIZER_LISTS
#elif defined(_MSC_VER) && _MSC_VER >= 1800
# define DLIB_HAS_INITIALIZER_LISTS
# define DLIB_HAS_RVALUE_REFERENCES
#elif defined(_MSC_VER) && _MSC_VER >= 1600
# define DLIB_HAS_RVALUE_REFERENCES
#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X)
# define DLIB_HAS_RVALUE_REFERENCES
# define DLIB_HAS_INITIALIZER_LISTS
#endif
#if defined(__APPLE__) && defined(__GNUC_LIBSTD__) && ((__GNUC_LIBSTD__-0) * 100 + __GNUC_LIBSTD_MINOR__-0 <= 402)
// Apple has not updated libstdc++ in some time and anything under 4.02 does not have <initializer_list> for sure.
# undef DLIB_HAS_INITIALIZER_LISTS
#endif
// figure out if the compiler has static_assert.
#if defined(__clang__)
# if __has_feature(cxx_static_assert)
# define DLIB_HAS_STATIC_ASSERT
# endif
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__)
# define DLIB_HAS_STATIC_ASSERT
#elif defined(_MSC_VER) && _MSC_VER >= 1600
# define DLIB_HAS_STATIC_ASSERT
#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X)
# define DLIB_HAS_STATIC_ASSERT
#endif
// -----------------------------
namespace dlib
@@ -37,6 +78,9 @@ namespace dlib
template <typename T> struct assert_are_same_type<T,T> {enum{value=1};};
template <typename T, typename U> struct assert_are_not_same_type {enum{value=1}; };
template <typename T> struct assert_are_not_same_type<T,T> {};
template <typename T, typename U> struct assert_types_match {enum{value=0};};
template <typename T> struct assert_types_match<T,T> {enum{value=1};};
}
@@ -48,14 +92,22 @@ namespace dlib
#define DLIB_NO_WARN_UNUSED
#endif
#define COMPILE_TIME_ASSERT(expression) \
DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value]
// Use the newer static_assert if it's available since it produces much more readable error
// messages.
#ifdef DLIB_HAS_STATIC_ASSERT
#define COMPILE_TIME_ASSERT(expression) static_assert(expression, "Failed assertion")
#define ASSERT_ARE_SAME_TYPE(type1, type2) static_assert(::dlib::assert_types_match<type1,type2>::value, "These types should be the same but aren't.")
#define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) static_assert(!::dlib::assert_types_match<type1,type2>::value, "These types should NOT be the same.")
#else
#define COMPILE_TIME_ASSERT(expression) \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value]
#define ASSERT_ARE_SAME_TYPE(type1, type2) \
DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type<type1,type2>::value]
#define ASSERT_ARE_SAME_TYPE(type1, type2) \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type<type1,type2>::value]
#define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \
DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type<type1,type2>::value]
#define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type<type1,type2>::value]
#endif
// -----------------------------
@@ -88,7 +140,7 @@ namespace dlib
#define DLIB_FUNCTION_NAME "unknown function"
#endif
#define DLIB_CASSERT(_exp,_message) \
#define DLIBM_CASSERT(_exp,_message) \
{if ( !(_exp) ) \
{ \
dlib_assert_breakpoint(); \
@@ -101,12 +153,22 @@ namespace dlib
throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \
}}
// This macro is not needed if you have a real C++ compiler. It's here to work around bugs in Visual Studio's preprocessor.
#define DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(x) x
// Make it so the 2nd argument of DLIB_CASSERT is optional. That is, you can call it like
// DLIB_CASSERT(exp) or DLIB_CASSERT(exp,message).
#define DLIBM_CASSERT_1_ARGS(exp) DLIBM_CASSERT(exp,"")
#define DLIBM_CASSERT_2_ARGS(exp,message) DLIBM_CASSERT(exp,message)
#define DLIBM_GET_3TH_ARG(arg1, arg2, arg3, ...) arg3
#define DLIBM_CASSERT_CHOOSER(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_GET_3TH_ARG(__VA_ARGS__, DLIBM_CASSERT_2_ARGS, DLIBM_CASSERT_1_ARGS))
#define DLIB_CASSERT(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_CASSERT_CHOOSER(__VA_ARGS__)(__VA_ARGS__))
#ifdef ENABLE_ASSERTS
#define DLIB_ASSERT(_exp,_message) DLIB_CASSERT(_exp,_message)
#define DLIB_ASSERT(...) DLIB_CASSERT(__VA_ARGS__)
#define DLIB_IF_ASSERT(exp) exp
#else
#define DLIB_ASSERT(_exp,_message)
#define DLIB_ASSERT(...) {}
#define DLIB_IF_ASSERT(exp)
#endif
@@ -126,8 +188,8 @@ namespace dlib
!*/
// Use the fact that in C++03 you can't put non-PODs into a union.
#define DLIB_ASSERT_HAS_STANDARD_LAYOUT(type) \
union BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \
DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(BOOST_JOIN(DAHSL_,__LINE__))];
union DLIB_BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(DLIB_BOOST_JOIN(DAHSL_,__LINE__))];
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

View File

@@ -1,403 +0,0 @@
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BASE64_KERNEL_1_CPp_
#define DLIB_BASE64_KERNEL_1_CPp_
#include "base64_kernel_1.h"
#include <iostream>
#include <sstream>
#include <climits>
namespace dlib
{
// ----------------------------------------------------------------------------------------
base64::line_ending_type base64::
line_ending (
) const
{
return eol_style;
}
// ----------------------------------------------------------------------------------------
void base64::
set_line_ending (
line_ending_type eol_style_
)
{
eol_style = eol_style_;
}
// ----------------------------------------------------------------------------------------
base64::
base64 (
) :
encode_table(0),
decode_table(0),
bad_value(100),
eol_style(LF)
{
try
{
encode_table = new char[64];
decode_table = new unsigned char[UCHAR_MAX];
}
catch (...)
{
if (encode_table) delete [] encode_table;
if (decode_table) delete [] decode_table;
throw;
}
// now set up the tables with the right stuff
encode_table[0] = 'A';
encode_table[17] = 'R';
encode_table[34] = 'i';
encode_table[51] = 'z';
encode_table[1] = 'B';
encode_table[18] = 'S';
encode_table[35] = 'j';
encode_table[52] = '0';
encode_table[2] = 'C';
encode_table[19] = 'T';
encode_table[36] = 'k';
encode_table[53] = '1';
encode_table[3] = 'D';
encode_table[20] = 'U';
encode_table[37] = 'l';
encode_table[54] = '2';
encode_table[4] = 'E';
encode_table[21] = 'V';
encode_table[38] = 'm';
encode_table[55] = '3';
encode_table[5] = 'F';
encode_table[22] = 'W';
encode_table[39] = 'n';
encode_table[56] = '4';
encode_table[6] = 'G';
encode_table[23] = 'X';
encode_table[40] = 'o';
encode_table[57] = '5';
encode_table[7] = 'H';
encode_table[24] = 'Y';
encode_table[41] = 'p';
encode_table[58] = '6';
encode_table[8] = 'I';
encode_table[25] = 'Z';
encode_table[42] = 'q';
encode_table[59] = '7';
encode_table[9] = 'J';
encode_table[26] = 'a';
encode_table[43] = 'r';
encode_table[60] = '8';
encode_table[10] = 'K';
encode_table[27] = 'b';
encode_table[44] = 's';
encode_table[61] = '9';
encode_table[11] = 'L';
encode_table[28] = 'c';
encode_table[45] = 't';
encode_table[62] = '+';
encode_table[12] = 'M';
encode_table[29] = 'd';
encode_table[46] = 'u';
encode_table[63] = '/';
encode_table[13] = 'N';
encode_table[30] = 'e';
encode_table[47] = 'v';
encode_table[14] = 'O';
encode_table[31] = 'f';
encode_table[48] = 'w';
encode_table[15] = 'P';
encode_table[32] = 'g';
encode_table[49] = 'x';
encode_table[16] = 'Q';
encode_table[33] = 'h';
encode_table[50] = 'y';
// we can now fill out the decode_table by using the encode_table
for (int i = 0; i < UCHAR_MAX; ++i)
{
decode_table[i] = bad_value;
}
for (unsigned char i = 0; i < 64; ++i)
{
decode_table[(unsigned char)encode_table[i]] = i;
}
}
// ----------------------------------------------------------------------------------------
base64::
~base64 (
)
{
delete [] encode_table;
delete [] decode_table;
}
// ----------------------------------------------------------------------------------------
void base64::
encode (
std::istream& in_,
std::ostream& out_
) const
{
using namespace std;
streambuf& in = *in_.rdbuf();
streambuf& out = *out_.rdbuf();
unsigned char inbuf[3];
unsigned char outbuf[4];
streamsize status = in.sgetn(reinterpret_cast<char*>(&inbuf),3);
unsigned char c1, c2, c3, c4, c5, c6;
int counter = 19;
// while we haven't hit the end of the input stream
while (status != 0)
{
if (counter == 0)
{
counter = 19;
// write a newline
char ch;
switch (eol_style)
{
case CR:
ch = '\r';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occured in the base64 object");
break;
case LF:
ch = '\n';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occured in the base64 object");
break;
case CRLF:
ch = '\r';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occured in the base64 object");
ch = '\n';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occured in the base64 object");
break;
default:
DLIB_CASSERT(false,"this should never happen");
}
}
--counter;
if (status == 3)
{
// encode the bytes in inbuf to base64 and write them to the output stream
c1 = inbuf[0]&0xfc;
c2 = inbuf[0]&0x03;
c3 = inbuf[1]&0xf0;
c4 = inbuf[1]&0x0f;
c5 = inbuf[2]&0xc0;
c6 = inbuf[2]&0x3f;
outbuf[0] = c1>>2;
outbuf[1] = (c2<<4)|(c3>>4);
outbuf[2] = (c4<<2)|(c5>>6);
outbuf[3] = c6;
outbuf[0] = encode_table[outbuf[0]];
outbuf[1] = encode_table[outbuf[1]];
outbuf[2] = encode_table[outbuf[2]];
outbuf[3] = encode_table[outbuf[3]];
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
{
throw std::ios_base::failure("error occured in the base64 object");
}
// get 3 more input bytes
status = in.sgetn(reinterpret_cast<char*>(&inbuf),3);
continue;
}
else if (status == 2)
{
// we are at the end of the input stream and need to add some padding
// encode the bytes in inbuf to base64 and write them to the output stream
c1 = inbuf[0]&0xfc;
c2 = inbuf[0]&0x03;
c3 = inbuf[1]&0xf0;
c4 = inbuf[1]&0x0f;
c5 = 0;
outbuf[0] = c1>>2;
outbuf[1] = (c2<<4)|(c3>>4);
outbuf[2] = (c4<<2)|(c5>>6);
outbuf[3] = '=';
outbuf[0] = encode_table[outbuf[0]];
outbuf[1] = encode_table[outbuf[1]];
outbuf[2] = encode_table[outbuf[2]];
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
{
throw std::ios_base::failure("error occured in the base64 object");
}
break;
}
else // in this case status must be 1
{
// we are at the end of the input stream and need to add some padding
// encode the bytes in inbuf to base64 and write them to the output stream
c1 = inbuf[0]&0xfc;
c2 = inbuf[0]&0x03;
c3 = 0;
outbuf[0] = c1>>2;
outbuf[1] = (c2<<4)|(c3>>4);
outbuf[2] = '=';
outbuf[3] = '=';
outbuf[0] = encode_table[outbuf[0]];
outbuf[1] = encode_table[outbuf[1]];
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
{
throw std::ios_base::failure("error occured in the base64 object");
}
break;
}
} // while (status != 0)
// make sure the stream buffer flushes to its I/O channel
out.pubsync();
}
// ----------------------------------------------------------------------------------------
void base64::
decode (
std::istream& in_,
std::ostream& out_
) const
{
using namespace std;
streambuf& in = *in_.rdbuf();
streambuf& out = *out_.rdbuf();
unsigned char inbuf[4];
unsigned char outbuf[3];
int inbuf_pos = 0;
streamsize status = in.sgetn(reinterpret_cast<char*>(inbuf),1);
// only count this character if it isn't some kind of filler
if (status == 1 && decode_table[inbuf[0]] != bad_value )
++inbuf_pos;
unsigned char c1, c2, c3, c4, c5, c6;
streamsize outsize;
// while we haven't hit the end of the input stream
while (status != 0)
{
// if we have 4 valid characters
if (inbuf_pos == 4)
{
inbuf_pos = 0;
// this might be the end of the encoded data so we need to figure out if
// there was any padding applied.
outsize = 3;
if (inbuf[3] == '=')
{
if (inbuf[2] == '=')
outsize = 1;
else
outsize = 2;
}
// decode the incoming characters
inbuf[0] = decode_table[inbuf[0]];
inbuf[1] = decode_table[inbuf[1]];
inbuf[2] = decode_table[inbuf[2]];
inbuf[3] = decode_table[inbuf[3]];
// now pack these guys into bytes rather than 6 bit chunks
c1 = inbuf[0]<<2;
c2 = inbuf[1]>>4;
c3 = inbuf[1]<<4;
c4 = inbuf[2]>>2;
c5 = inbuf[2]<<6;
c6 = inbuf[3];
outbuf[0] = c1|c2;
outbuf[1] = c3|c4;
outbuf[2] = c5|c6;
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),outsize)!=outsize)
{
throw std::ios_base::failure("error occured in the base64 object");
}
}
// get more input characters
status = in.sgetn(reinterpret_cast<char*>(inbuf + inbuf_pos),1);
// only count this character if it isn't some kind of filler
if ((decode_table[inbuf[inbuf_pos]] != bad_value || inbuf[inbuf_pos] == '=') &&
status != 0)
++inbuf_pos;
} // while (status != 0)
if (inbuf_pos != 0)
{
ostringstream sout;
sout << inbuf_pos << " extra characters were found at the end of the encoded data."
<< " This may indicate that the data stream has been truncated.";
// this happens if we hit EOF in the middle of decoding a 24bit block.
throw decode_error(sout.str());
}
// make sure the stream buffer flushes to its I/O channel
out.pubsync();
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BASE64_KERNEL_1_CPp_

View File

@@ -93,7 +93,7 @@ namespace dlib
) const;
/*!
ensures
- reads data from in (until EOF is reached) and decodees it
- reads data from in (until EOF is reached), decodes it,
and writes it to out.
throws
- std::ios_base::failure

View File

@@ -5,6 +5,11 @@
#include "bayes_utils_abstract.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <vector>
#include "../string.h"
#include "../map.h"
#include "../matrix.h"
@@ -13,11 +18,7 @@
#include "../set.h"
#include "../algs.h"
#include "../noncopyable.h"
#include "../smart_pointers.h"
#include "../graph.h"
#include <vector>
#include <algorithm>
#include <ctime>
namespace dlib
{
@@ -355,7 +356,7 @@ namespace dlib
table.clear();
}
unsigned long size () const { return table.size(); }
size_t size () const { return table.size(); }
bool move_next() const { return table.move_next(); }
void reset() const { table.reset(); }
map_pair<assignment,double>& element()
@@ -1659,7 +1660,7 @@ namespace dlib
private:
scoped_ptr<bayesian_network_join_tree_helpers::bnjt> impl;
std::unique_ptr<bayesian_network_join_tree_helpers::bnjt> impl;
unsigned long num_nodes;
};

View File

@@ -367,7 +367,7 @@ namespace dlib
std::ostream& out
);
/*!
provides deserialization support
provides serialization support
!*/
void deserialize (
@@ -499,7 +499,7 @@ namespace dlib
std::ostream& out
);
/*!
provides deserialization support
provides serialization support
!*/
void deserialize (
@@ -615,7 +615,7 @@ namespace dlib
std::ostream& out
);
/*!
provides deserialization support
provides serialization support
!*/
void deserialize (

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,6 @@
namespace dlib
{
using namespace dlib::relational_operators; // defined in algs.h
class bigint_kernel_1
{
@@ -531,6 +530,10 @@ namespace dlib
}
}
inline bool operator> (const bigint_kernel_1& a, const bigint_kernel_1& b) { return b < a; }
inline bool operator!= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a == b); }
inline bool operator<= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(b < a); }
inline bool operator>= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a < b); }
}
#ifdef NO_MAKEFILE

File diff suppressed because it is too large Load Diff

View File

@@ -15,8 +15,6 @@
namespace dlib
{
using namespace dlib::relational_operators; // defined in algs.h
class bigint_kernel_2
{
/*!
@@ -557,6 +555,11 @@ namespace dlib
}
}
inline bool operator> (const bigint_kernel_2& a, const bigint_kernel_2& b) { return b < a; }
inline bool operator!= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a == b); }
inline bool operator<= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(b < a); }
inline bool operator>= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a < b); }
}
#ifdef NO_MAKEFILE

View File

@@ -10,7 +10,6 @@
namespace dlib
{
using namespace dlib::relational_operators; // defined in algs.h
class bigint
{
@@ -661,6 +660,10 @@ namespace dlib
provides deserialization support
!*/
inline bool operator> (const bigint& a, const bigint& b) { return b < a; }
inline bool operator!= (const bigint& a, const bigint& b) { return !(a == b); }
inline bool operator<= (const bigint& a, const bigint& b) { return !(b < a); }
inline bool operator>= (const bigint& a, const bigint& b) { return !(a < b); }
}
#endif // DLIB_BIGINT_KERNEl_ABSTRACT_

View File

@@ -1122,6 +1122,17 @@ namespace dlib
return *this;
}
// ----------------------------------------------------------------------------------------
template < typename bigint_base >
inline bool operator> (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return b < a; }
template < typename bigint_base >
inline bool operator!= (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return !(a == b); }
template < typename bigint_base >
inline bool operator<= (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return !(b < a); }
template < typename bigint_base >
inline bool operator>= (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return !(a < b); }
// ----------------------------------------------------------------------------------------
}

View File

@@ -168,7 +168,7 @@ namespace dlib
);
// functions from the enumerable interface
inline unsigned long size (
inline size_t size (
) const;
bool at_start (
@@ -597,7 +597,7 @@ namespace dlib
typename mem_manager,
typename compare
>
unsigned long binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
size_t binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
size (
) const
{

View File

@@ -169,7 +169,7 @@ namespace dlib
);
// functions from the enumerable interface
inline unsigned long size (
inline size_t size (
) const;
bool at_start (
@@ -543,7 +543,7 @@ namespace dlib
typename mem_manager,
typename compare
>
unsigned long binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
size_t binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
size (
) const
{

View File

@@ -1,200 +0,0 @@
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BIT_STREAM_KERNEL_1_CPp_
#define DLIB_BIT_STREAM_KERNEL_1_CPp_
#include "bit_stream_kernel_1.h"
#include "../algs.h"
#include <iostream>
namespace dlib
{
inline void swap (
bit_stream_kernel_1& a,
bit_stream_kernel_1& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void bit_stream_kernel_1::
clear (
)
{
if (write_mode)
{
write_mode = false;
// flush output buffer
if (buffer_size > 0)
{
buffer <<= 8 - buffer_size;
osp->write(reinterpret_cast<char*>(&buffer),1);
}
}
else
read_mode = false;
}
// ----------------------------------------------------------------------------------------
void bit_stream_kernel_1::
set_input_stream (
std::istream& is
)
{
isp = &is;
read_mode = true;
buffer_size = 0;
}
// ----------------------------------------------------------------------------------------
void bit_stream_kernel_1::
set_output_stream (
std::ostream& os
)
{
osp = &os;
write_mode = true;
buffer_size = 0;
}
// ----------------------------------------------------------------------------------------
void bit_stream_kernel_1::
close (
)
{
if (write_mode)
{
write_mode = false;
// flush output buffer
if (buffer_size > 0)
{
buffer <<= 8 - buffer_size;
osp->write(reinterpret_cast<char*>(&buffer),1);
}
}
else
read_mode = false;
}
// ----------------------------------------------------------------------------------------
bool bit_stream_kernel_1::
is_in_write_mode (
) const
{
return write_mode;
}
// ----------------------------------------------------------------------------------------
bool bit_stream_kernel_1::
is_in_read_mode (
) const
{
return read_mode;
}
// ----------------------------------------------------------------------------------------
void bit_stream_kernel_1::
write (
int bit
)
{
// flush buffer if necessary
if (buffer_size == 8)
{
buffer <<= 8 - buffer_size;
if (osp->rdbuf()->sputn(reinterpret_cast<char*>(&buffer),1) == 0)
{
throw std::ios_base::failure("error occured in the bit_stream object");
}
buffer_size = 0;
}
++buffer_size;
buffer <<= 1;
buffer += static_cast<unsigned char>(bit);
}
// ----------------------------------------------------------------------------------------
bool bit_stream_kernel_1::
read (
int& bit
)
{
// get new byte if necessary
if (buffer_size == 0)
{
if (isp->rdbuf()->sgetn(reinterpret_cast<char*>(&buffer), 1) == 0)
{
// if we didn't read anything then return false
return false;
}
buffer_size = 8;
}
// put the most significant bit from buffer into bit
bit = static_cast<int>(buffer >> 7);
// shift out the bit that was just read
buffer <<= 1;
--buffer_size;
return true;
}
// ----------------------------------------------------------------------------------------
void bit_stream_kernel_1::
swap (
bit_stream_kernel_1& item
)
{
std::istream* isp_temp = item.isp;
std::ostream* osp_temp = item.osp;
bool write_mode_temp = item.write_mode;
bool read_mode_temp = item.read_mode;
unsigned char buffer_temp = item.buffer;
unsigned short buffer_size_temp = item.buffer_size;
item.isp = isp;
item.osp = osp;
item.write_mode = write_mode;
item.read_mode = read_mode;
item.buffer = buffer;
item.buffer_size = buffer_size;
isp = isp_temp;
osp = osp_temp;
write_mode = write_mode_temp;
read_mode = read_mode_temp;
buffer = buffer_temp;
buffer_size = buffer_size_temp;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BIT_STREAM_KERNEL_1_CPp_

View File

@@ -1,5 +1,11 @@
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifdef DLIB_ALL_SOURCE_END
#include "dlib_basic_cpp_build_tutorial.txt"
#endif
#ifndef DLIB_BRIdGE_
#define DLIB_BRIdGE_

View File

@@ -3,17 +3,19 @@
#ifndef DLIB_BRIDGe_Hh_
#define DLIB_BRIDGe_Hh_
#include "bridge_abstract.h"
#include <iostream>
#include <memory>
#include <string>
#include "bridge_abstract.h"
#include "../pipe.h"
#include "../threads.h"
#include "../smart_pointers.h"
#include "../serialize.h"
#include "../sockets.h"
#include "../sockstreambuf.h"
#include "../logger.h"
#include "../algs.h"
#include <iostream>
namespace dlib
{
@@ -141,7 +143,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
namespace impl
namespace impl_brns
{
class impl_bridge_base
{
@@ -545,8 +547,8 @@ namespace dlib
signaler s;
bool receive_thread_active;
bool transmit_thread_active;
scoped_ptr<connection> con;
scoped_ptr<listener> list;
std::unique_ptr<connection> con;
std::unique_ptr<listener> list;
const unsigned short port;
const std::string ip;
transmit_pipe_type* const transmit_pipe;
@@ -594,26 +596,26 @@ namespace dlib
listen_on_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename T, typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_receive_decoration<R> receive_pipe,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename T >
void reconfigure (
listen_on_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<T,T>(network_parameters.port, &transmit_pipe.p, 0)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,T>(network_parameters.port, &transmit_pipe.p, 0)); }
template < typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<R,R>(network_parameters.port, 0, &receive_pipe.p)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<R,R>(network_parameters.port, 0, &receive_pipe.p)); }
@@ -623,26 +625,26 @@ namespace dlib
connect_to_ip_and_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename T, typename R >
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_receive_decoration<R> receive_pipe,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename R >
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<R,R>(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<R,R>(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); }
template < typename T >
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl::impl_bridge<T,T>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); }
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,T>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); }
bridge_status get_bridge_status (
@@ -656,7 +658,7 @@ namespace dlib
private:
scoped_ptr<impl::impl_bridge_base> pimpl;
std::unique_ptr<impl_brns::impl_bridge_base> pimpl;
};
// ----------------------------------------------------------------------------------------

View File

@@ -1,495 +0,0 @@
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BSP_CPph_
#define DLIB_BSP_CPph_
#include "bsp.h"
#include <stack>
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace dlib
{
namespace impl1
{
void connect_all (
map_id_to_con& cons,
const std::vector<network_address>& hosts,
unsigned long node_id
)
{
cons.clear();
for (unsigned long i = 0; i < hosts.size(); ++i)
{
scoped_ptr<bsp_con> con(new bsp_con(hosts[i]));
dlib::serialize(node_id, con->stream); // tell the other end our node_id
unsigned long id = i+1;
cons.add(id, con);
}
}
void connect_all_hostinfo (
map_id_to_con& cons,
const std::vector<hostinfo>& hosts,
unsigned long node_id,
std::string& error_string
)
{
cons.clear();
for (unsigned long i = 0; i < hosts.size(); ++i)
{
try
{
scoped_ptr<bsp_con> con(new bsp_con(hosts[i].addr));
dlib::serialize(node_id, con->stream); // tell the other end our node_id
con->stream.flush();
unsigned long id = hosts[i].node_id;
cons.add(id, con);
}
catch (std::exception&)
{
std::ostringstream sout;
sout << "Could not connect to " << hosts[i].addr;
error_string = sout.str();
break;
}
}
}
void send_out_connection_orders (
map_id_to_con& cons,
const std::vector<network_address>& hosts
)
{
// tell everyone their node ids
cons.reset();
while (cons.move_next())
{
dlib::serialize(cons.element().key(), cons.element().value()->stream);
}
// now tell them who to connect to
std::vector<hostinfo> targets;
for (unsigned long i = 0; i < hosts.size(); ++i)
{
hostinfo info(hosts[i], i+1);
dlib::serialize(targets, cons[info.node_id]->stream);
targets.push_back(info);
// let the other host know how many incoming connections to expect
const unsigned long num = hosts.size()-targets.size();
dlib::serialize(num, cons[info.node_id]->stream);
cons[info.node_id]->stream.flush();
}
}
// ------------------------------------------------------------------------------------
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace impl2
{
// These control bytes are sent before each message between nodes. Note that many
// of these are only sent between the control node (node 0) and the other nodes.
// This is because the controller node is responsible for handling the
// synchronization that needs to happen when all nodes block on calls to
// receive_data()
// at the same time.
// denotes a normal content message.
const static char MESSAGE_HEADER = 0;
// sent to the controller node when someone receives a message via receive_data().
const static char GOT_MESSAGE = 1;
// sent to the controller node when someone sends a message via send().
const static char SENT_MESSAGE = 2;
// sent to the controller node when someone enters a call to receive_data()
const static char IN_WAITING_STATE = 3;
// broadcast when a node terminates itself.
const static char NODE_TERMINATE = 5;
// broadcast by the controller node when it determines that all nodes are blocked
// on calls to receive_data() and there aren't any messages in flight. This is also
// what makes us go to the next epoch.
const static char SEE_ALL_IN_WAITING_STATE = 6;
// This isn't ever transmitted between nodes. It is used internally to indicate
// that an error occurred.
const static char READ_ERROR = 7;
// ------------------------------------------------------------------------------------
void read_thread (
impl1::bsp_con* con,
unsigned long node_id,
unsigned long sender_id,
impl1::thread_safe_message_queue& msg_buffer
)
{
try
{
while(true)
{
impl1::msg_data msg;
deserialize(msg.msg_type, con->stream);
msg.sender_id = sender_id;
if (msg.msg_type == MESSAGE_HEADER)
{
msg.data.reset(new std::vector<char>);
deserialize(msg.epoch, con->stream);
deserialize(*msg.data, con->stream);
}
msg_buffer.push_and_consume(msg);
if (msg.msg_type == NODE_TERMINATE)
break;
}
}
catch (std::exception& e)
{
impl1::msg_data msg;
msg.data.reset(new std::vector<char>);
vectorstream sout(*msg.data);
sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
sout << " Receiving processing node id: " << node_id << std::endl;
sout << " Error message in the exception: " << e.what() << std::endl;
msg.sender_id = sender_id;
msg.msg_type = READ_ERROR;
msg_buffer.push_and_consume(msg);
}
catch (...)
{
impl1::msg_data msg;
msg.data.reset(new std::vector<char>);
vectorstream sout(*msg.data);
sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
sout << " Receiving processing node id: " << node_id << std::endl;
msg.sender_id = sender_id;
msg.msg_type = READ_ERROR;
msg_buffer.push_and_consume(msg);
}
}
// ------------------------------------------------------------------------------------
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// IMPLEMENTATION OF bsp_context OBJECT MEMBERS
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void bsp_context::
close_all_connections_gracefully(
)
{
if (node_id() != 0)
{
_cons.reset();
while (_cons.move_next())
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
}
}
impl1::msg_data msg;
// now wait for all the other nodes to terminate
while (num_terminated_nodes < _cons.size() )
{
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
++current_epoch;
}
if (!msg_buffer.pop(msg))
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
if (msg.msg_type == impl2::NODE_TERMINATE)
{
++num_terminated_nodes;
_cons[msg.sender_id]->terminated = true;
}
else if (msg.msg_type == impl2::READ_ERROR)
{
throw dlib::socket_error(msg.data_to_string());
}
else if (msg.msg_type == impl2::MESSAGE_HEADER)
{
throw dlib::socket_error("A BSP node received a message after it has terminated.");
}
else if (msg.msg_type == impl2::GOT_MESSAGE)
{
--num_waiting_nodes;
--outstanding_messages;
}
else if (msg.msg_type == impl2::SENT_MESSAGE)
{
++outstanding_messages;
}
else if (msg.msg_type == impl2::IN_WAITING_STATE)
{
++num_waiting_nodes;
}
}
if (node_id() == 0)
{
_cons.reset();
while (_cons.move_next())
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
}
if (outstanding_messages != 0)
{
std::ostringstream sout;
sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
sout << "have a corresponding call to receive().";
throw dlib::socket_error(sout.str());
}
}
}
// ----------------------------------------------------------------------------------------
bsp_context::
~bsp_context()
{
_cons.reset();
while (_cons.move_next())
{
_cons.element().value()->con->shutdown();
}
msg_buffer.disable();
// this will wait for all the threads to terminate
threads.clear();
}
// ----------------------------------------------------------------------------------------
bsp_context::
bsp_context(
unsigned long node_id_,
impl1::map_id_to_con& cons_
) :
outstanding_messages(0),
num_waiting_nodes(0),
num_terminated_nodes(0),
current_epoch(1),
_cons(cons_),
_node_id(node_id_)
{
// spawn a bunch of read threads, one for each connection
_cons.reset();
while (_cons.move_next())
{
scoped_ptr<thread_function> ptr(new thread_function(&impl2::read_thread,
_cons.element().value().get(),
_node_id,
_cons.element().key(),
ref(msg_buffer)));
threads.push_back(ptr);
}
}
// ----------------------------------------------------------------------------------------
bool bsp_context::
receive_data (
shared_ptr<std::vector<char> >& item,
unsigned long& sending_node_id
)
{
notify_control_node(impl2::IN_WAITING_STATE);
while (true)
{
// If there aren't any nodes left to give us messages then return right now.
// We need to check the msg_buffer size to make sure there aren't any
// unprocessed message there. Recall that this can happen because status
// messages always jump to the front of the message buffer. So we might have
// learned about the node terminations before processing their messages for us.
if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
{
return false;
}
// if all running nodes are currently blocking forever on receive_data()
if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
// Note that the reason we have this epoch counter is so we can tell if a
// sent message is from before or after one of these "all nodes waiting"
// synchronization events. If we didn't have the epoch count we would have
// a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
// message before others and then sends out a message to another node
// before that node got the SEE_ALL_IN_WAITING_STATE message. Then that
// node would think the normal message came before SEE_ALL_IN_WAITING_STATE
// which would be bad.
++current_epoch;
return false;
}
impl1::msg_data data;
if (!msg_buffer.pop(data, current_epoch))
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
switch(data.msg_type)
{
case impl2::MESSAGE_HEADER: {
item = data.data;
sending_node_id = data.sender_id;
notify_control_node(impl2::GOT_MESSAGE);
return true;
} break;
case impl2::IN_WAITING_STATE: {
++num_waiting_nodes;
} break;
case impl2::GOT_MESSAGE: {
--outstanding_messages;
--num_waiting_nodes;
} break;
case impl2::SENT_MESSAGE: {
++outstanding_messages;
} break;
case impl2::NODE_TERMINATE: {
++num_terminated_nodes;
_cons[data.sender_id]->terminated = true;
} break;
case impl2::SEE_ALL_IN_WAITING_STATE: {
++current_epoch;
return false;
} break;
case impl2::READ_ERROR: {
throw dlib::socket_error(data.data_to_string());
} break;
default: {
throw dlib::socket_error("Unknown message received by dlib::bsp_context");
} break;
} // end switch()
} // end while (true)
}
// ----------------------------------------------------------------------------------------
void bsp_context::
notify_control_node (
char val
)
{
if (node_id() == 0)
{
using namespace impl2;
switch(val)
{
case SENT_MESSAGE: {
++outstanding_messages;
} break;
case GOT_MESSAGE: {
--outstanding_messages;
} break;
case IN_WAITING_STATE: {
// nothing to do in this case
} break;
default:
DLIB_CASSERT(false,"This should never happen");
}
}
else
{
serialize(val, _cons[0]->stream);
_cons[0]->stream.flush();
}
}
// ----------------------------------------------------------------------------------------
void bsp_context::
broadcast_byte (
char val
)
{
for (unsigned long i = 0; i < number_of_nodes(); ++i)
{
// don't send to yourself or to terminated nodes
if (i == node_id() || _cons[i]->terminated)
continue;
serialize(val, _cons[i]->stream);
_cons[i]->stream.flush();
}
}
// ----------------------------------------------------------------------------------------
void bsp_context::
send_data(
const std::vector<char>& item,
unsigned long target_node_id
)
{
using namespace impl2;
if (_cons[target_node_id]->terminated)
throw socket_error("Attempt to send a message to a node that has terminated.");
serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
serialize(current_epoch, _cons[target_node_id]->stream);
serialize(item, _cons[target_node_id]->stream);
_cons[target_node_id]->stream.flush();
notify_control_node(SENT_MESSAGE);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BSP_CPph_

View File

@@ -4,17 +4,19 @@
#define DLIB_BsP_Hh_
#include "bsp_abstract.h"
#include <memory>
#include <queue>
#include <vector>
#include "../sockets.h"
#include "../array.h"
#include "../smart_pointers.h"
#include "../sockstreambuf.h"
#include "../string.h"
#include "../serialize.h"
#include "../map.h"
#include "../ref.h"
#include "../vectorstream.h"
#include <queue>
#include <vector>
namespace dlib
{
@@ -41,7 +43,7 @@ namespace dlib
}
bsp_con(
scoped_ptr<connection>& conptr
std::unique_ptr<connection>& conptr
) :
buf(conptr),
stream(&buf),
@@ -53,13 +55,13 @@ namespace dlib
con->disable_nagle();
}
scoped_ptr<connection> con;
std::unique_ptr<connection> con;
sockstreambuf buf;
std::iostream stream;
bool terminated;
};
typedef dlib::map<unsigned long, scoped_ptr<bsp_con> >::kernel_1a_c map_id_to_con;
typedef dlib::map<unsigned long, std::unique_ptr<bsp_con> >::kernel_1a_c map_id_to_con;
void connect_all (
map_id_to_con& cons,
@@ -134,7 +136,7 @@ namespace dlib
)
{
cons.clear();
scoped_ptr<listener> list;
std::unique_ptr<listener> list;
const int status = create_listener(list, port);
if (status == PORTINUSE)
{
@@ -148,13 +150,13 @@ namespace dlib
port_notify_function(list->get_listening_port());
scoped_ptr<connection> con;
std::unique_ptr<connection> con;
if (list->accept(con))
{
throw socket_error("Error occurred while accepting new connection");
}
scoped_ptr<bsp_con> temp(new bsp_con(con));
std::unique_ptr<bsp_con> temp(new bsp_con(con));
unsigned long remote_node_id;
dlib::deserialize(remote_node_id, temp->stream);
@@ -197,7 +199,7 @@ namespace dlib
while (cons2.size() > 0)
{
unsigned long id;
scoped_ptr<bsp_con> temp;
std::unique_ptr<bsp_con> temp;
cons2.remove_any(id,temp);
cons.add(id,temp);
}
@@ -207,7 +209,7 @@ namespace dlib
struct msg_data
{
shared_ptr<std::vector<char> > data;
std::shared_ptr<std::vector<char> > data;
unsigned long sender_id;
char msg_type;
dlib::uint64 epoch;
@@ -420,7 +422,7 @@ namespace dlib
)
{
unsigned long id;
shared_ptr<std::vector<char> > temp;
std::shared_ptr<std::vector<char> > temp;
if (receive_data(temp,id))
throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message.");
}
@@ -459,7 +461,7 @@ namespace dlib
unsigned long& sending_node_id
)
{
shared_ptr<std::vector<char> > temp;
std::shared_ptr<std::vector<char> > temp;
if (receive_data(temp, sending_node_id))
{
vectorstream sin(*temp);
@@ -496,7 +498,7 @@ namespace dlib
!*/
bool receive_data (
shared_ptr<std::vector<char> >& item,
std::shared_ptr<std::vector<char> >& item,
unsigned long& sending_node_id
);
@@ -533,7 +535,7 @@ namespace dlib
impl1::map_id_to_con& _cons;
const unsigned long _node_id;
array<scoped_ptr<thread_function> > threads;
array<std::unique_ptr<thread_function> > threads;
// -----------------------------------

View File

@@ -1 +0,0 @@
#include "dlib_include_path_tutorial.txt"

View File

@@ -5,6 +5,8 @@
#include "clustering/modularity_clustering.h"
#include "clustering/chinese_whispers.h"
#include "clustering/spectral_cluster.h"
#include "clustering/bottom_up_cluster.h"
#include "svm/kkmeans.h"
#endif // DLIB_CLuSTERING_

View File

@@ -0,0 +1,253 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_
#define DLIB_BOTTOM_uP_CLUSTER_Hh_
#include <queue>
#include <map>
#include "bottom_up_cluster_abstract.h"
#include "../algs.h"
#include "../matrix.h"
#include "../disjoint_subsets.h"
#include "../graph_utils.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace buc_impl
{
inline void merge_sets (
matrix<double>& dists,
unsigned long dest,
unsigned long src
)
{
for (long r = 0; r < dists.nr(); ++r)
dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src));
}
struct compare_dist
{
bool operator() (
const sample_pair& a,
const sample_pair& b
) const
{
return a.distance() > b.distance();
}
};
}
// ----------------------------------------------------------------------------------------
template <
typename EXP
>
unsigned long bottom_up_cluster (
const matrix_exp<EXP>& dists_,
std::vector<unsigned long>& labels,
unsigned long min_num_clusters,
double max_dist = std::numeric_limits<double>::infinity()
)
{
matrix<double> dists = matrix_cast<double>(dists_);
// make sure requires clause is not broken
DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0,
"\t unsigned long bottom_up_cluster()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t dists.nr(): " << dists.nr()
<< "\n\t dists.nc(): " << dists.nc()
<< "\n\t min_num_clusters: " << min_num_clusters
);
using namespace buc_impl;
labels.resize(dists.nr());
disjoint_subsets sets;
sets.set_size(dists.nr());
if (labels.size() == 0)
return 0;
// push all the edges in the graph into a priority queue so the best edges to merge
// come first.
std::priority_queue<sample_pair, std::vector<sample_pair>, compare_dist> que;
for (long r = 0; r < dists.nr(); ++r)
for (long c = r+1; c < dists.nc(); ++c)
que.push(sample_pair(r,c,dists(r,c)));
// Now start merging nodes.
for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter)
{
// find the next best thing to merge.
double best_dist = que.top().distance();
unsigned long a = sets.find_set(que.top().index1());
unsigned long b = sets.find_set(que.top().index2());
que.pop();
// we have been merging and modifying the distances, so make sure this distance
// is still valid and these guys haven't been merged already.
while(a == b || best_dist < dists(a,b))
{
// Haven't merged it yet, so put it back in with updated distance for
// reconsideration later.
if (a != b)
que.push(sample_pair(a, b, dists(a, b)));
best_dist = que.top().distance();
a = sets.find_set(que.top().index1());
b = sets.find_set(que.top().index2());
que.pop();
}
// now merge these sets if the best distance is small enough
if (best_dist > max_dist)
break;
unsigned long news = sets.merge_sets(a,b);
unsigned long olds = (news==a)?b:a;
merge_sets(dists, news, olds);
}
// figure out which cluster each element is in. Also make sure the labels are
// contiguous.
std::map<unsigned long, unsigned long> relabel;
for (unsigned long r = 0; r < labels.size(); ++r)
{
unsigned long l = sets.find_set(r);
// relabel to make contiguous
if (relabel.count(l) == 0)
{
unsigned long next = relabel.size();
relabel[l] = next;
}
labels[r] = relabel[l];
}
return relabel.size();
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
struct snl_range
{
snl_range() = default;
snl_range(double val) : lower(val), upper(val) {}
snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)}
double lower = 0;
double upper = 0;
double width() const { return upper-lower; }
bool operator<(const snl_range& item) const { return lower < item.lower; }
};
inline snl_range merge(const snl_range& a, const snl_range& b)
{
return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper));
}
inline double distance (const snl_range& a, const snl_range& b)
{
return std::max(a.lower,b.lower) - std::min(a.upper,b.upper);
}
inline std::ostream& operator<< (std::ostream& out, const snl_range& item )
{
out << "["<<item.lower<<","<<item.upper<<"]";
return out;
}
// ----------------------------------------------------------------------------------------
inline std::vector<snl_range> segment_number_line (
const std::vector<double>& x,
const double max_range_width
)
{
DLIB_CASSERT(max_range_width >= 0);
// create initial ranges, one for each value in x. So initially, all the ranges have
// width of 0.
std::vector<snl_range> ranges;
for (auto v : x)
ranges.push_back(v);
std::sort(ranges.begin(), ranges.end());
std::vector<snl_range> greedy_final_ranges;
if (ranges.size() == 0)
return greedy_final_ranges;
// We will try two different clustering strategies. One that does a simple greedy left
// to right sweep and another that does a bottom up agglomerative clustering. This
// first loop runs the greedy left to right sweep. Then at the end of this routine we
// will return the results that produced the tightest clustering.
greedy_final_ranges.push_back(ranges[0]);
for (size_t i = 1; i < ranges.size(); ++i)
{
auto m = merge(greedy_final_ranges.back(), ranges[i]);
if (m.width() <= max_range_width)
greedy_final_ranges.back() = m;
else
greedy_final_ranges.push_back(ranges[i]);
}
// Here we do the bottom up clustering. So compute the edges connecting our ranges.
// We will simply say there are edges between ranges if and only if they are
// immediately adjacent on the number line.
std::vector<sample_pair> edges;
for (size_t i = 1; i < ranges.size(); ++i)
edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i])));
std::sort(edges.begin(), edges.end(), order_by_distance<sample_pair>);
disjoint_subsets sets;
sets.set_size(ranges.size());
// Now start merging nodes.
for (auto edge : edges)
{
// find the next best thing to merge.
unsigned long a = sets.find_set(edge.index1());
unsigned long b = sets.find_set(edge.index2());
// merge it if it doesn't result in an interval that's too big.
auto m = merge(ranges[a], ranges[b]);
if (m.width() <= max_range_width)
{
unsigned long news = sets.merge_sets(a,b);
ranges[news] = m;
}
}
// Now create a list of the final ranges. We will do this by keeping track of which
// range we already added to final_ranges.
std::vector<snl_range> final_ranges;
std::vector<bool> already_output(ranges.size(), false);
for (unsigned long i = 0; i < sets.size(); ++i)
{
auto s = sets.find_set(i);
if (!already_output[s])
{
final_ranges.push_back(ranges[s]);
already_output[s] = true;
}
}
// only use the greedy clusters if they found a clustering with fewer clusters.
// Otherwise, the bottom up clustering probably produced a more sensible clustering.
if (final_ranges.size() <= greedy_final_ranges.size())
return final_ranges;
else
return greedy_final_ranges;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_

View File

@@ -0,0 +1,136 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
#ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
#include "../matrix.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename EXP
>
unsigned long bottom_up_cluster (
const matrix_exp<EXP>& dists,
std::vector<unsigned long>& labels,
unsigned long min_num_clusters,
double max_dist = std::numeric_limits<double>::infinity()
);
/*!
requires
- dists.nr() == dists.nc()
- min_num_clusters > 0
- dists == trans(dists)
(l.e. dists should be symmetric)
ensures
- Runs a bottom up agglomerative clustering algorithm.
- Interprets dists as a matrix that gives the distances between dists.nr()
items. In particular, we take dists(i,j) to be the distance between the ith
and jth element of some set. This function clusters the elements of this set
into at least min_num_clusters (or dists.nr() if there aren't enough
elements). Additionally, within each cluster, the maximum pairwise distance
between any two cluster elements is <= max_dist.
- returns the number of clusters found.
- #labels.size() == dists.nr()
- for all valid i:
- #labels[i] == the cluster ID of the node with index i (i.e. the node
corresponding to the distances dists(i,*)).
- 0 <= #labels[i] < the number of clusters found
(i.e. cluster IDs are assigned contiguously and start at 0)
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
struct snl_range
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents an interval on the real number line. It is used
to store the outputs of the segment_number_line() routine defined below.
!*/
snl_range(
);
/*!
ensures
- #lower == 0
- #upper == 0
!*/
snl_range(
double val
);
/*!
ensures
- #lower == val
- #upper == val
!*/
snl_range(
double l,
double u
);
/*!
requires
- l <= u
ensures
- #lower == l
- #upper == u
!*/
double lower;
double upper;
double width(
) const { return upper-lower; }
/*!
ensures
- returns the width of this interval on the number line.
!*/
bool operator<(const snl_range& item) const { return lower < item.lower; }
/*!
ensures
- provides a total ordering of snl_range objects assuming they are
non-overlapping.
!*/
};
std::ostream& operator<< (std::ostream& out, const snl_range& item );
/*!
ensures
- prints item to out in the form [lower,upper].
!*/
// ----------------------------------------------------------------------------------------
std::vector<snl_range> segment_number_line (
const std::vector<double>& x,
const double max_range_width
);
/*!
requires
- max_range_width >= 0
ensures
- Finds a clustering of the values in x and returns the ranges that define the
clustering. This routine uses a combination of bottom up clustering and a
simple greedy scan to try and find the most compact set of ranges that
contain all the values in x.
- This routine has approximately linear runtime.
- Every value in x will be contained inside one of the returned snl_range
objects;
- All returned snl_range object's will have a width() <= max_range_width and
will also be non-overlapping.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_

View File

@@ -0,0 +1,80 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SPECTRAL_CLUSTEr_H_
#define DLIB_SPECTRAL_CLUSTEr_H_
#include "spectral_cluster_abstract.h"
#include <vector>
#include "../matrix.h"
#include "../svm/kkmeans.h"
namespace dlib
{
template <
typename kernel_type,
typename vector_type
>
std::vector<unsigned long> spectral_cluster (
const kernel_type& k,
const vector_type& samples,
const unsigned long num_clusters
)
{
DLIB_CASSERT(num_clusters > 0,
"\t std::vector<unsigned long> spectral_cluster(k,samples,num_clusters)"
<< "\n\t num_clusters can't be 0."
);
if (num_clusters == 1)
{
// nothing to do, just assign everything to the 0 cluster.
return std::vector<unsigned long>(samples.size(), 0);
}
// compute the similarity matrix.
matrix<double> K(samples.size(), samples.size());
for (long r = 0; r < K.nr(); ++r)
for (long c = r+1; c < K.nc(); ++c)
K(r,c) = K(c,r) = (double)k(samples[r], samples[c]);
for (long r = 0; r < K.nr(); ++r)
K(r,r) = 0;
matrix<double,0,1> D(K.nr());
for (long r = 0; r < K.nr(); ++r)
D(r) = sum(rowm(K,r));
D = sqrt(reciprocal(D));
K = diagm(D)*K*diagm(D);
matrix<double> u,w,v;
// Use the normal SVD routine unless the matrix is really big, then use the fast
// approximate version.
if (K.nr() < 1000)
svd3(K,u,w,v);
else
svd_fast(K,u,w,v, num_clusters+100, 5);
// Pick out the eigenvectors associated with the largest eigenvalues.
rsort_columns(v,w);
v = colm(v, range(0,num_clusters-1));
// Now build the normalized spectral vectors, one for each input vector.
std::vector<matrix<double,0,1> > spec_samps, centers;
for (long r = 0; r < v.nr(); ++r)
{
spec_samps.push_back(trans(rowm(v,r)));
const double len = length(spec_samps.back());
if (len != 0)
spec_samps.back() /= len;
}
// Finally do the K-means clustering
pick_initial_centers(num_clusters, centers, spec_samps);
find_clusters_using_kmeans(spec_samps, centers);
// And then compute the cluster assignments based on the output of K-means.
std::vector<unsigned long> assignments;
for (unsigned long i = 0; i < spec_samps.size(); ++i)
assignments.push_back(nearest_center(centers, spec_samps[i]));
return assignments;
}
}
#endif // DLIB_SPECTRAL_CLUSTEr_H_

View File

@@ -0,0 +1,43 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_
#ifdef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_
#include <vector>
namespace dlib
{
template <
typename kernel_type,
typename vector_type
>
std::vector<unsigned long> spectral_cluster (
const kernel_type& k,
const vector_type& samples,
const unsigned long num_clusters
);
/*!
requires
- samples must be something with an interface compatible with std::vector.
- The following expression must evaluate to a double or float:
k(samples[i], samples[j])
- num_clusters > 0
ensures
- Performs the spectral clustering algorithm described in the paper:
On spectral clustering: Analysis and an algorithm by Ng, Jordan, and Weiss.
and returns the results.
- This function clusters the input data samples into num_clusters clusters and
returns a vector that indicates which cluster each sample falls into. In
particular, we return an array A such that:
- A.size() == samples.size()
- A[i] == the cluster assignment of samples[i].
- for all valid i: 0 <= A[i] < num_clusters
- The "similarity" of samples[i] with samples[j] is given by
k(samples[i],samples[j]). This means that k() should output a number >= 0
and the number should be larger for samples that are more similar.
!*/
}
#endif // DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_

View File

@@ -1,73 +0,0 @@
# Don't add dlib if it's already been added to the cmake project
if (NOT TARGET dlib)
# Determine the path to dlib.
string(REGEX REPLACE "cmake$" "" dlib_path ${CMAKE_CURRENT_LIST_FILE})
if (CMAKE_COMPILER_IS_GNUCXX)
# By default, g++ won't warn or error if you forget to return a value in a
# function which requires you to do so. This option makes it give a warning
# for doing this.
add_definitions("-Wreturn-type")
endif()
# Setup some options to allow a user to enable SSE and AVX instruction use.
if (CMAKE_COMPILER_IS_GNUCXX OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang"
OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU"
OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel")
option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" OFF)
option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF)
option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF)
if(USE_AVX_INSTRUCTIONS)
add_definitions(-mavx)
elseif (USE_SSE4_INSTRUCTIONS)
add_definitions(-msse4)
elseif(USE_SSE2_INSTRUCTIONS)
add_definitions(-msse2)
endif()
elseif (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # else if using Visual Studio
# Use SSE2 by default when using Visual Studio.
option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" ON)
# Visual Studio 2005 didn't support SSE4
if (NOT MSVC80)
option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF)
endif()
# Visual Studio 2005 and 2008 didn't support AVX
if (NOT MSVC80 AND NOT MSVC90)
option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF)
endif()
include(CheckTypeSize)
check_type_size( "void*" SIZE_OF_VOID_PTR)
if(USE_AVX_INSTRUCTIONS)
add_definitions(/arch:AVX)
elseif (USE_SSE4_INSTRUCTIONS)
# Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes.
# So only give it when we are doing a 32 bit build.
if (SIZE_OF_VOID_PTR EQUAL 4)
add_definitions(/arch:SSE2)
endif()
add_definitions(-DDLIB_HAVE_SSE2)
add_definitions(-DDLIB_HAVE_SSE3)
add_definitions(-DDLIB_HAVE_SSE41)
elseif(USE_SSE2_INSTRUCTIONS)
# Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes.
# So only give it when we are doing a 32 bit build.
if (SIZE_OF_VOID_PTR EQUAL 4)
add_definitions(/arch:SSE2)
endif()
add_definitions(-DDLIB_HAVE_SSE2)
endif()
endif()
# Add folder containing dlib to the include search path.
INCLUDE_DIRECTORIES(${dlib_path}/..)
# This is really optional, but nice. It will make sure the build mode
# created by cmake is always release by default.
include(${dlib_path}/release_build_by_default)
add_subdirectory(${dlib_path} dlib_build)
endif()

View File

@@ -0,0 +1,35 @@
cmake_minimum_required(VERSION 2.8.12)
message(WARNING "add_global_compiler_switch() is deprecated. Use target_compile_options() instead")
# Make macros that can add compiler switches to the entire project. Not just
# to the current cmake folder being built.
macro ( add_global_compiler_switch switch_name )
# If removing the switch would change the flags then it's already present
# and we don't need to do anything.
string(REPLACE "${switch_name}" "" tempstr "${CMAKE_CXX_FLAGS}")
if ("${CMAKE_CXX_FLAGS}" STREQUAL "${tempstr}" )
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${switch_name}"
CACHE STRING "Flags used by the compiler during all C++ builds."
FORCE)
endif ()
endmacro()
macro ( remove_global_compiler_switch switch_name )
string(REPLACE "${switch_name}" "" tempstr "${CMAKE_CXX_FLAGS}")
if (NOT "${CMAKE_CXX_FLAGS}" STREQUAL "${tempstr}" )
set (CMAKE_CXX_FLAGS "${tempstr}"
CACHE STRING "Flags used by the compiler during all C++ builds."
FORCE)
endif ()
endmacro()
macro (add_global_define def_name)
add_global_compiler_switch(-D${def_name})
endmacro()
macro (remove_global_define def_name)
remove_global_compiler_switch(-D${def_name})
endmacro()

View File

@@ -0,0 +1,19 @@
# This script checks if your compiler and host processor can generate and then run programs with AVX instructions.
cmake_minimum_required(VERSION 2.8.12)
# Don't rerun this script if its already been executed.
if (DEFINED AVX_IS_AVAILABLE_ON_HOST)
return()
endif()
# Set to false unless we find out otherwise in the code below.
set(AVX_IS_AVAILABLE_ON_HOST 0)
try_compile(test_for_avx_worked ${PROJECT_BINARY_DIR}/avx_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_avx
avx_test)
if(test_for_avx_worked)
message (STATUS "AVX instructions can be executed by the host processor.")
set(AVX_IS_AVAILABLE_ON_HOST 1)
endif()

View File

@@ -0,0 +1,20 @@
# This script checks if __ARM_NEON__ is defined for your compiler
cmake_minimum_required(VERSION 2.8.12)
# Don't rerun this script if its already been executed.
if (DEFINED ARM_NEON_IS_AVAILABLE)
return()
endif()
# Set to false unless we find out otherwise in the code below.
set(ARM_NEON_IS_AVAILABLE 0)
# test if __ARM_NEON__ is defined
try_compile(test_for_neon_worked ${PROJECT_BINARY_DIR}/neon_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_neon
neon_test)
if(test_for_neon_worked)
message (STATUS "__ARM_NEON__ defined.")
set(ARM_NEON_IS_AVAILABLE 1)
endif()

View File

@@ -0,0 +1,385 @@
#
# This is a CMake makefile. You can find the cmake utility and
# information about it at http://www.cmake.org
#
#
# This cmake file tries to find installed BLAS and LAPACK libraries.
# It looks for an installed copy of the Intel MKL library first and then
# attempts to find some other BLAS and LAPACK libraries if you don't have
# the Intel MKL.
#
# blas_found - True if BLAS is available
# lapack_found - True if LAPACK is available
# found_intel_mkl - True if the Intel MKL library is available
# found_intel_mkl_headers - True if Intel MKL headers are available
# blas_libraries - link against these to use BLAS library
# lapack_libraries - link against these to use LAPACK library
# mkl_libraries - link against these to use the MKL library
# mkl_include_dir - add to the include path to use the MKL library
# openmp_libraries - Set to Intel's OpenMP library if and only if we
# find the MKL.
# setting this makes CMake allow normal looking if else statements
SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true)
SET(blas_found 0)
SET(lapack_found 0)
SET(found_intel_mkl 0)
SET(found_intel_mkl_headers 0)
SET(lapack_with_underscore 0)
SET(lapack_without_underscore 0)
message(STATUS "Searching for BLAS and LAPACK")
if (UNIX OR MINGW)
message(STATUS "Searching for BLAS and LAPACK")
if (BUILDING_MATLAB_MEX_FILE)
# # This commented out stuff would link directly to MATLAB's built in
# BLAS and LAPACK. But it's better to not link to anything and do a
#find_library(MATLAB_BLAS_LIBRARY mwblas PATHS ${MATLAB_LIB_FOLDERS} )
#find_library(MATLAB_LAPACK_LIBRARY mwlapack PATHS ${MATLAB_LIB_FOLDERS} )
#if (MATLAB_BLAS_LIBRARY AND MATLAB_LAPACK_LIBRARY)
# add_subdirectory(external/cblas)
# set(blas_libraries ${MATLAB_BLAS_LIBRARY} cblas )
# set(lapack_libraries ${MATLAB_LAPACK_LIBRARY} )
# set(blas_found 1)
# set(lapack_found 1)
# message(STATUS "Found MATLAB's BLAS and LAPACK libraries")
#endif()
# We need cblas since MATLAB doesn't provide cblas symbols.
add_subdirectory(external/cblas)
set(blas_libraries cblas )
set(blas_found 1)
set(lapack_found 1)
message(STATUS "Will link with MATLAB's BLAS and LAPACK at runtime (hopefully!)")
## Don't try to link to anything other than MATLAB's own internal blas
## and lapack libraries because doing so generally upsets MATLAB. So
## we just end here no matter what.
return()
endif()
# First, search for libraries via pkg-config, which is the cleanest path
find_package(PkgConfig)
pkg_check_modules(BLAS_REFERENCE cblas)
pkg_check_modules(LAPACK_REFERENCE lapack)
if (BLAS_REFERENCE_FOUND AND LAPACK_REFERENCE_FOUND)
set(blas_libraries "${BLAS_REFERENCE_LDFLAGS}")
set(lapack_libraries "${LAPACK_REFERENCE_LDFLAGS}")
set(blas_found 1)
set(lapack_found 1)
set(REQUIRES_LIBS "${REQUIRES_LIBS} cblas lapack")
message(STATUS "Found BLAS and LAPACK via pkg-config")
return()
endif()
include(CheckTypeSize)
check_type_size( "void*" SIZE_OF_VOID_PTR)
if (SIZE_OF_VOID_PTR EQUAL 8)
set( mkl_search_path
/opt/intel/mkl/*/lib/em64t
/opt/intel/mkl/lib/intel64
/opt/intel/lib/intel64
/opt/intel/mkl/lib
)
find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path})
mark_as_advanced(mkl_intel)
else()
set( mkl_search_path
/opt/intel/mkl/*/lib/32
/opt/intel/mkl/lib/ia32
/opt/intel/lib/ia32
)
find_library(mkl_intel mkl_intel ${mkl_search_path})
mark_as_advanced(mkl_intel)
endif()
include(CheckLibraryExists)
# Get mkl_include_dir
set(mkl_include_search_path
/opt/intel/mkl/include
/opt/intel/include
)
find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path})
mark_as_advanced(mkl_include_dir)
# Search for the needed libraries from the MKL. We will try to link against the mkl_rt
# file first since this way avoids linking bugs in some cases.
find_library(mkl_rt mkl_rt ${mkl_search_path})
find_library(openmp_libraries iomp5 ${mkl_search_path})
mark_as_advanced( mkl_rt openmp_libraries )
# if we found the MKL
if ( mkl_rt)
set(mkl_libraries ${mkl_rt} )
set(blas_libraries ${mkl_rt} )
set(lapack_libraries ${mkl_rt} )
set(blas_found 1)
set(lapack_found 1)
set(found_intel_mkl 1)
message(STATUS "Found Intel MKL BLAS/LAPACK library")
endif()
if (NOT found_intel_mkl)
# Search for the needed libraries from the MKL. This time try looking for a different
# set of MKL files and try to link against those.
find_library(mkl_core mkl_core ${mkl_search_path})
find_library(mkl_thread mkl_intel_thread ${mkl_search_path})
find_library(mkl_iomp iomp5 ${mkl_search_path})
find_library(mkl_pthread pthread ${mkl_search_path})
mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp mkl_pthread)
# If we found the MKL
if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp AND mkl_pthread)
set(mkl_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
set(blas_found 1)
set(lapack_found 1)
set(found_intel_mkl 1)
message(STATUS "Found Intel MKL BLAS/LAPACK library")
endif()
endif()
if (found_intel_mkl AND mkl_include_dir)
set(found_intel_mkl_headers 1)
endif()
# try to find some other LAPACK libraries if we didn't find the MKL
set(extra_paths
/usr/lib64
/usr/lib64/atlas-sse3
/usr/lib64/atlas-sse2
/usr/lib64/atlas
/usr/lib
/usr/lib/atlas-sse3
/usr/lib/atlas-sse2
/usr/lib/atlas
/usr/lib/openblas-base
/opt/OpenBLAS/lib
$ENV{OPENBLAS_HOME}/lib
)
INCLUDE (CheckFunctionExists)
if (NOT blas_found)
find_library(cblas_lib openblas PATHS ${extra_paths})
if (cblas_lib)
set(blas_libraries ${cblas_lib})
set(blas_found 1)
message(STATUS "Found OpenBLAS library")
set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries})
# If you compiled OpenBLAS with LAPACK in it then it should have the
# sgetrf_single function in it. So if we find that function in
# OpenBLAS then just use OpenBLAS's LAPACK.
CHECK_FUNCTION_EXISTS(sgetrf_single OPENBLAS_HAS_LAPACK)
if (OPENBLAS_HAS_LAPACK)
message(STATUS "Using OpenBLAS's built in LAPACK")
# set(lapack_libraries gfortran)
set(lapack_found 1)
endif()
endif()
mark_as_advanced( cblas_lib)
endif()
if (NOT lapack_found)
find_library(lapack_lib NAMES lapack lapack-3 PATHS ${extra_paths})
if (lapack_lib)
set(lapack_libraries ${lapack_lib})
set(lapack_found 1)
message(STATUS "Found LAPACK library")
endif()
mark_as_advanced( lapack_lib)
endif()
# try to find some other BLAS libraries if we didn't find the MKL
if (NOT blas_found)
find_library(atlas_lib atlas PATHS ${extra_paths})
find_library(cblas_lib cblas PATHS ${extra_paths})
if (atlas_lib AND cblas_lib)
set(blas_libraries ${atlas_lib} ${cblas_lib})
set(blas_found 1)
message(STATUS "Found ATLAS BLAS library")
endif()
mark_as_advanced( atlas_lib cblas_lib)
endif()
# CentOS 7 atlas
if (NOT blas_found)
find_library(tatlas_lib tatlas PATHS ${extra_paths})
find_library(satlas_lib satlas PATHS ${extra_paths})
if (tatlas_lib AND satlas_lib )
set(blas_libraries ${tatlas_lib} ${satlas_lib})
set(blas_found 1)
message(STATUS "Found ATLAS BLAS library")
endif()
mark_as_advanced( tatlas_lib satlas_lib)
endif()
if (NOT blas_found)
find_library(cblas_lib cblas PATHS ${extra_paths})
if (cblas_lib)
set(blas_libraries ${cblas_lib})
set(blas_found 1)
message(STATUS "Found CBLAS library")
endif()
mark_as_advanced( cblas_lib)
endif()
if (NOT blas_found)
find_library(generic_blas blas PATHS ${extra_paths})
if (generic_blas)
set(blas_libraries ${generic_blas})
set(blas_found 1)
message(STATUS "Found BLAS library")
endif()
mark_as_advanced( generic_blas)
endif()
# Make sure we really found a CBLAS library. That is, it needs to expose
# the proper cblas link symbols. So here we test if one of them is present
# and assume everything is good if it is. Note that we don't do this check if
# we found the Intel MKL since for some reason CHECK_FUNCTION_EXISTS doesn't work
# with it. But it's fine since the MKL should always have cblas.
if (blas_found AND NOT found_intel_mkl)
set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries})
CHECK_FUNCTION_EXISTS(cblas_ddot HAVE_CBLAS)
if (NOT HAVE_CBLAS)
message(STATUS "BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK")
set(blas_found 0)
set(lapack_found 0)
endif()
endif()
elseif(WIN32 AND NOT MINGW)
message(STATUS "Searching for BLAS and LAPACK")
include(CheckTypeSize)
check_type_size( "void*" SIZE_OF_VOID_PTR)
if (SIZE_OF_VOID_PTR EQUAL 8)
set( mkl_search_path
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/intel64"
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/intel64"
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/intel64"
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/intel64"
"C:/Program Files (x86)/Intel/Composer XE/mkl/lib/intel64"
"C:/Program Files (x86)/Intel/Composer XE/compiler/lib/intel64"
"C:/Program Files/Intel/Composer XE/mkl/lib/intel64"
"C:/Program Files/Intel/Composer XE/compiler/lib/intel64"
)
find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path})
else()
set( mkl_search_path
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/ia32"
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/ia32"
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/ia32"
"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/ia32"
"C:/Program Files (x86)/Intel/Composer XE/mkl/lib/ia32"
"C:/Program Files (x86)/Intel/Composer XE/compiler/lib/ia32"
"C:/Program Files/Intel/Composer XE/mkl/lib/ia32"
"C:/Program Files/Intel/Composer XE/compiler/lib/ia32"
)
find_library(mkl_intel mkl_intel_c ${mkl_search_path})
endif()
INCLUDE (CheckFunctionExists)
# Search for the needed libraries from the MKL.
find_library(mkl_core mkl_core ${mkl_search_path})
find_library(mkl_thread mkl_intel_thread ${mkl_search_path})
find_library(mkl_iomp libiomp5md ${mkl_search_path})
mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp)
# If we found the MKL
if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp )
set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} )
set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} )
set(blas_found 1)
set(lapack_found 1)
message(STATUS "Found Intel MKL BLAS/LAPACK library")
# Make sure the version of the Intel MKL we found is compatible with
# the compiler we are using. One way to do this check is to see if we can
# link to it right now.
set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries})
CHECK_FUNCTION_EXISTS(cblas_ddot HAVE_CBLAS)
if (NOT HAVE_CBLAS)
message("BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK")
set(blas_found 0)
set(lapack_found 0)
endif()
endif()
endif()
# When all else fails use CMake's built in functions to find BLAS and LAPACK
if (NOT blas_found)
find_package(BLAS QUIET)
if (${BLAS_FOUND})
set(blas_libraries ${BLAS_LIBRARIES})
set(blas_found 1)
if (NOT lapack_found)
find_package(LAPACK QUIET)
if (${LAPACK_FOUND})
set(lapack_libraries ${LAPACK_LIBRARIES})
set(lapack_found 1)
endif()
endif()
endif()
endif()
# If using lapack, determine whether to mangle functions
if (lapack_found)
include(CheckFunctionExists)
include(CheckFortranFunctionExists)
set(CMAKE_REQUIRED_LIBRARIES ${lapack_libraries})
check_function_exists("sgesv" LAPACK_FOUND_C_UNMANGLED)
check_function_exists("sgesv_" LAPACK_FOUND_C_MANGLED)
if (CMAKE_Fortran_COMPILER_LOADED)
check_fortran_function_exists("sgesv" LAPACK_FOUND_FORTRAN_UNMANGLED)
check_fortran_function_exists("sgesv_" LAPACK_FOUND_FORTRAN_MANGLED)
endif ()
if (LAPACK_FOUND_C_MANGLED OR LAPACK_FOUND_FORTRAN_MANGLED)
set(lapack_with_underscore 1)
elseif (LAPACK_FOUND_C_UNMANGLED OR LAPACK_FOUND_FORTRAN_UNMANGLED)
set(lapack_without_underscore 1)
endif ()
endif()
if (UNIX OR MINGW)
if (NOT blas_found)
message(" *****************************************************************************")
message(" *** No BLAS library found so using dlib's built in BLAS. However, if you ***")
message(" *** install an optimized BLAS such as OpenBLAS or the Intel MKL your code ***")
message(" *** will run faster. On Ubuntu you can install OpenBLAS by executing: ***")
message(" *** sudo apt-get install libopenblas-dev liblapack-dev ***")
message(" *** Or you can easily install OpenBLAS from source by downloading the ***")
message(" *** source tar file from http://www.openblas.net, extracting it, and ***")
message(" *** running: ***")
message(" *** make; sudo make install ***")
message(" *****************************************************************************")
endif()
endif()

View File

@@ -0,0 +1,150 @@
cmake_minimum_required(VERSION 2.8.12)
if (POLICY CMP0054)
cmake_policy(SET CMP0054 NEW)
endif()
# Check if we are being built as part of a pybind11 module.
if (COMMAND pybind11_add_module)
# For python users, assume they have SSE4 at least and then if the host machine has AVX use that too.
set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Use SSE4 instructions")
include(${CMAKE_CURRENT_LIST_DIR}/check_if_avx_instructions_executable_on_host.cmake)
if (AVX_IS_AVAILABLE_ON_HOST)
set(USE_AVX_INSTRUCTIONS ON CACHE BOOL "Use AVX instructions")
endif()
endif()
set(USING_OLD_VISUAL_STUDIO_COMPILER 0)
if(MSVC AND MSVC_VERSION VERSION_LESS 1900)
message(FATAL_ERROR "C++11 is required to use dlib, but the version of Visual Studio you are using is too old and doesn't support C++11. You need Visual Studio 2015 or newer. ")
elseif(MSVC AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 19.0.24210.0 )
message(STATUS "NOTE: Visual Studio didn't have good enough C++11 support until Visual Studio 2015 update 3 (v19.0.24210.0)")
message(STATUS "So we aren't enabling things that require full C++11 support (e.g. the deep learning tools).")
message(STATUS "Also, be aware that Visual Studio's version naming is confusing, in particular, there are multiple versions of 'update 3'")
message(STATUS "So if you are getting this message you need to update to the newer version of Visual Studio to use full C++11.")
set(USING_OLD_VISUAL_STUDIO_COMPILER 1)
elseif(MSVC AND (MSVC_VERSION EQUAL 1911 OR MSVC_VERSION EQUAL 1910))
message(STATUS "******************************************************************************************")
message(STATUS "Your version of Visual Studio has incomplete C++11 support and is unable to compile the ")
message(STATUS "DNN examples. So we are disabling the deep learning tools. If you want to use the DNN ")
message(STATUS "tools in dlib then update your copy of Visual Studio.")
message(STATUS "******************************************************************************************")
set(USING_OLD_VISUAL_STUDIO_COMPILER 1)
endif()
if(CMAKE_COMPILER_IS_GNUCXX)
execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION)
if (GCC_VERSION VERSION_LESS 4.8)
message(FATAL_ERROR "C++11 is required to use dlib, but the version of GCC you are using is too old and doesn't support C++11. You need GCC 4.8 or newer. ")
endif()
endif()
# push USING_OLD_VISUAL_STUDIO_COMPILER to the parent so we can use it in the
# examples CMakeLists.txt file.
get_directory_property(has_parent PARENT_DIRECTORY)
if(has_parent)
set(USING_OLD_VISUAL_STUDIO_COMPILER ${USING_OLD_VISUAL_STUDIO_COMPILER} PARENT_SCOPE)
endif()
set(gcc_like_compilers GNU Clang Intel)
set(intel_archs x86_64 i386 i686 AMD64 amd64 x86)
# Setup some options to allow a user to enable SSE and AVX instruction use.
if ((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND
(";${intel_archs};" MATCHES ";${CMAKE_SYSTEM_PROCESSOR};") AND NOT USE_AUTO_VECTOR)
option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" OFF)
option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF)
option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF)
if(USE_AVX_INSTRUCTIONS)
list(APPEND active_compile_opts -mavx)
message(STATUS "Enabling AVX instructions")
elseif (USE_SSE4_INSTRUCTIONS)
list(APPEND active_compile_opts -msse4)
message(STATUS "Enabling SSE4 instructions")
elseif(USE_SSE2_INSTRUCTIONS)
list(APPEND active_compile_opts -msse2)
message(STATUS "Enabling SSE2 instructions")
endif()
elseif (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # else if using Visual Studio
# Use SSE2 by default when using Visual Studio.
option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" ON)
option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF)
option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF)
include(CheckTypeSize)
check_type_size( "void*" SIZE_OF_VOID_PTR)
if(USE_AVX_INSTRUCTIONS)
list(APPEND active_compile_opts /arch:AVX)
message(STATUS "Enabling AVX instructions")
elseif (USE_SSE4_INSTRUCTIONS)
# Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes.
# So only give it when we are doing a 32 bit build.
if (SIZE_OF_VOID_PTR EQUAL 4)
list(APPEND active_compile_opts /arch:SSE2)
endif()
message(STATUS "Enabling SSE4 instructions")
list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2")
list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE3")
list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE41")
elseif(USE_SSE2_INSTRUCTIONS)
# Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes.
# So only give it when we are doing a 32 bit build.
if (SIZE_OF_VOID_PTR EQUAL 4)
list(APPEND active_compile_opts /arch:SSE2)
endif()
message(STATUS "Enabling SSE2 instructions")
list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2")
endif()
elseif((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND
("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "^arm"))
option(USE_NEON_INSTRUCTIONS "Compile your program with ARM-NEON instructions" OFF)
if(USE_NEON_INSTRUCTIONS)
list(APPEND active_compile_opts -mfpu=neon)
message(STATUS "Enabling ARM-NEON instructions")
endif()
endif()
if (CMAKE_COMPILER_IS_GNUCXX)
# By default, g++ won't warn or error if you forget to return a value in a
# function which requires you to do so. This option makes it give a warning
# for doing this.
list(APPEND active_compile_opts "-Wreturn-type")
endif()
if ("Clang" MATCHES ${CMAKE_CXX_COMPILER_ID})
# Increase clang's default tempalte recurision depth so the dnn examples don't error out.
list(APPEND active_compile_opts "-ftemplate-depth=500")
endif()
if (MSVC)
# By default Visual Studio does not support .obj files with more than 65k sections.
# However, code generated by file_to_code_ex and code using DNN module can have
# them. So this flag enables > 65k sections, but produces .obj files
# that will not be readable by VS 2005.
list(APPEND active_compile_opts "/bigobj")
# Build dlib with all cores. Don't propagate the setting to client programs
# though since they might compile large translation units that use too much
# RAM.
list(APPEND active_compile_opts_private "/MP")
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 3.3)
# Clang can compile all Dlib's code at Windows platform. Tested with Clang 5
list(APPEND active_compile_opts "-Xclang -fcxx-exceptions")
endif()
endif()

View File

@@ -0,0 +1,19 @@
# Including this cmake script into your cmake project will cause visual studio
# to build your project against the static C runtime.
cmake_minimum_required(VERSION 2.8.12)
if (POLICY CMP0054)
cmake_policy(SET CMP0054 NEW)
endif()
if (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif()
endforeach(flag_var)
endif()

View File

@@ -0,0 +1,113 @@
# This script creates a function, enable_cpp11_for_target(), which checks if your
# compiler has C++11 support and enables it if it does.
cmake_minimum_required(VERSION 2.8.12)
if (POLICY CMP0054)
cmake_policy(SET CMP0054 NEW)
endif()
set(_where_is_cmake_utils_dir ${CMAKE_CURRENT_LIST_DIR})
function(enable_cpp11_for_target target_name)
# Set to false unless we find out otherwise in the code below.
set(COMPILER_CAN_DO_CPP_11 0)
macro(test_compiler_for_cpp11)
message(STATUS "Building a C++11 test project to see if your compiler supports C++11")
try_compile(test_for_cpp11_worked ${PROJECT_BINARY_DIR}/cpp11_test_build
${_where_is_cmake_utils_dir}/test_for_cpp11 cpp11_test)
if (test_for_cpp11_worked)
message(STATUS "C++11 activated.")
set(COMPILER_CAN_DO_CPP_11 1)
else()
set(COMPILER_CAN_DO_CPP_11 0)
message(STATUS "********** Your compiler failed to build a C++11 project. C++11 is required to use all parts of dlib! **********")
endif()
endmacro()
# Now turn on the appropriate compiler switch to enable C++11 if you have a
# C++11 compiler. In CMake 3.1 there is a simple flag you can set, but earlier
# verions of CMake are not so convenient.
if (CMAKE_VERSION VERSION_LESS "3.1.2")
if(CMAKE_COMPILER_IS_GNUCXX)
execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION)
if (GCC_VERSION VERSION_GREATER 4.8 OR GCC_VERSION VERSION_EQUAL 4.8)
message(STATUS "C++11 activated.")
target_compile_options(${target_name} PUBLIC "-std=gnu++11")
set(COMPILER_CAN_DO_CPP_11 1)
endif()
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string )
string (REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION ${clang_full_version_string})
if (CLANG_VERSION VERSION_GREATER 3.3)
message(STATUS "C++11 activated.")
target_compile_options(${target_name} PUBLIC "-std=c++11")
set(COMPILER_CAN_DO_CPP_11 1)
endif()
else()
# Since we don't know what compiler this is just try to build a c++11 project and see if it compiles.
test_compiler_for_cpp11()
endif()
else()
# Set a flag if the compiler you are using is capable of providing C++11 features.
get_property(cxx_features GLOBAL PROPERTY CMAKE_CXX_KNOWN_FEATURES)
if (";${cxx_features};" MATCHES ";cxx_rvalue_references;" AND
";${cxx_features};" MATCHES ";cxx_variadic_templates;" AND
";${cxx_features};" MATCHES ";cxx_lambdas;" AND
";${cxx_features};" MATCHES ";cxx_defaulted_move_initializers;" AND
";${cxx_features};" MATCHES ";cxx_delegating_constructors;" AND
";${cxx_features};" MATCHES ";cxx_thread_local;" AND
";${cxx_features};" MATCHES ";cxx_constexpr;" AND
";${cxx_features};" MATCHES ";cxx_decltype_incomplete_return_types;" AND
";${cxx_features};" MATCHES ";cxx_auto_type;")
set(COMPILER_CAN_DO_CPP_11 1)
# Tell cmake that we need C++11 for dlib
target_compile_features(${target_name}
PUBLIC
cxx_rvalue_references
cxx_variadic_templates
cxx_lambdas
cxx_defaulted_move_initializers
cxx_delegating_constructors
cxx_thread_local
cxx_constexpr
# cxx_decltype_incomplete_return_types # purposfully commented out because cmake errors out on this when using visual studio and cmake 3.8.0
cxx_auto_type
)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
# Sometimes clang will lie and report that it supports C++11 when
# really it doesn't support thread_local. So check for that.
test_compiler_for_cpp11()
else()
message(STATUS "C++11 activated.")
endif()
endif()
endif()
# Always enable whatever partial C++11 support we have, even if it isn't full
# support, and just hope for the best.
if (NOT COMPILER_CAN_DO_CPP_11)
include(CheckCXXCompilerFlag)
CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11)
CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X)
if(COMPILER_SUPPORTS_CXX11)
message(STATUS "C++11 activated (compiler doesn't have full C++11 support).")
target_compile_options(${target_name} PUBLIC "-std=c++11")
elseif(COMPILER_SUPPORTS_CXX0X)
message(STATUS "C++0x activated (compiler doesn't have full C++11 support).")
target_compile_options(${target_name} PUBLIC "-std=c++0x")
endif()
endif()
endfunction()

View File

@@ -305,7 +305,7 @@ namespace dlib
bool move_next (
) const { return options.move_next(); }
unsigned long size (
size_t size (
) const { return options.size(); }
private:

View File

@@ -10,7 +10,7 @@
#include <string>
#include <sstream>
#include <map>
#include "../smart_pointers.h"
#include <memory>
namespace dlib
{
@@ -105,7 +105,7 @@ namespace dlib
// Make a separate ostringstream for each option group. We are going to write
// the output for each group to a separate ostringstream so that we can keep
// them grouped together in the final output.
std::map<string,shared_ptr<ostringstream> > groups;
std::map<string,std::shared_ptr<ostringstream> > groups;
this->reset();
while(this->move_next())
{
@@ -173,7 +173,7 @@ namespace dlib
out << _dT(ct,"Options:");
// Now print everything out
typename std::map<string,shared_ptr<ostringstream> >::iterator i;
typename std::map<string,std::shared_ptr<ostringstream> >::iterator i;
for (i = groups.begin(); i != groups.end(); ++i)
{
// print the group name if we have one

View File

@@ -7,15 +7,28 @@
// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle
// automatically depending on the state of certain other macros, which is not what you want
// when creating a shared library.
//#define ENABLE_ASSERTS // asserts always enabled
//#define DLIB_DISABLE_ASSERTS // asserts always disabled
/* #undef ENABLE_ASSERTS */
#define DLIB_DISABLE_ASSERTS // asserts always disabled
/* #undef DLIB_ISO_CPP_ONLY */
#define DLIB_NO_GUI_SUPPORT
/* #undef DLIB_ENABLE_STACK_TRACE */
/* #undef LAPACK_FORCE_UNDERSCORE */
/* #undef LAPACK_FORCE_NOUNDERSCORE */
// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA,
// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines.
/* #undef DLIB_JPEG_SUPPORT */
/* #undef DLIB_PNG_SUPPORT */
/* #undef DLIB_GIF_SUPPORT */
/* #undef DLIB_USE_FFTW */
/* #undef DLIB_USE_BLAS */
/* #undef DLIB_USE_LAPACK */
/* #undef DLIB_USE_CUDA */
/* #undef DLIB_USE_MKL_FFT */
// This variable allows dlib/test_for_odr_violations.h to catch people who mistakenly use
// headers from one version of dlib with a compiled dlib binary from a different dlib version.
#define DLIB_CHECK_FOR_VERSION_MISMATCH DLIB_VERSION_MISMATCH_CHECK__EXPECTED_VERSION_19_13_0
// You should also consider telling dlib to link against libjpeg, libpng, fftw, and a BLAS
// and LAPACK library. To do this you need to uncomment the following #defines.
// #define DLIB_JPEG_SUPPORT
// #define DLIB_PNG_SUPPORT
// #define DLIB_USE_FFTW
// #define DLIB_USE_BLAS
// #define DLIB_USE_LAPACK

View File

@@ -65,7 +65,8 @@ namespace dlib
!*/
inline bool print_status (
double cur
double cur,
bool always_print = false
);
/*!
ensures
@@ -74,10 +75,13 @@ namespace dlib
remaining until cur becomes equal to target().
- prints a status message to the screen which indicates how much
more time is left until cur is equal to target()
- This function throttles the printing so that at most 1 message is printed
each second. Note that it won't print anything to the screen until about
one second has elapsed. This means that the first call to print_status()
never prints to the screen.
- if (always_print) then
- This function prints to the screen each time it is called.
- else
- This function throttles the printing so that at most 1 message is
printed each second. Note that it won't print anything to the screen
until about one second has elapsed. This means that the first call
to print_status() never prints to the screen.
- This function returns true if it prints to the screen and false
otherwise.
!*/
@@ -115,7 +119,8 @@ namespace dlib
bool console_progress_indicator::
print_status (
double cur
double cur,
bool always_print
)
{
const time_t cur_time = std::time(0);
@@ -132,7 +137,7 @@ namespace dlib
return false;
}
if (cur_time != last_time)
if (cur_time != last_time || always_print)
{
last_time = cur_time;
double delta_t = static_cast<double>(cur_time - start_time);
@@ -152,17 +157,17 @@ namespace dlib
if (seconds < 60)
{
ss = std::cout.precision(0);
std::cout << "Time remaining: " << seconds << " seconds. \r" << std::flush;
std::cout << "Time remaining: " << seconds << " seconds. \r" << std::flush;
}
else if (seconds < 60*60)
{
ss = std::cout.precision(2);
std::cout << "Time remaining: " << seconds/60 << " minutes. \r" << std::flush;
std::cout << "Time remaining: " << seconds/60 << " minutes. \r" << std::flush;
}
else
{
ss = std::cout.precision(2);
std::cout << "Time remaining: " << seconds/60/60 << " hours. \r" << std::flush;
std::cout << "Time remaining: " << seconds/60/60 << " hours. \r" << std::flush;
}
// restore previous output flags and precision settings

View File

@@ -0,0 +1,11 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CONTRoL_
#define DLIB_CONTRoL_
#include "control/lspi.h"
#include "control/mpc.h"
#endif // DLIB_CONTRoL_

View File

@@ -0,0 +1,128 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#include "approximate_linear_models_abstract.h"
#include "../matrix.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
struct process_sample
{
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
process_sample(){}
process_sample(
const state_type& s,
const action_type& a,
const state_type& n,
const double& r
) : state(s), action(a), next_state(n), reward(r) {}
state_type state;
action_type action;
state_type next_state;
double reward;
};
template < typename feature_extractor >
void serialize (const process_sample<feature_extractor>& item, std::ostream& out)
{
serialize(item.state, out);
serialize(item.action, out);
serialize(item.next_state, out);
serialize(item.reward, out);
}
template < typename feature_extractor >
void deserialize (process_sample<feature_extractor>& item, std::istream& in)
{
deserialize(item.state, in);
deserialize(item.action, in);
deserialize(item.next_state, in);
deserialize(item.reward, in);
}
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class policy
{
public:
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
policy (
)
{
w.set_size(fe.num_features());
w = 0;
}
policy (
const matrix<double,0,1>& weights_,
const feature_extractor& fe_
) : w(weights_), fe(fe_) {}
action_type operator() (
const state_type& state
) const
{
return fe.find_best_action(state,w);
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
const matrix<double,0,1>& get_weights (
) const { return w; }
private:
matrix<double,0,1> w;
feature_extractor fe;
};
template < typename feature_extractor >
inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.get_feature_extractor(), out);
serialize(item.get_weights(), out);
}
template < typename feature_extractor >
inline void deserialize(policy<feature_extractor>& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
feature_extractor fe;
matrix<double,0,1> w;
deserialize(fe, in);
deserialize(w, in);
item = policy<feature_extractor>(w,fe);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_

View File

@@ -0,0 +1,213 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
#ifdef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
#include "../matrix.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
struct example_feature_extractor
{
/*!
WHAT THIS OBJECT REPRESENTS
This object defines the interface a feature extractor must implement if it
is to be used with the process_sample and policy objects defined at the
bottom of this file. Moreover, it is meant to represent the core part
of a model used in a reinforcement learning algorithm.
In particular, this object models a Q(state,action) function where
Q(state,action) == dot(w, PSI(state,action))
where PSI(state,action) is a feature vector and w is a parameter
vector.
Therefore, a feature extractor defines how the PSI(x,y) feature vector is
calculated. It also defines the types used to represent the state and
action objects.
THREAD SAFETY
Instances of this object are required to be threadsafe, that is, it should
be safe for multiple threads to make concurrent calls to the member
functions of this object.
!*/
// The state and actions can be any types so long as you provide typedefs for them.
typedef T state_type;
typedef U action_type;
// We can also say that the last element in the weight vector w must be 1. This
// can be useful for including a prior into your model.
const static bool force_last_weight_to_1 = false;
example_feature_extractor(
);
/*!
ensures
- this object is properly initialized.
!*/
unsigned long num_features(
) const;
/*!
ensures
- returns the dimensionality of the PSI() feature vector.
!*/
action_type find_best_action (
const state_type& state,
const matrix<double,0,1>& w
) const;
/*!
ensures
- returns the action A that maximizes Q(state,A) = dot(w,PSI(state,A)).
That is, this function finds the best action to take in the given state
when our model is parameterized by the given weight vector w.
!*/
void get_features (
const state_type& state,
const action_type& action,
matrix<double,0,1>& feats
) const;
/*!
ensures
- #feats.size() == num_features()
- #feats == PSI(state,action)
!*/
};
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
struct process_sample
{
/*!
REQUIREMENTS ON feature_extractor
feature_extractor should implement the example_feature_extractor interface
defined at the top of this file.
WHAT THIS OBJECT REPRESENTS
This object holds a training sample for a reinforcement learning algorithm.
In particular, it should be a sample from some process where the process
was in state this->state, then took this->action action which resulted in
receiving this->reward and ending up in the state this->next_state.
!*/
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
process_sample(){}
process_sample(
const state_type& s,
const action_type& a,
const state_type& n,
const double& r
) : state(s), action(a), next_state(n), reward(r) {}
state_type state;
action_type action;
state_type next_state;
double reward;
};
template < typename feature_extractor >
void serialize (const process_sample<feature_extractor>& item, std::ostream& out);
template < typename feature_extractor >
void deserialize (process_sample<feature_extractor>& item, std::istream& in);
/*!
provides serialization support.
!*/
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class policy
{
/*!
REQUIREMENTS ON feature_extractor
feature_extractor should implement the example_feature_extractor interface
defined at the top of this file.
WHAT THIS OBJECT REPRESENTS
This is a policy based on the supplied feature_extractor model. In
particular, it maps from feature_extractor::state_type to the best action
to take in that state.
!*/
public:
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
policy (
);
/*!
ensures
- #get_feature_extractor() == feature_extractor()
(i.e. it will have its default value)
- #get_weights().size() == #get_feature_extractor().num_features()
- #get_weights() == 0
!*/
policy (
const matrix<double,0,1>& weights,
const feature_extractor& fe
);
/*!
requires
- fe.num_features() == weights.size()
ensures
- #get_feature_extractor() == fe
- #get_weights() == weights
!*/
action_type operator() (
const state_type& state
) const;
/*!
ensures
- returns get_feature_extractor().find_best_action(state,w);
!*/
const feature_extractor& get_feature_extractor (
) const;
/*!
ensures
- returns the feature extractor used by this object
!*/
const matrix<double,0,1>& get_weights (
) const;
/*!
ensures
- returns the parameter vector (w) associated with this object. The length
of the vector is get_feature_extractor().num_features().
!*/
};
template < typename feature_extractor >
void serialize(const policy<feature_extractor>& item, std::ostream& out);
template < typename feature_extractor >
void deserialize(policy<feature_extractor>& item, std::istream& in);
/*!
provides serialization support.
!*/
// ----------------------------------------------------------------------------------------
#endif // DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_

View File

@@ -0,0 +1,188 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_LSPI_Hh_
#define DLIB_LSPI_Hh_
#include "lspi_abstract.h"
#include "approximate_linear_models.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class lspi
{
public:
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
explicit lspi(
const feature_extractor& fe_
) : fe(fe_)
{
init();
}
lspi(
)
{
init();
}
double get_discount (
) const { return discount; }
void set_discount (
double value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < value && value <= 1,
"\t void lspi::set_discount(value)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t value: " << value
);
discount = value;
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
void set_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void lspi::set_epsilon(eps_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t eps_: " << eps_
);
eps = eps_;
}
double get_epsilon (
) const
{
return eps;
}
void set_lambda (
double lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(lambda_ >= 0,
"\t void lspi::set_lambda(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda = lambda_;
}
double get_lambda (
) const
{
return lambda;
}
void set_max_iterations (
unsigned long max_iter
) { max_iterations = max_iter; }
unsigned long get_max_iterations (
) { return max_iterations; }
template <typename vector_type>
policy<feature_extractor> train (
const vector_type& samples
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(samples.size() > 0,
"\t policy lspi::train(samples)"
<< "\n\t invalid inputs were given to this function"
);
matrix<double,0,1> w(fe.num_features());
w = 0;
matrix<double,0,1> prev_w, b, f1, f2;
matrix<double> A;
double change;
unsigned long iter = 0;
do
{
A = identity_matrix<double>(fe.num_features())*lambda;
b = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
fe.get_features(samples[i].state, samples[i].action, f1);
fe.get_features(samples[i].next_state,
fe.find_best_action(samples[i].next_state,w),
f2);
A += f1*trans(f1 - discount*f2);
b += f1*samples[i].reward;
}
prev_w = w;
if (feature_extractor::force_last_weight_to_1)
w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0));
else
w = pinv(A)*b;
change = length(w-prev_w);
++iter;
if (verbose)
std::cout << "iteration: " << iter << "\tchange: " << change << std::endl;
} while(change > eps && iter < max_iterations);
return policy<feature_extractor>(w,fe);
}
private:
void init()
{
lambda = 0.01;
discount = 0.8;
eps = 0.01;
verbose = false;
max_iterations = 100;
}
double lambda;
double discount;
double eps;
bool verbose;
unsigned long max_iterations;
feature_extractor fe;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_LSPI_Hh_

View File

@@ -0,0 +1,193 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_LSPI_ABSTRACT_Hh_
#ifdef DLIB_LSPI_ABSTRACT_Hh_
#include "approximate_linear_models_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class lspi
{
/*!
REQUIREMENTS ON feature_extractor
feature_extractor should implement the example_feature_extractor interface
defined at the top of dlib/control/approximate_linear_models_abstract.h
WHAT THIS OBJECT REPRESENTS
This object is an implementation of the reinforcement learning algorithm
described in the following paper:
Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy
iteration." The Journal of Machine Learning Research 4 (2003):
1107-1149.
This means that it takes a bunch of training data in the form of
process_samples and outputs a policy that hopefully performs well when run
on the process that generated those samples.
!*/
public:
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
explicit lspi(
const feature_extractor& fe_
);
/*!
ensures
- #get_feature_extractor() == fe_
- #get_lambda() == 0.01
- #get_discount == 0.8
- #get_epsilon() == 0.01
- is not verbose
- #get_max_iterations() == 100
!*/
lspi(
);
/*!
ensures
- #get_feature_extractor() == feature_extractor()
(i.e. it will have its default value)
- #get_lambda() == 0.01
- #get_discount == 0.8
- #get_epsilon() == 0.01
- is not verbose
- #get_max_iterations() == 100
!*/
double get_discount (
) const;
/*!
ensures
- returns the discount applied to the sum of rewards in the Bellman
equation.
!*/
void set_discount (
double value
);
/*!
requires
- 0 < value <= 1
ensures
- #get_discount() == value
!*/
const feature_extractor& get_feature_extractor (
) const;
/*!
ensures
- returns the feature extractor used by this object
!*/
void be_verbose (
);
/*!
ensures
- This object will print status messages to standard out so that a
user can observe the progress of the algorithm.
!*/
void be_quiet (
);
/*!
ensures
- this object will not print anything to standard out
!*/
void set_epsilon (
double eps
);
/*!
requires
- eps > 0
ensures
- #get_epsilon() == eps
!*/
double get_epsilon (
) const;
/*!
ensures
- returns the error epsilon that determines when training should stop.
Smaller values may result in a more accurate solution but take longer to
train.
!*/
void set_lambda (
double lambda_
);
/*!
requires
- lambda >= 0
ensures
- #get_lambda() == lambda
!*/
double get_lambda (
) const;
/*!
ensures
- returns the regularization parameter. It is the parameter that
determines the trade off between trying to fit the training data
exactly or allowing more errors but hopefully improving the
generalization ability of the resulting function. Smaller values
encourage exact fitting while larger values of lambda may encourage
better generalization.
!*/
void set_max_iterations (
unsigned long max_iter
);
/*!
ensures
- #get_max_iterations() == max_iter
!*/
unsigned long get_max_iterations (
);
/*!
ensures
- returns the maximum number of iterations the SVM optimizer is allowed to
run before it is required to stop and return a result.
!*/
template <
typename vector_type
>
policy<feature_extractor> train (
const vector_type& samples
) const;
/*!
requires
- samples.size() > 0
- samples is something with an interface that looks like
std::vector<process_sample<feature_extractor>>. That is, it should
be some kind of array of process_sample objects.
ensures
- Trains a policy based on the given data and returns the results. The
idea is to find a policy that will obtain the largest possible reward
when run on the process that generated the samples. In particular,
if the returned policy is P then:
- P(S) == the best action to take when in state S.
- if (feature_extractor::force_last_weight_to_1) then
- The last element of P.get_weights() is 1.
!*/
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_LSPI_ABSTRACT_Hh_

View File

@@ -0,0 +1,370 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MPC_Hh_
#define DLIB_MPC_Hh_
#include "mpc_abstract.h"
#include "../matrix.h"
#include "../algs.h"
namespace dlib
{
template <
long S_,
long I_,
unsigned long horizon_
>
class mpc
{
public:
const static long S = S_;
const static long I = I_;
const static unsigned long horizon = horizon_;
mpc(
)
{
A = 0;
B = 0;
C = 0;
Q = 0;
R = 0;
lower = 0;
upper = 0;
max_iterations = 0;
eps = 0.01;
for (unsigned long i = 0; i < horizon; ++i)
{
target[i].set_size(A.nr());
target[i] = 0;
controls[i].set_size(B.nc());
controls[i] = 0;
}
lambda = 0;
}
mpc (
const matrix<double,S,S>& A_,
const matrix<double,S,I>& B_,
const matrix<double,S,1>& C_,
const matrix<double,S,1>& Q_,
const matrix<double,I,1>& R_,
const matrix<double,I,1>& lower_,
const matrix<double,I,1>& upper_
) : A(A_), B(B_), C(C_), Q(Q_), R(R_), lower(lower_), upper(upper_)
{
// make sure requires clause is not broken
DLIB_ASSERT(A.nr() > 0 && B.nc() > 0,
"\t mpc::mpc()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t A.nr(): " << A.nr()
<< "\n\t B.nc(): " << B.nc()
);
DLIB_ASSERT(A.nr() == A.nc() &&
A.nr() == B.nr() &&
A.nr() == C.nr() &&
A.nr() == Q.nr(),
"\t mpc::mpc()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t A.nr(): " << A.nr()
<< "\n\t A.nc(): " << A.nc()
<< "\n\t B.nr(): " << B.nr()
<< "\n\t C.nr(): " << C.nr()
<< "\n\t Q.nr(): " << Q.nr()
);
DLIB_ASSERT(
B.nc() == R.nr() &&
B.nc() == lower.nr() &&
B.nc() == upper.nr() ,
"\t mpc::mpc()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t B.nr(): " << B.nr()
<< "\n\t B.nc(): " << B.nc()
<< "\n\t lower.nr(): " << lower.nr()
<< "\n\t upper.nr(): " << upper.nr()
);
DLIB_ASSERT(min(Q) >= 0 &&
min(R) > 0 &&
min(upper-lower) >= 0,
"\t mpc::mpc()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t min(Q): " << min(Q)
<< "\n\t min(R): " << min(R)
<< "\n\t min(upper-lower): " << min(upper-lower)
);
max_iterations = 10000;
eps = 0.01;
for (unsigned long i = 0; i < horizon; ++i)
{
target[i].set_size(A.nr());
target[i] = 0;
controls[i].set_size(B.nc());
controls[i] = 0;
}
// Bound the maximum eigenvalue of the hessian by computing the trace of the
// hessian matrix.
lambda = sum(R)*horizon;
matrix<double,S,S> temp = diagm(Q);
for (unsigned long c = 0; c < horizon; ++c)
{
lambda += trace(trans(B)*temp*B);
Q_diag[horizon-c-1] = diag(trans(B)*temp*B);
temp = trans(A)*temp*A + diagm(Q);
}
}
const matrix<double,S,S>& get_A (
) const { return A; }
const matrix<double,S,I>& get_B (
) const { return B; }
const matrix<double,S,1>& get_C (
) const { return C; }
const matrix<double,S,1>& get_Q (
) const { return Q; }
const matrix<double,I,1>& get_R (
) const { return R; }
const matrix<double,I,1>& get_lower_constraints (
) const { return lower; }
const matrix<double,I,1>& get_upper_constraints (
) const { return upper; }
void set_target (
const matrix<double,S,1>& val,
const unsigned long time
)
{
DLIB_ASSERT(time < horizon,
"\t void mpc::set_target(eps_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t time: " << time
<< "\n\t horizon: " << horizon
);
target[time] = val;
}
void set_target (
const matrix<double,S,1>& val
)
{
for (unsigned long i = 0; i < horizon; ++i)
target[i] = val;
}
void set_last_target (
const matrix<double,S,1>& val
)
{
set_target(val, horizon-1);
}
const matrix<double,S,1>& get_target (
const unsigned long time
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(time < horizon,
"\t matrix mpc::get_target(eps_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t time: " << time
<< "\n\t horizon: " << horizon
);
return target[time];
}
unsigned long get_max_iterations (
) const { return max_iterations; }
void set_max_iterations (
unsigned long max_iter
)
{
max_iterations = max_iter;
}
void set_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void mpc::set_epsilon(eps_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t eps_: " << eps_
);
eps = eps_;
}
double get_epsilon (
) const
{
return eps;
}
matrix<double,I,1> operator() (
const matrix<double,S,1>& current_state
)
{
// make sure requires clause is not broken
DLIB_ASSERT(min(R) > 0 && A.nr() == current_state.size(),
"\t matrix mpc::operator(current_state)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t min(R): " << min(R)
<< "\n\t A.nr(): " << A.nr()
<< "\n\t current_state.size(): " << current_state.size()
);
// Shift the inputs over by one time step so we can use them to warm start the
// optimizer.
for (unsigned long i = 1; i < horizon; ++i)
controls[i-1] = controls[i];
solve_linear_mpc(current_state);
for (unsigned long i = 1; i < horizon; ++i)
target[i-1] = target[i];
return controls[0];
}
private:
// These temporary variables here just to avoid reallocating them on each call to
// operator().
matrix<double,S,1> M[horizon];
matrix<double,I,1> MM[horizon];
matrix<double,I,1> df[horizon];
matrix<double,I,1> v[horizon];
matrix<double,I,1> v_old[horizon];
void solve_linear_mpc (
const matrix<double,S,1>& initial_state
)
{
// make it so MM == trans(K)*Q*(M-target)
M[0] = A*initial_state + C;
for (unsigned long i = 1; i < horizon; ++i)
M[i] = A*M[i-1] + C;
for (unsigned long i = 0; i < horizon; ++i)
M[i] = diagm(Q)*(M[i]-target[i]);
for (long i = (long)horizon-2; i >= 0; --i)
M[i] += trans(A)*M[i+1];
for (unsigned long i = 0; i < horizon; ++i)
MM[i] = trans(B)*M[i];
unsigned long iter = 0;
for (; iter < max_iterations; ++iter)
{
// compute current gradient and put it into df.
// df == H*controls + MM;
M[0] = B*controls[0];
for (unsigned long i = 1; i < horizon; ++i)
M[i] = A*M[i-1] + B*controls[i];
for (unsigned long i = 0; i < horizon; ++i)
M[i] = diagm(Q)*M[i];
for (long i = (long)horizon-2; i >= 0; --i)
M[i] += trans(A)*M[i+1];
for (unsigned long i = 0; i < horizon; ++i)
df[i] = MM[i] + trans(B)*M[i] + diagm(R)*controls[i];
// Check the stopping condition, which is the magnitude of the largest element
// of the gradient.
double max_df = 0;
unsigned long max_t = 0;
long max_v = 0;
for (unsigned long i = 0; i < horizon; ++i)
{
for (long j = 0; j < controls[i].size(); ++j)
{
// if this variable isn't an active constraint then we care about it's
// derivative.
if (!((controls[i](j) <= lower(j) && df[i](j) > 0) ||
(controls[i](j) >= upper(j) && df[i](j) < 0)))
{
if (std::abs(df[i](j)) > max_df)
{
max_df = std::abs(df[i](j));
max_t = i;
max_v = j;
}
}
}
}
if (max_df < eps)
break;
// We will start out by doing a little bit of coordinate descent because it
// allows us to optimize individual variables exactly. Since we are warm
// starting each iteration with a really good solution this helps speed
// things up a lot.
const unsigned long smo_iters = 50;
if (iter < smo_iters)
{
if (Q_diag[max_t](max_v) == 0) continue;
// Take the optimal step but just for one variable.
controls[max_t](max_v) = -(df[max_t](max_v)-Q_diag[max_t](max_v)*controls[max_t](max_v))/Q_diag[max_t](max_v);
controls[max_t](max_v) = put_in_range(lower(max_v), upper(max_v), controls[max_t](max_v));
// If this is the last SMO iteration then don't forget to initialize v
// for the gradient steps.
if (iter+1 == smo_iters)
{
for (unsigned long i = 0; i < horizon; ++i)
v[i] = controls[i];
}
}
else
{
// Take a projected gradient step.
for (unsigned long i = 0; i < horizon; ++i)
{
v_old[i] = v[i];
v[i] = dlib::clamp(controls[i] - 1.0/lambda * df[i], lower, upper);
controls[i] = dlib::clamp(v[i] + (std::sqrt(lambda)-1)/(std::sqrt(lambda)+1)*(v[i]-v_old[i]), lower, upper);
}
}
}
}
unsigned long max_iterations;
double eps;
matrix<double,S,S> A;
matrix<double,S,I> B;
matrix<double,S,1> C;
matrix<double,S,1> Q;
matrix<double,I,1> R;
matrix<double,I,1> lower;
matrix<double,I,1> upper;
matrix<double,S,1> target[horizon];
double lambda; // abound on the largest eigenvalue of the hessian matrix.
matrix<double,I,1> Q_diag[horizon];
matrix<double,I,1> controls[horizon];
};
}
#endif // DLIB_MPC_Hh_

View File

@@ -0,0 +1,276 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_MPC_ABSTRACT_Hh_
#ifdef DLIB_MPC_ABSTRACT_Hh_
#include "../matrix.h"
namespace dlib
{
template <
long S_,
long I_,
unsigned long horizon_
>
class mpc
{
/*!
REQUIREMENTS ON horizon_
horizon_ > 0
REQUIREMENTS ON S_
S_ >= 0
REQUIREMENTS ON I_
I_ >= 0
WHAT THIS OBJECT REPRESENTS
This object implements a linear model predictive controller. To explain
what that means, suppose you have some process you want to control and the
process dynamics are described by the linear equation:
x_{i+1} = A*x_i + B*u_i + C
That is, the next state the system goes into is a linear function of its
current state (x_i) and the current control (u_i) plus some constant bias
or disturbance.
A model predictive controller can find the control (u) you should apply to
drive the state (x) to some reference value, or alternatively to make the
state track some reference time-varying sequence. It does this by
simulating the process for horizon_ time steps and selecting the control
that leads to the best performance over the next horizon_ steps.
To be precise, each time you ask this object for a control, it solves the
following quadratic program:
min sum_i trans(x_i-target_i)*Q*(x_i-target_i) + trans(u_i)*R*u_i
x_i,u_i
such that: x_0 == current_state
x_{i+1} == A*x_i + B*u_i + C
lower <= u_i <= upper
0 <= i < horizon_
and reports u_0 as the control you should take given that you are currently
in current_state. Q and R are user supplied matrices that define how we
penalize variations away from the target state as well as how much we want
to avoid generating large control signals.
Finally, the algorithm we use to solve this quadratic program is based
largely on the method described in:
A Fast Gradient method for embedded linear predictive control (2011)
by Markus Kogel and Rolf Findeisen
!*/
public:
const static long S = S_;
const static long I = I_;
const static unsigned long horizon = horizon_;
mpc(
);
/*!
ensures
- #get_max_iterations() == 0
- The A,B,C,Q,R,lower, and upper parameter matrices are filled with zeros.
Therefore, to use this object you must initialize it via the constructor
that supplies these parameters.
!*/
mpc (
const matrix<double,S,S>& A,
const matrix<double,S,I>& B,
const matrix<double,S,1>& C,
const matrix<double,S,1>& Q,
const matrix<double,I,1>& R,
const matrix<double,I,1>& lower,
const matrix<double,I,1>& upper
);
/*!
requires
- A.nr() > 0
- B.nc() > 0
- A.nr() == A.nc() == B.nr() == C.nr() == Q.nr()
- B.nc() == R.nr() == lower.nr() == upper.nr()
- min(Q) >= 0
- min(R) > 0
- min(upper-lower) >= 0
ensures
- #get_A() == A
- #get_B() == B
- #get_C() == C
- #get_Q() == Q
- #get_R() == R
- #get_lower_constraints() == lower
- #get_upper_constraints() == upper
- for all valid i:
- get_target(i) == a vector of all zeros
- get_target(i).size() == A.nr()
- #get_max_iterations() == 10000
- #get_epsilon() == 0.01
!*/
const matrix<double,S,S>& get_A (
) const;
/*!
ensures
- returns the A matrix from the quadratic program defined above.
!*/
const matrix<double,S,I>& get_B (
) const;
/*!
ensures
- returns the B matrix from the quadratic program defined above.
!*/
const matrix<double,S,1>& get_C (
) const;
/*!
ensures
- returns the C matrix from the quadratic program defined above.
!*/
const matrix<double,S,1>& get_Q (
) const;
/*!
ensures
- returns the diagonal of the Q matrix from the quadratic program defined
above.
!*/
const matrix<double,I,1>& get_R (
) const;
/*!
ensures
- returns the diagonal of the R matrix from the quadratic program defined
above.
!*/
const matrix<double,I,1>& get_lower_constraints (
) const;
/*!
ensures
- returns the lower matrix from the quadratic program defined above. All
controls generated by this object will have values no less than this
lower bound. That is, any control u will satisfy min(u-lower) >= 0.
!*/
const matrix<double,I,1>& get_upper_constraints (
) const;
/*!
ensures
- returns the upper matrix from the quadratic program defined above. All
controls generated by this object will have values no larger than this
upper bound. That is, any control u will satisfy min(upper-u) >= 0.
!*/
const matrix<double,S,1>& get_target (
const unsigned long time
) const;
/*!
requires
- time < horizon
ensures
- This object will try to find the control sequence that results in the
process obtaining get_target(time) state at the indicated time. Note
that the next time instant after "right now" is time 0.
!*/
void set_target (
const matrix<double,S,1>& val,
const unsigned long time
);
/*!
requires
- time < horizon
ensures
- #get_target(time) == val
!*/
void set_target (
const matrix<double,S,1>& val
);
/*!
ensures
- for all valid t:
- #get_target(t) == val
!*/
void set_last_target (
const matrix<double,S,1>& val
);
/*!
ensures
- performs: set_target(val, horizon-1)
!*/
unsigned long get_max_iterations (
) const;
/*!
ensures
- When operator() is called it solves an optimization problem to
get_epsilon() precision to determine the next control action. In
particular, we run the optimizer until the magnitude of each element of
the gradient vector is less than get_epsilon() or until
get_max_iterations() solver iterations have been executed.
!*/
void set_max_iterations (
unsigned long max_iter
);
/*!
ensures
- #get_max_iterations() == max_iter
!*/
void set_epsilon (
double eps
);
/*!
requires
- eps > 0
ensures
- #get_epsilon() == eps
!*/
double get_epsilon (
) const;
/*!
ensures
- When operator() is called it solves an optimization problem to
get_epsilon() precision to determine the next control action. In
particular, we run the optimizer until the magnitude of each element of
the gradient vector is less than get_epsilon() or until
get_max_iterations() solver iterations have been executed. This means
that smaller epsilon values will give more accurate outputs but may take
longer to compute.
!*/
matrix<double,I,1> operator() (
const matrix<double,S,1>& current_state
);
/*!
requires
- min(R) > 0
- A.nr() == current_state.size()
ensures
- Solves the model predictive control problem defined by the arguments to
this objects constructor, assuming that the starting state is given by
current_state. Then we return the control that should be taken in the
current state that best optimizes the quadratic objective function
defined above.
- We also shift over the target states so that you only need to update the
last one (if you are using non-zero target states) via a call to
set_last_target()). In particular, for all valid t, it will be the case
that:
- #get_target(t) == get_target(t+1)
- #get_target(horizon-1) == get_target(horizon-1)
!*/
};
}
#endif // DLIB_MPC_ABSTRACT_Hh_

View File

@@ -5,6 +5,7 @@
#include "../algs.h"
#include <string>
#include <vector>
#include "crc32_kernel_abstract.h"
namespace dlib
@@ -15,11 +16,9 @@ namespace dlib
/*!
INITIAL VALUE
checksum == 0xFFFFFFFF
table == crc table
CONVENTION
get_checksum() == checksum ^ 0xFFFFFFFF
table == crc table
!*/
public:
@@ -34,6 +33,10 @@ namespace dlib
const std::string& item
);
inline crc32 (
const std::vector<char>& item
);
inline virtual ~crc32 (
);
@@ -48,6 +51,13 @@ namespace dlib
const std::string& item
);
inline void add (
const std::vector<char>& item
);
inline operator unsigned long (
) const { return get_checksum(); }
inline unsigned long get_checksum (
) const;
@@ -61,12 +71,67 @@ namespace dlib
private:
inline void fill_crc_table(
);
unsigned long checksum;
unsigned long table[256];
inline unsigned long table (
unsigned int idx
) const
{
/*
// This code generates the crc_table used below.
unsigned long crc_table[256];
for (unsigned long i = 0; i < 256; ++i)
{
unsigned long temp = i;
for (unsigned long j = 0; j < 8; ++j)
{
if (temp&1)
temp = (temp>>1)^0xedb88320;
else
temp >>= 1;
}
crc_table[i] = temp;
std::cout << std::hex << crc_table[i] << std::endl;
}
*/
const static unsigned long crc_table[256] = {
0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x76dc419, 0x706af48f, 0xe963a535, 0x9e6495a3,
0xedb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x9b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91,
0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7,
0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5,
0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b,
0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59,
0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f,
0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d,
0x76dc4190, 0x1db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x6b6b51f, 0x9fbfe4a5, 0xe8b8d433,
0x7807c9a2, 0xf00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x86d3d2d, 0x91646c97, 0xe6635c01,
0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457,
0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65,
0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb,
0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9,
0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f,
0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad,
0xedb88320, 0x9abfb3b6, 0x3b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x4db2615, 0x73dc1683,
0xe3630b12, 0x94643b84, 0xd6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0xa00ae27, 0x7d079eb1,
0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7,
0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5,
0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b,
0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79,
0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f,
0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d,
0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x26d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x5005713,
0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0xcb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0xbdbdf21,
0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777,
0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45,
0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db,
0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9,
0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf,
0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d
};
return crc_table[idx];
}
};
@@ -79,29 +144,6 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void crc32::
fill_crc_table (
)
{
unsigned long temp;
// fill out the crc table
for (unsigned long i = 0; i < 256; ++i)
{
temp = i;
for (unsigned long j = 0; j < 8; ++j)
{
if (temp&1)
temp = (temp>>1)^0xedb88320;
else
temp >>= 1;
}
table[i] = temp;
}
}
// ----------------------------------------------------------------------------------------
crc32::
@@ -109,7 +151,6 @@ namespace dlib
)
{
checksum = 0xFFFFFFFF;
fill_crc_table();
}
// ----------------------------------------------------------------------------------------
@@ -120,7 +161,17 @@ namespace dlib
)
{
checksum = 0xFFFFFFFF;
fill_crc_table();
add(item);
}
// ----------------------------------------------------------------------------------------
crc32::
crc32 (
const std::vector<char>& item
)
{
checksum = 0xFFFFFFFF;
add(item);
}
@@ -148,7 +199,7 @@ namespace dlib
unsigned char item
)
{
checksum = (checksum>>8) ^ table[(checksum^item) & 0xFF];
checksum = (checksum>>8) ^ table((checksum^item) & 0xFF);
}
// ----------------------------------------------------------------------------------------
@@ -159,7 +210,18 @@ namespace dlib
)
{
for (std::string::size_type i = 0; i < item.size(); ++i)
checksum = (checksum>>8) ^ table[(checksum^item[i]) & 0xFF];
checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF);
}
// ----------------------------------------------------------------------------------------
void crc32::
add (
const std::vector<char>& item
)
{
for (unsigned long i = 0; i < item.size(); ++i)
checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF);
}
// ----------------------------------------------------------------------------------------

View File

@@ -5,6 +5,7 @@
#include "../algs.h"
#include <string>
#include <vector>
namespace dlib
{
@@ -41,6 +42,17 @@ namespace dlib
constructor and then calling add() on item)
!*/
crc32 (
const std::vector<char>& item
);
/*!
ensures
- #*this is properly initialized
- calls this->add(item).
(i.e. Using this constructor is the same as using the default
constructor and then calling add() on item)
!*/
virtual ~crc32 (
);
/*!
@@ -73,6 +85,15 @@ namespace dlib
concatenated with item.
!*/
void add (
const std::vector<char>& item
);
/*!
ensures
- #get_checksum() == The checksum of all items added to *this previously
concatenated with item.
!*/
unsigned long get_checksum (
) const;
/*!
@@ -80,6 +101,13 @@ namespace dlib
- returns the current checksum
!*/
operator unsigned long (
) const;
/*!
ensures
- returns get_checksum()
!*/
void swap (
crc32& item
);

View File

@@ -1 +0,0 @@
#include "dlib_include_path_tutorial.txt"

View File

@@ -0,0 +1,505 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CPU_H_
#define DLIB_DNN_CPU_H_
// This file contains CPU implementations of the GPU based functions in cuda_dlib.h
// and cudnn_dlibapi.h
#include "tensor.h"
#include "../geometry/rectangle.h"
namespace dlib
{
namespace cpu
{
// -----------------------------------------------------------------------------------
void multiply (
bool add_to,
tensor& dest,
const tensor& src1,
const tensor& src2
);
void multiply_conv (
bool add_to,
tensor& dest,
const tensor& src1,
const tensor& src2
);
void multiply_zero_padded (
bool add_to,
tensor& dest,
const tensor& src1,
const tensor& src2
);
void scale_channels (
bool add_to,
tensor& dest,
const tensor& src,
const tensor& scales
);
void add(
float beta,
tensor& dest,
float alpha,
const tensor& src
);
void assign_bias_gradient (
tensor& grad,
const tensor& gradient_input
);
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
);
void assign_conv_bias_gradient (
tensor& grad,
const tensor& gradient_input
);
// -----------------------------------------------------------------------------------
void affine_transform(
tensor& dest,
const tensor& src,
const float A,
const float B
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B,
const float C
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C,
const float D
);
void affine_transform_range(
size_t begin,
size_t end,
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C
);
// -----------------------------------------------------------------------------------
void affine_transform(
tensor& dest,
const tensor& src,
const tensor& A,
const tensor& B
);
// -----------------------------------------------------------------------------------
void affine_transform_conv(
tensor& dest,
const tensor& src,
const tensor& A,
const tensor& B
);
// -----------------------------------------------------------------------------------
void affine_transform(
const rectangle& rect,
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
float A,
float B,
float C
);
// -----------------------------------------------------------------------------------
void compute_adam_update (
size_t begin,
size_t end,
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
);
// -----------------------------------------------------------------------------------
void batch_normalize_inference (
const double eps,
resizable_tensor& dest,
const tensor& src,
const tensor& gamma,
const tensor& beta,
const tensor& running_means,
const tensor& running_variances
);
void batch_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const double averaging_factor,
resizable_tensor& running_means,
resizable_tensor& running_variances,
const tensor& src,
const tensor& gamma,
const tensor& beta
);
void batch_normalize_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
void batch_normalize_conv_inference (
const double eps,
resizable_tensor& dest,
const tensor& src,
const tensor& gamma,
const tensor& beta,
const tensor& running_means,
const tensor& running_variances
);
void batch_normalize_conv (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const double averaging_factor,
resizable_tensor& running_means,
resizable_tensor& running_variances,
const tensor& src,
const tensor& gamma,
const tensor& beta
);
void batch_normalize_conv_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
// -----------------------------------------------------------------------------------
void threshold (
tensor& data,
float thresh
);
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
);
// -----------------------------------------------------------------------------------
void softmax (
tensor& dest,
const tensor& src
);
void softmax_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ------------------------------------------------------------------------------------
void softmax_all (
tensor& dest,
const tensor& src
);
void softmax_all_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ------------------------------------------------------------------------------------
void sigmoid (
tensor& dest,
const tensor& src
);
void sigmoid_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ------------------------------------------------------------------------------------
void relu (
tensor& dest,
const tensor& src
);
void relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ----------------------------------------------------------------------------------------
void prelu (
tensor& dest,
const tensor& src,
const tensor& param
);
void prelu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input,
const tensor& param,
tensor& params_grad
);
// ------------------------------------------------------------------------------------
void tanh (
tensor& dest,
const tensor& src
);
void tanh_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ----------------------------------------------------------------------------------------
void resize_bilinear (
tensor& dest,
long dest_row_stride,
long dest_channel_stride,
const tensor& src,
long src_row_stride,
long src_channel_stride
);
void resize_bilinear_gradient (
tensor& grad,
long grad_row_stride,
long grad_channel_stride,
const tensor& gradient_input,
long gradient_input_row_stride,
long gradient_input_channel_stride
);
inline void resize_bilinear (
tensor& dest,
const tensor& src
) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
inline void resize_bilinear_gradient (
tensor& grad,
const tensor& gradient_input
) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
// -----------------------------------------------------------------------------------
class pooling
{
public:
pooling(const pooling&) = delete;
pooling& operator=(const pooling&) = delete;
pooling (
);
void clear(
);
void setup_max_pooling(
int window_height,
int window_width,
int stride_y,
int stride_x,
int padding_y,
int padding_x
);
void setup_avg_pooling(
int window_height,
int window_width,
int stride_y,
int stride_x,
int padding_y,
int padding_x
);
bool does_max_pooling(
) const { return do_max_pooling; }
void operator() (
resizable_tensor& dest,
const tensor& src
);
void get_gradient(
const tensor& gradient_input,
const tensor& dest,
const tensor& src,
tensor& grad
);
private:
int window_height;
int window_width;
int stride_y;
int stride_x;
int padding_y;
int padding_x;
bool do_max_pooling;
};
// -----------------------------------------------------------------------------------
class tensor_conv
{
public:
tensor_conv(const tensor_conv&) = delete;
tensor_conv& operator=(const tensor_conv&) = delete;
tensor_conv() {}
void clear(
) {}
void setup(
const tensor& data, /* not used but required for interface */
const tensor& filters, /* not used but required for interface */
int stride_y,
int stride_x,
int padding_y,
int padding_x
)
{
(void)data; /* silence compiler */
DLIB_CASSERT(stride_y > 0 && stride_x > 0);
DLIB_CASSERT(0 <= padding_y && padding_y < filters.nr());
DLIB_CASSERT(0 <= padding_x && padding_x < filters.nc());
last_stride_y = stride_y;
last_stride_x = stride_x;
last_padding_y = padding_y;
last_padding_x = padding_x;
}
void operator() (
const bool add_to_output,
resizable_tensor& output,
const tensor& data,
const tensor& filters
);
void operator() (
const bool add_to_output,
tensor& output,
const tensor& data,
const tensor& filters
);
void get_gradient_for_data (
const bool add_to_output,
const tensor& gradient_input,
const tensor& filters,
tensor& data_gradient
);
void get_gradient_for_filters (
const bool add_to_output,
const tensor& gradient_input,
const tensor& data,
tensor& filters_gradient
);
private:
long last_stride_y = 0;
long last_stride_x = 0;
long last_padding_y = 0;
long last_padding_x = 0;
};
// -----------------------------------------------------------------------------------
void copy_tensor(
bool add_to,
tensor& dest,
size_t dest_k_offset,
const tensor& src,
size_t src_k_offset,
size_t count_k
);
// -----------------------------------------------------------------------------------
}
}
#ifdef NO_MAKEFILE
#include "cpu_dlib.cpp"
#endif
#endif // DLIB_DNN_CPU_H_

View File

@@ -0,0 +1,50 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuBLAS_H_
#define DLIB_DNN_CuBLAS_H_
#ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "cuda_errors.h"
namespace dlib
{
namespace cuda
{
// -----------------------------------------------------------------------------------
void gemm (
float beta,
tensor& dest,
float alpha,
const tensor& lhs,
bool trans_lhs,
const tensor& rhs,
bool trans_rhs
);
/*!
requires
- The dimensions of lhs and rhs must be compatible for matrix
multiplication. In particular:
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()
ensures
- performs: dest = alpha*L*R + beta*mat(dest)
!*/
// ------------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuBLAS_H_

View File

@@ -0,0 +1,256 @@
// Copyright (C) 2017 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuDA_DATA_PTR_H_
#define DLIB_DNN_CuDA_DATA_PTR_H_
#ifdef DLIB_USE_CUDA
#include <memory>
#include <vector>
#include "../assert.h"
namespace dlib
{
namespace cuda
{
// ------------------------------------------------------------------------------------
class cuda_data_void_ptr
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a block of memory on a CUDA device.
!*/
public:
cuda_data_void_ptr() = default;
cuda_data_void_ptr(size_t n);
/*!
ensures
- This object will allocate a device memory buffer of n bytes.
- #size() == n
!*/
void* data() { return pdata.get(); }
const void* data() const { return pdata.get(); }
operator void*() { return pdata.get(); }
operator const void*() const { return pdata.get(); }
void reset() { pdata.reset(); }
size_t size() const { return num; }
/*!
ensures
- returns the length of this buffer, in bytes.
!*/
cuda_data_void_ptr operator+ (size_t offset) const
/*!
requires
- offset < size()
ensures
- returns a pointer that is offset by the given amount.
!*/
{
DLIB_CASSERT(offset < num);
cuda_data_void_ptr temp;
temp.num = num-offset;
temp.pdata = std::shared_ptr<void>(pdata, ((char*)pdata.get())+offset);
return temp;
}
private:
size_t num = 0;
std::shared_ptr<void> pdata;
};
inline cuda_data_void_ptr operator+(size_t offset, const cuda_data_void_ptr& rhs) { return rhs+offset; }
// ------------------------------------------------------------------------------------
void memcpy(
void* dest,
const cuda_data_void_ptr& src
);
/*!
requires
- dest == a pointer to at least src.size() bytes on the host machine.
ensures
- copies the GPU data from src into dest.
- This routine is equivalent to performing: memcpy(dest,src,src.size())
!*/
void memcpy(
void* dest,
const cuda_data_void_ptr& src,
const size_t num
);
/*!
requires
- dest == a pointer to at least num bytes on the host machine.
- num <= src.size()
ensures
- copies the GPU data from src into dest. Copies only the first num bytes
of src to dest.
!*/
// ------------------------------------------------------------------------------------
void memcpy(
cuda_data_void_ptr dest,
const void* src
);
/*!
requires
- dest == a pointer to at least src.size() bytes on the host machine.
ensures
- copies the host data from src to the GPU memory buffer dest.
- This routine is equivalent to performing: memcpy(dest,src,dest.size())
!*/
void memcpy(
cuda_data_void_ptr dest,
const void* src,
const size_t num
);
/*!
requires
- dest == a pointer to at least num bytes on the host machine.
- num <= dest.size()
ensures
- copies the host data from src to the GPU memory buffer dest. Copies only
the first num bytes of src to dest.
!*/
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
template <typename T>
class cuda_data_ptr
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a block of memory on a CUDA device. It is just a type safe
version of cuda_data_void_ptr.
!*/
public:
static_assert(std::is_standard_layout<T>::value, "You can only create basic standard layout types on the GPU");
cuda_data_ptr() = default;
cuda_data_ptr(size_t n) : num(n)
/*!
ensures
- This object will allocate a device memory buffer of n T objects.
- #size() == n
!*/
{
if (n == 0)
return;
pdata = cuda_data_void_ptr(n*sizeof(T));
}
T* data() { return (T*)pdata.data(); }
const T* data() const { return (T*)pdata.data(); }
operator T*() { return (T*)pdata.data(); }
operator const T*() const { return (T*)pdata.data(); }
void reset() { pdata.reset(); }
size_t size() const { return num; }
friend void memcpy(
std::vector<T>& dest,
const cuda_data_ptr& src
)
{
dest.resize(src.size());
if (src.size() != 0)
memcpy(dest.data(), src.pdata);
}
friend void memcpy(
cuda_data_ptr& src,
const std::vector<T>& dest
)
{
if (dest.size() != src.size())
dest = cuda_data_ptr<T>(src.size());
if (src.size() != 0)
memcpy(src.pdata, dest.data());
}
private:
size_t num = 0;
cuda_data_void_ptr pdata;
};
// ------------------------------------------------------------------------------------
class resizable_cuda_buffer
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a block of memory on a CUDA device that will be automatically
resized if requested size is larger than allocated.
!*/
public:
cuda_data_void_ptr get(size_t size)
/*!
ensures
- This object will return the buffer of requested size or larger.
- buffer.size() >= size
- Client code should not hold the returned cuda_data_void_ptr for long
durations, but instead should call get() whenever the buffer is
needed. Doing so ensures that multiple buffers are not kept around
in the event of a resize.
!*/
{
if (buffer.size() < size)
{
buffer.reset();
buffer = cuda_data_void_ptr(size);
}
return buffer;
}
private:
cuda_data_void_ptr buffer;
};
// ----------------------------------------------------------------------------------------
std::shared_ptr<resizable_cuda_buffer> device_global_buffer(
);
/*!
ensures
- Returns a pointer to a globally shared CUDA memory buffer on the
currently selected CUDA device. The buffer is also thread local. So
each host thread will get its own buffer. You can use this global buffer
as scratch space for CUDA computations that all take place on the default
stream. Using it in this way ensures that there aren't any race conditions
involving the use of the buffer.
- The global buffer is deallocated once all references to it are
destructed. It will be reallocated as required. So if you want to avoid
these reallocations then hold a copy of the shared_ptr returned by this
function.
!*/
// ----------------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuDA_DATA_PTR_H_

View File

@@ -0,0 +1,530 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuDA_H_
#define DLIB_DNN_CuDA_H_
#include "tensor.h"
#include "../geometry/rectangle.h"
namespace dlib
{
namespace cuda
{
// ----------------------------------------------------------------------------------------
void set_device (
int dev
);
int get_device (
);
int get_num_devices (
);
std::string get_device_name (
int device
);
void set_current_device_blocking_sync(
);
bool can_access_peer (int device_id, int peer_device_id);
bool can_access_peer (const tensor& device, const tensor& peer_device);
void device_synchronize (int dev);
void device_synchronize (const tensor& dev);
class raii_set_device
{
public:
raii_set_device() = delete;
raii_set_device(const raii_set_device&) = delete;
raii_set_device& operator=(const raii_set_device&) = delete;
raii_set_device(int dev)
{
prev_dev = get_device();
set_device(dev);
}
raii_set_device(const tensor& dev)
{
prev_dev = get_device();
set_device(dev.device_id());
}
void operator() (int dev)
{
set_device(dev);
}
void operator() (const tensor& dev)
{
set_device(dev.device_id());
}
~raii_set_device() noexcept(false)
{
set_device(prev_dev);
}
private:
int prev_dev;
};
#ifdef DLIB_USE_CUDA
class enable_peer_access
{
public:
enable_peer_access() = delete;
enable_peer_access(const enable_peer_access&) = delete;
enable_peer_access& operator=(const enable_peer_access&) = delete;
enable_peer_access(
int device_id,
int peer_device_id
);
enable_peer_access(
const tensor& device,
const tensor& peer_device
) : enable_peer_access(device.device_id(), peer_device.device_id())
{}
~enable_peer_access() noexcept(false);
private:
bool call_disable;
int device_id;
int peer_device_id;
};
// -----------------------------------------------------------------------------------
void inverse_norms (
resizable_tensor& invnorms,
const tensor& data,
const double eps
);
void dot_prods (
resizable_tensor& out,
const tensor& lhs,
const tensor& rhs
);
void dot_prods (
bool add_to,
tensor& out,
const tensor& lhs,
const tensor& rhs
);
void scale_columns (
tensor& out,
const tensor& m,
const tensor& v
);
void scale_rows (
tensor& out,
const tensor& m,
const tensor& v
);
void scale_rows2 (
float beta,
tensor& out,
const tensor& m1,
const tensor& m2,
const tensor& v1,
const tensor& v2
);
void exp (
tensor& dest,
const tensor& src
);
void log (
tensor& dest,
const tensor& src
);
void log10 (
tensor& dest,
const tensor& src
);
// ------------------------------------------------------------------------------------
void set_tensor (
tensor& t,
float value
);
void scale_tensor (
tensor& t,
float value
);
// ------------------------------------------------------------------------------------
void multiply (
bool add_to,
tensor& dest,
const tensor& src1,
const tensor& src2
);
void multiply_conv (
bool add_to,
tensor& dest,
const tensor& src1,
const tensor& src2
);
void multiply_zero_padded (
bool add_to,
tensor& dest,
const tensor& src1,
const tensor& src2
);
void scale_channels (
bool add_to,
tensor& dest,
const tensor& src,
const tensor& scales
);
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
);
// -----------------------------------------------------------------------------------
void affine_transform(
tensor& dest,
const tensor& src,
const float A,
const float B
);
void affine_transform(
tensor& dest,
const tensor& src,
const float A
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B,
const float C
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C,
const float D
);
void affine_transform_range(
size_t begin,
size_t end,
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C
);
void affine_transform(
const rectangle& rect,
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
float A,
float B,
float C
);
// Note that this function isn't in the tt:: namespace because add_scaled() is
// called by cuda::add() so we don't need a tt:: version of add_scaled().
void add_scaled(
tensor& dest,
const float scale,
const tensor& src
);
void add_cv_to_all_columns(
float beta,
tensor& dest,
float alpha,
const tensor& src
);
// -----------------------------------------------------------------------------------
void affine_transform(
tensor& dest,
const tensor& src,
const tensor& A,
const tensor& B
);
// -----------------------------------------------------------------------------------
void affine_transform_conv(
tensor& dest,
const tensor& src,
const tensor& A,
const tensor& B
);
// ----------------------------------------------------------------------------------------
void compute_adam_update (
size_t begin,
size_t end,
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
);
// -----------------------------------------------------------------------------------
void assign_bias_gradient (
tensor& grad,
const tensor& gradient_input
);
// -----------------------------------------------------------------------------------
void threshold (
tensor& data,
float thresh
);
// ----------------------------------------------------------------------------------------
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
);
// ----------------------------------------------------------------------------------------
void prelu (
tensor& dest,
const tensor& src,
const tensor& param
);
void prelu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input,
const tensor& param,
tensor& params_grad
);
// ----------------------------------------------------------------------------------------
void resize_bilinear (
tensor& dest,
long dest_row_stride,
long dest_channel_stride,
const tensor& src,
long src_row_stride,
long src_channel_stride
);
void resize_bilinear_gradient (
tensor& grad,
long grad_row_stride,
long grad_channel_stride,
const tensor& gradient_input,
long gradient_input_row_stride,
long gradient_input_channel_stride
);
inline void resize_bilinear (
tensor& dest,
const tensor& src
) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
inline void resize_bilinear_gradient (
tensor& grad,
const tensor& gradient_input
) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
// ----------------------------------------------------------------------------------------
void copy_tensor(
bool add_to,
tensor& dest,
size_t dest_k_offset,
const tensor& src,
size_t src_k_offset,
size_t count_k
);
// ----------------------------------------------------------------------------------------
class compute_loss_multiclass_log_per_pixel
{
/*!
The point of this class is to compute the loss computed by
loss_multiclass_log_per_pixel, but to do so with CUDA.
!*/
public:
compute_loss_multiclass_log_per_pixel(
)
{
work = device_global_buffer();
}
template <
typename const_label_iterator
>
void operator() (
const_label_iterator truth,
const tensor& subnetwork_output,
tensor& gradient,
double& loss
) const
{
const size_t bytes_per_plane = subnetwork_output.nr()*subnetwork_output.nc()*sizeof(uint16_t);
// Allocate a cuda buffer to store all the truth images and also one float
// for the scalar loss output.
cuda_data_void_ptr buf = work->get(subnetwork_output.num_samples()*bytes_per_plane + sizeof(float));
cuda_data_void_ptr loss_buf = buf;
buf = buf+sizeof(float);
// copy the truth data into a cuda buffer.
for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth)
{
const matrix<uint16_t>& t = *truth;
DLIB_ASSERT(t.nr() == subnetwork_output.nr());
DLIB_ASSERT(t.nc() == subnetwork_output.nc());
memcpy(buf + i*bytes_per_plane, &t(0,0), bytes_per_plane);
}
do_work(static_cast<float*>(loss_buf.data()), static_cast<uint16_t*>(buf.data()), subnetwork_output, gradient, loss);
}
private:
static void do_work(
float* loss_cuda_work_buffer,
const uint16_t* truth_buffer,
const tensor& subnetwork_output,
tensor& gradient,
double& loss
);
std::shared_ptr<resizable_cuda_buffer> work;
};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
#else // if DLIB_USE_CUDA NOT DEFINED
inline void set_device (
int id
)
{
DLIB_CASSERT(id == 0, "dlib::cuda::set_device(id) called with an invalid device id.");
}
inline int get_device (
){ return 0; }
inline int get_num_devices (
) { return 1; }
inline std::string get_device_name (
int device
)
{
DLIB_CASSERT(device == 0, "dlib::cuda::set_device(id) called with an invalid device id.");
return "CUDA_DISABLED";
}
inline void set_current_device_blocking_sync(
) {}
inline bool can_access_peer (int , int )
{ return false; }
inline bool can_access_peer (const tensor& , const tensor& )
{ return false; }
inline void device_synchronize (int ){}
inline void device_synchronize (const tensor& ){}
class enable_peer_access
{
public:
enable_peer_access() = delete;
enable_peer_access(const enable_peer_access&) = delete;
enable_peer_access& operator=(const enable_peer_access&) = delete;
enable_peer_access( int, int ){}
enable_peer_access( const tensor&, const tensor& ) {}
};
#endif // DLIB_USE_CUDA
}
}
#endif // DLIB_DNN_CuDA_H_

View File

@@ -0,0 +1,70 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CUDA_ERRORs_H_
#define DLIB_CUDA_ERRORs_H_
#include "../error.h"
namespace dlib
{
struct cuda_error : public error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA CUDA runtime
returns an error.
!*/
cuda_error(const std::string& message): error(message) {}
};
struct cudnn_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuDNN library
returns an error.
!*/
cudnn_error(const std::string& message): cuda_error(message) {}
};
struct curand_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuRAND library
returns an error.
!*/
curand_error(const std::string& message): cuda_error(message) {}
};
struct cublas_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuBLAS library
returns an error.
!*/
cublas_error(const std::string& message): cuda_error(message) {}
};
struct cusolver_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuSolver library
returns an error.
!*/
cusolver_error(const std::string& message): cuda_error(message) {}
};
}
#endif // DLIB_CUDA_ERRORs_H_

View File

@@ -0,0 +1,413 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CUDA_UtILS_H_
#define DLIB_CUDA_UtILS_H_
#ifndef DLIB_USE_CUDA
#error "This file shouldn't be #included unless DLIB_USE_CUDA is #defined"
#endif
#include "cuda_errors.h"
#include "../algs.h"
#include <cmath>
#include <cuda_runtime.h>
#include <sstream>
#include <iostream>
#include <memory>
#include <vector>
#include <type_traits>
// Check the return value of a call to the CUDA runtime for an error condition.
#define CHECK_CUDA(call) \
do{ \
const cudaError_t error = call; \
if (error != cudaSuccess) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << cudaGetErrorString(error);\
throw dlib::cuda_error(sout.str()); \
} \
}while(false)
// ----------------------------------------------------------------------------------------
#ifdef __CUDACC__
namespace dlib
{
namespace cuda
{
// ------------------------------------------------------------------------------------
__inline__ __device__ size_t pack_idx (
size_t dim_size3,
size_t dim_size2,
size_t dim_size1,
size_t idx4,
size_t idx3,
size_t idx2,
size_t idx1
)
/*!
ensures
- Converts a 4D array index into a 1D index assuming row major layout. To
understand precisely what this function does, imagine we had an array
declared like this:
int ARRAY[anything][dim_size3][dim_size2][dim_size1];
Then we could index it like this:
ARRAY[idx4][idx3][idx2][idx1]
or equivalently like this:
((int*)ARRAY)[pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1)]
!*/
{
return ((idx4*dim_size3 + idx3)*dim_size2 + idx2)*dim_size1 + idx1;
}
__inline__ __device__ void unpack_idx (
size_t idx,
size_t dim_size3,
size_t dim_size2,
size_t dim_size1,
size_t& idx4,
size_t& idx3,
size_t& idx2,
size_t& idx1
)
/*!
ensures
- This function computes the inverse of pack_idx(). Therefore,
if PACKED == pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1)
then unpack_idx(PACKED,dim_size3,dim_size2,dim_size1, IDX4,IDX3,IDX2,IDX1)
results in:
- IDX1 == idx1
- IDX2 == idx2
- IDX3 == idx3
- IDX4 == idx4
!*/
{
idx1 = idx%dim_size1;
idx /= dim_size1;
idx2 = idx%dim_size2;
idx /= dim_size2;
idx3 = idx%dim_size3;
idx /= dim_size3;
idx4 = idx;
}
// ------------------------------------------------------------------------------------
// This function is from the article:
// http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
__inline__ __device__ float warp_reduce_sum(float val)
{
for (int offset = warpSize/2; offset > 0; offset /= 2)
#if CUDART_VERSION >= 9000
val += __shfl_down_sync(0xFFFFFFFF,val, offset);
#else
val += __shfl_down(val, offset);
#endif
return val;
}
__inline__ __device__ bool is_first_thread_in_warp()
{
return (threadIdx.x & (warpSize - 1)) == 0;
}
__inline__ __device__ void warp_reduce_atomic_add(
float& out,
float val
)
/*!
ensures
- Atomically adds all the val variables in the current warp to out.
See this page for an extended discussion:
http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
!*/
{
val = warp_reduce_sum(val);
if (is_first_thread_in_warp())
atomicAdd(&out, val);
}
// ------------------------------------------------------------------------------------
struct max_jobs
{
max_jobs(int x) : num_x(x) {}
max_jobs(int x, int y) : num_x(x), num_y(y) {}
int num_x;
int num_y = 1;
};
template <typename Kernel, typename... T>
void launch_kernel (
Kernel K,
T ...args
)
/*!
ensures
- launches the given kernel K(args...). The point of this function is to
automatically set the kernel launch parameters to something reasonable
based on the properties of the kernel and the current GPU card.
!*/
{
int num_blocks, num_threads;
CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K));
K<<<num_blocks,num_threads>>>(args...);
}
template <typename Kernel, typename... T>
void launch_kernel (
Kernel K,
max_jobs m,
T ...args
)
/*!
ensures
- This function is just like launch_kernel(K,args...) except that you can
additionally supply a max_jobs number that tells it how many possible
total threads could be used. This is useful when launching potentially
small jobs that might not need the number of threads suggested by
launch_kernel().
!*/
{
if (m.num_x == 0 || m.num_y == 0)
return;
int num_blocks, num_threads;
CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K));
// Check if the job is really small and we don't really need to launch a kernel
// with this many blocks and threads.
if (num_blocks*num_threads > m.num_x*m.num_y)
num_blocks = (m.num_x*m.num_y+num_threads-1)/num_threads;
if (m.num_y == 1)
{
K<<<num_blocks,num_threads>>>(args...);
}
else
{
/*
In general, the reason m.num_y!=1 (i.e. the reason you are in this
code path) is because we are using nested grid-stride loops. There are
two important things to note about what we are doing here. To
illustrate them we will talk about this little CUDA code snippet:
// initialize out before we begin.
for (auto i : grid_stride_range_y(0, nr))
for (auto j : grid_stride_range(0, 1))
out[i] = 0;
__syncthreads(); // synchronize threads in block
// loop over some 2D thing and sum and store things into out.
for (auto i : grid_stride_range_y(0, nr))
{
float temp = 0;
for (auto j : grid_stride_range(0, nc))
temp += whatever[i*nc+j];
// store the sum into out[i]
warp_reduce_atomic_add(out[i], temp);
}
First, we make sure the number of x threads is a multiple of 32 so that
you can use warp_reduce_atomic_add() inside the y loop.
Second, we put the x block size to 1 so inter-block synchronization is
easier. For example, if the number of x blocks wasn't 1 the above code
would have a race condition in it. This is because the execution of
out[i]=0 would be done by blocks with blockIdx.x==0, but then in the
second set of loops, *all* the x blocks use out[i]. Since
__syncthreads() doesn't do any synchronization between blocks some of
the blocks might begin before the out[i]=0 statements finished and that
would be super bad.
*/
// Try and make sure that the ratio of x to y threads is reasonable based
// on the respective size of our loops.
int x_threads = 32;
int y_threads = num_threads/32;
const int ratio = static_cast<int>(std::round(put_in_range(1, y_threads, m.num_x/(double)m.num_y)));
x_threads *= ratio;
y_threads /= ratio;
dim3 blocks(1,num_blocks);
dim3 threads(x_threads,y_threads);
K<<<blocks,threads>>>(args...);
}
}
// ------------------------------------------------------------------------------------
class grid_stride_range
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a tool for making a for loop that loops over an entire block of
memory inside a kernel, but doing so in a way that parallelizes
appropriately across all the threads in a kernel launch. For example,
the following kernel would add the vector a to the vector b and store
the output in out (assuming all vectors are of dimension n):
__global__ void add_arrays(
const float* a,
const float* b,
float* out,
size_t n
)
{
for (auto i : grid_stride_range(0, n))
{
out[i] = a[i]+b[i];
}
}
!*/
public:
__device__ grid_stride_range(
size_t ibegin_,
size_t iend_
) :
ibegin(ibegin_),
iend(iend_)
{}
class iterator
{
public:
__device__ iterator() {}
__device__ iterator(size_t pos_) : pos(pos_) {}
__device__ size_t operator*() const
{
return pos;
}
__device__ iterator& operator++()
{
pos += gridDim.x * blockDim.x;
return *this;
}
__device__ bool operator!=(const iterator& item) const
{ return pos < item.pos; }
private:
size_t pos;
};
__device__ iterator begin() const
{
return iterator(ibegin+blockDim.x * blockIdx.x + threadIdx.x);
}
__device__ iterator end() const
{
return iterator(iend);
}
private:
size_t ibegin;
size_t iend;
};
// ------------------------------------------------------------------------------------
class grid_stride_range_y
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is just like grid_stride_range except that it looks at
CUDA's y thread index (e.g. threadIdx.y) instead of the x index.
Therefore, if you launch a cuda kernel with a statement like:
dim3 blocks(1,10);
dim3 threads(32,32); // You need to have x and y not equal to 1 to get parallelism over both loops.
add_arrays<<<blocks,threads>>>(a,b,out,nr,nc);
You can perform a nested 2D parallel for loop rather than doing just a
1D for loop.
So the code in the kernel would look like this if you wanted to add two
2D matrices:
__global__ void add_arrays(
const float* a,
const float* b,
float* out,
size_t nr,
size_t nc
)
{
for (auto r : grid_stride_range_y(0, nr))
{
for (auto c : grid_stride_range(0, nc))
{
auto i = r*nc+c;
out[i] = a[i]+b[i];
}
}
}
!*/
public:
__device__ grid_stride_range_y(
size_t ibegin_,
size_t iend_
) :
ibegin(ibegin_),
iend(iend_)
{}
class iterator
{
public:
__device__ iterator() {}
__device__ iterator(size_t pos_) : pos(pos_) {}
__device__ size_t operator*() const
{
return pos;
}
__device__ iterator& operator++()
{
pos += gridDim.y * blockDim.y;
return *this;
}
__device__ bool operator!=(const iterator& item) const
{ return pos < item.pos; }
private:
size_t pos;
};
__device__ iterator begin() const
{
return iterator(ibegin+blockDim.y * blockIdx.y + threadIdx.y);
}
__device__ iterator end() const
{
return iterator(iend);
}
private:
size_t ibegin;
size_t iend;
};
// ------------------------------------------------------------------------------------
}
}
#endif // __CUDACC__
// ----------------------------------------------------------------------------------------
#endif // DLIB_CUDA_UtILS_H_

View File

@@ -0,0 +1,518 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuDNN_H_
#define DLIB_DNN_CuDNN_H_
#ifdef DLIB_USE_CUDA
#include "cuda_errors.h"
#include <memory>
#include "cuda_data_ptr.h"
namespace dlib
{
class tensor;
class resizable_tensor;
namespace cuda
{
// -----------------------------------------------------------------------------------
class tensor_descriptor
{
/*!
Each tensor object will carry a tensor_descriptor in it when compiled with
CUDA.
!*/
public:
// not copyable
tensor_descriptor(const tensor_descriptor&) = delete;
tensor_descriptor& operator=(const tensor_descriptor&) = delete;
// but is movable
tensor_descriptor(tensor_descriptor&& item) : tensor_descriptor() { swap(item); }
tensor_descriptor& operator=(tensor_descriptor&& item) { swap(item); return *this; }
tensor_descriptor();
~tensor_descriptor();
void set_size(
int n,
int k,
int nr,
int nc
);
/*!
ensures
- if any of the arguments are 0 then they are all set to 0 in the tensor.
!*/
void get_size (
int& n,
int& k,
int& nr,
int& nc
) const;
const void* get_handle (
) const { return handle; }
private:
void swap(tensor_descriptor& item) { std::swap(handle, item.handle); }
void* handle;
};
// ------------------------------------------------------------------------------------
void add(
float beta,
tensor& dest,
float alpha,
const tensor& src
);
/*!
requires
- One of the following is true:
- have_same_dimensions(src, dest)
- src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1
- src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()
- src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()
- is_same_object(src,dest) == false
ensures
- performs: dest = beta*dest + alpha*src
However, how the addition happens depends on the dimensions of src. In
particular, this function adds the scaled values of one src tensor to
dest. Each dimension of the src tensor must match the corresponding
dimension of the dest tensor or must be equal to 1. In the latter case,
the same value from the src tensor, for those dimensions, will be used to
add into the dest tensor.
!*/
// ------------------------------------------------------------------------------------
void assign_conv_bias_gradient (
tensor& grad,
const tensor& gradient_input
);
/*!
requires
- grad.num_samples() == 1
- grad.k() >= 1
- grad.nr() == 1
- grad.nc() == 1
- gradient_input.k() == grad.k()
- gradient_input.size() > 0
- is_same_object(grad,gradient_input) == false
ensures
- let BIAS be a tensor with all dimensions equal to 1 except for k which is >= 1.
- let OUT be the output of add(1,OUT,1,BIAS)
- let f(gradient_input,BIAS) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to BIAS and
assigns it to grad.
!*/
// ------------------------------------------------------------------------------------
void batch_normalize_inference (
const double eps,
resizable_tensor& dest,
const tensor& src,
const tensor& gamma,
const tensor& beta,
const tensor& running_means,
const tensor& running_variances
);
void batch_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const double averaging_factor,
resizable_tensor& running_means,
resizable_tensor& running_variances,
const tensor& src,
const tensor& gamma,
const tensor& beta
);
void batch_normalize_gradient(
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
// ------------------------------------------------------------------------------------
void batch_normalize_conv_inference (
const double eps,
resizable_tensor& dest,
const tensor& src,
const tensor& gamma,
const tensor& beta,
const tensor& running_means,
const tensor& running_variances
);
void batch_normalize_conv (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const double averaging_factor,
resizable_tensor& running_means,
resizable_tensor& running_variances,
const tensor& src,
const tensor& gamma,
const tensor& beta
);
void batch_normalize_conv_gradient(
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
// ------------------------------------------------------------------------------------
class tensor_conv
{
public:
tensor_conv(const tensor_conv&) = delete;
tensor_conv& operator=(const tensor_conv&) = delete;
tensor_conv();
void clear(
);
~tensor_conv (
);
void operator() (
const bool add_to_output,
tensor& output,
const tensor& data,
const tensor& filters
);
void operator() (
const bool add_to_output,
resizable_tensor& output,
const tensor& data,
const tensor& filters
);
void get_gradient_for_data (
const bool add_to_output,
const tensor& gradient_input,
const tensor& filters,
tensor& data_gradient
);
void get_gradient_for_filters (
const bool add_to_output,
const tensor& gradient_input,
const tensor& data,
tensor& filters_gradient
);
void setup(
const tensor& data,
const tensor& filters,
int stride_y,
int stride_x,
int padding_y,
int padding_x
);
private:
// These variables record the type of data given to the last call to setup().
int stride_y;
int stride_x;
int padding_y;
int padding_x;
long data_num_samples, data_k, data_nr, data_nc;
long filters_num_samples, filters_k, filters_nr, filters_nc;
void* filter_handle;
void* conv_handle;
// dimensions of the output tensor from operator()
int out_num_samples;
int out_k;
int out_nr;
int out_nc;
int forward_algo;
int backward_data_algo;
int backward_filters_algo;
size_t forward_workspace_size_in_bytes;
size_t backward_data_workspace_size_in_bytes;
size_t backward_filters_workspace_size_in_bytes;
std::shared_ptr<resizable_cuda_buffer> workspace;
cuda_data_void_ptr forward_workspace;
cuda_data_void_ptr backward_data_workspace;
cuda_data_void_ptr backward_filters_workspace;
};
// ------------------------------------------------------------------------------------
class pooling
{
public:
pooling(const pooling&) = delete;
pooling& operator=(const pooling&) = delete;
pooling (
);
~pooling(
);
void clear(
);
void setup_max_pooling(
int window_height,
int window_width,
int stride_y,
int stride_x,
int padding_y,
int padding_x
);
void setup_avg_pooling(
int window_height,
int window_width,
int stride_y,
int stride_x,
int padding_y,
int padding_x
);
bool does_max_pooling(
) const { return do_max_pooling; }
void operator() (
resizable_tensor& dest,
const tensor& src
);
void get_gradient(
const tensor& gradient_input,
const tensor& dest,
const tensor& src,
tensor& grad
);
private:
void setup(
int window_height,
int window_width,
int stride_y,
int stride_x,
int padding_y,
int padding_x,
int pooling_mode
);
void* handle;
int window_height;
int window_width;
int stride_y;
int stride_x;
int padding_y;
int padding_x;
bool do_max_pooling;
};
// ------------------------------------------------------------------------------------
void softmax (
tensor& dest,
const tensor& src
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- Note that the softmax function is a vector valued function:
s(x) == exp(x)/sum(exp(x))
- Computes the softmax function on src and writes the results to dest. The
softmax is computed per spatial location across the different channels at
each location. That is, softmax() outputs a new tensor, #dest, where
each of the spatial locations in dest (i.e. image idx, row idx, and
column idx) contains the output of s() evaluated over the channel values
at each location.
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void softmax_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
- is_same_object(grad, dest)==false
ensures
- We interpret dest as the output of softmax(dest,SRC) for some SRC tensor.
Then let f(SRC) == dot(gradient_input,dest) Then this function computes
the gradient of f() with respect to SRC and assigns it to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ------------------------------------------------------------------------------------
void softmax_all (
tensor& dest,
const tensor& src
);
void softmax_all_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ------------------------------------------------------------------------------------
void sigmoid (
tensor& dest,
const tensor& src
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- #dest.host()[i] == 1/(1+std::exp(-src.host()[i]))
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void sigmoid_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
- is_same_object(grad,dest) == false
ensures
- Recalling that dest is the output of sigmoid(dest,SRC) for some SRC tensor,
let f(SRC) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to SRC and
assigns it to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ------------------------------------------------------------------------------------
void relu (
tensor& dest,
const tensor& src
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- #dest.host()[i] == std::max(0,src.host()[i])
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
- is_same_object(grad,dest) == false
ensures
- Recalling that dest is the output of relu(dest,SRC) for some SRC tensor,
let f(SRC) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to SRC and
assigns it to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ------------------------------------------------------------------------------------
void tanh (
tensor& dest,
const tensor& src
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- #dest.host()[i] == std::tanh(src.host()[i])
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void tanh_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
- is_same_object(grad,dest) == false
ensures
- Recalling that dest is the output of tanh(dest,SRC) for some SRC tensor,
let f(SRC) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to SRC and
assigns it to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ------------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuDNN_H_

View File

@@ -0,0 +1,75 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuRAND_H_
#define DLIB_DNN_CuRAND_H_
#ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "cuda_errors.h"
#include "cuda_data_ptr.h"
namespace dlib
{
namespace cuda
{
// -----------------------------------------------------------------------------------
class curand_generator
{
public:
// not copyable
curand_generator(const curand_generator&) = delete;
curand_generator& operator=(const curand_generator&) = delete;
curand_generator() : curand_generator(0) {}
curand_generator(unsigned long long seed);
~curand_generator();
void fill (
cuda_data_ptr<unsigned int>& data
);
/*!
ensures
- Fills data with random 32-bit unsigned integers.
!*/
void fill_gaussian (
tensor& data,
float mean = 0,
float stddev = 1
);
/*!
requires
- data.size()%2 == 0
- stddev >= 0
ensures
- Fills data with random numbers drawn from a Gaussian distribution
with the given mean and standard deviation.
!*/
void fill_uniform (
tensor& data
);
/*!
ensures
- Fills data with uniform random numbers in the range (0.0, 1.0].
!*/
private:
void* handle;
};
// -----------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuRAND_H_

View File

@@ -0,0 +1,75 @@
// Copyright (C) 2017 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuSOLVER_H_
#define DLIB_DNN_CuSOLVER_H_
#ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "cuda_errors.h"
#include "cuda_data_ptr.h"
#include "../noncopyable.h"
namespace dlib
{
namespace cuda
{
// -----------------------------------------------------------------------------------
class inv : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a functor for doing matrix inversion on the GPU. The only
reason it's an object is to avoid the reallocation of some GPU memory
blocks if you want to do a bunch of matrix inversions in a row.
!*/
public:
inv() = default;
~inv();
void operator() (
const tensor& m,
resizable_tensor& out
);
/*!
requires
- m.size() == m.num_samples()*m.num_samples()
(i.e. mat(m) must be a square matrix)
ensures
- out == inv(mat(m));
!*/
int get_last_status(
);
/*!
ensures
- returns 0 if the last matrix inversion was successful and != 0
otherwise.
!*/
private:
void sync_if_needed();
bool did_work_lately = false;
resizable_tensor m;
cuda_data_ptr<float> workspace;
cuda_data_ptr<int> Ipiv;
cuda_data_ptr<int> info;
};
// ------------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuSOLVER_H_

View File

@@ -0,0 +1,266 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_GPU_DaTA_H_
#define DLIB_GPU_DaTA_H_
#include "gpu_data_abstract.h"
#include <memory>
#include <cstring>
#include "cuda_errors.h"
#include "../serialize.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class gpu_data
{
/*!
CONVENTION
- if (size() != 0) then
- data_host == a pointer to size() floats in CPU memory.
- if (data_device) then
- data_device == a pointer to size() floats in device memory.
- if (there might be an active async transfer from host to device) then
- have_active_transfer == true
- We use the host_current and device_current bools to keep track of which
copy of the data (or both) are most current. e.g. if the CPU has
modified the data and it hasn't been copied to the device yet then
host_current==true and device_current==false.
Similarly, we use device_in_use==true to indicate that device() has been
called and no operation to wait for all CUDA kernel completion has been
executed. So if device_in_use==true then there might be a CUDA kernel
executing that is using the device memory block contained in this object.
!*/
public:
gpu_data(
) : data_size(0), host_current(true), device_current(true),have_active_transfer(false),device_in_use(false), the_device_id(0)
{
}
// Not copyable
gpu_data(const gpu_data&) = delete;
gpu_data& operator=(const gpu_data&) = delete;
// but is movable
gpu_data(gpu_data&& item) : gpu_data() { swap(item); }
gpu_data& operator=(gpu_data&& item) { swap(item); return *this; }
int device_id() const { return the_device_id; }
#ifdef DLIB_USE_CUDA
void async_copy_to_device() const;
void set_size(size_t new_size);
#else
// Note that calls to host() or device() will block until any async transfers are complete.
void async_copy_to_device() const{}
void set_size(size_t new_size)
{
if (new_size == 0)
{
data_size = 0;
host_current = true;
device_current = true;
device_in_use = false;
data_host.reset();
data_device.reset();
}
else if (new_size != data_size)
{
data_size = new_size;
host_current = true;
device_current = true;
device_in_use = false;
data_host.reset(new float[new_size], std::default_delete<float[]>());
data_device.reset();
}
}
#endif
const float* host() const
{
copy_to_host();
return data_host.get();
}
float* host()
{
copy_to_host();
device_current = false;
return data_host.get();
}
float* host_write_only()
{
host_current = true;
device_current = false;
return data_host.get();
}
const float* device() const
{
#ifndef DLIB_USE_CUDA
DLIB_CASSERT(false, "CUDA NOT ENABLED");
#endif
copy_to_device();
device_in_use = true;
return data_device.get();
}
float* device()
{
#ifndef DLIB_USE_CUDA
DLIB_CASSERT(false, "CUDA NOT ENABLED");
#endif
copy_to_device();
host_current = false;
device_in_use = true;
return data_device.get();
}
float* device_write_only()
{
#ifndef DLIB_USE_CUDA
DLIB_CASSERT(false, "CUDA NOT ENABLED");
#endif
wait_for_transfer_to_finish();
host_current = false;
device_current = true;
device_in_use = true;
return data_device.get();
}
bool host_ready (
) const { return host_current; }
bool device_ready (
) const { return device_current && !have_active_transfer; }
size_t size() const { return data_size; }
void swap (gpu_data& item)
{
std::swap(data_size, item.data_size);
std::swap(host_current, item.host_current);
std::swap(device_current, item.device_current);
std::swap(have_active_transfer, item.have_active_transfer);
std::swap(data_host, item.data_host);
std::swap(data_device, item.data_device);
std::swap(cuda_stream, item.cuda_stream);
std::swap(the_device_id, item.the_device_id);
}
private:
#ifdef DLIB_USE_CUDA
void copy_to_device() const;
void copy_to_host() const;
void wait_for_transfer_to_finish() const;
#else
void copy_to_device() const{}
void copy_to_host() const{}
void wait_for_transfer_to_finish() const{}
#endif
size_t data_size;
mutable bool host_current;
mutable bool device_current;
mutable bool have_active_transfer;
mutable bool device_in_use;
std::shared_ptr<float> data_host;
std::shared_ptr<float> data_device;
std::shared_ptr<void> cuda_stream;
int the_device_id;
};
inline void serialize(const gpu_data& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.size(), out);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
serialize(data[i], out);
}
inline void deserialize(gpu_data& item, std::istream& in)
{
int version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::gpu_data.");
size_t s;
deserialize(s, in);
item.set_size(s);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
deserialize(data[i], in);
}
#ifdef DLIB_USE_CUDA
void memcpy (gpu_data& dest, const gpu_data& src);
void memcpy (
gpu_data& dest,
size_t dest_offset,
const gpu_data& src,
size_t src_offset,
size_t num
);
#else
inline void memcpy (gpu_data& dest, const gpu_data& src)
{
DLIB_CASSERT(dest.size() == src.size());
if (src.size() == 0 || &dest == &src)
return;
std::memcpy(dest.host_write_only(), src.host(), sizeof(float)*src.size());
}
inline void memcpy (
gpu_data& dest,
size_t dest_offset,
const gpu_data& src,
size_t src_offset,
size_t num
)
{
DLIB_CASSERT(dest_offset + num <= dest.size());
DLIB_CASSERT(src_offset + num <= src.size());
if (num == 0)
return;
if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num)
{
// if they perfectly alias each other then there is nothing to do
if (dest_offset == src_offset)
return;
else
std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num);
}
else
{
// if we write to the entire thing then we can use host_write_only()
if (dest_offset == 0 && num == dest.size())
std::memcpy(dest.host_write_only(), src.host()+src_offset, sizeof(float)*num);
else
std::memcpy(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num);
}
}
#endif
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_GPU_DaTA_H_

View File

@@ -0,0 +1,266 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_GPU_DaTA_ABSTRACT_H_
#ifdef DLIB_GPU_DaTA_ABSTRACT_H_
#include "cuda_errors.h"
#include "../serialize.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class gpu_data
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is a block of size() floats, all stored contiguously in memory.
Importantly, it keeps two copies of the floats, one on the host CPU side
and another on the GPU device side. It automatically performs the necessary
host/device transfers to keep these two copies of the data in sync.
All transfers to the device happen asynchronously with respect to the
default CUDA stream so that CUDA kernel computations can overlap with data
transfers. However, any transfers from the device to the host happen
synchronously in the default CUDA stream. Therefore, you should perform
all your CUDA kernel launches on the default stream so that transfers back
to the host do not happen before the relevant computations have completed.
If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all.
Instead, it will simply store one host side memory block of floats.
THREAD SAFETY
Instances of this object are not thread-safe. So don't touch one from
multiple threads at the same time.
!*/
public:
gpu_data(
);
/*!
ensures
- #size() == 0
- #host() == nullptr
- #device() == nullptr
- #host_ready() == true
- #device_ready() == true
- #device_id() == 0
!*/
// This object is not copyable, however, it is movable.
gpu_data(const gpu_data&) = delete;
gpu_data& operator=(const gpu_data&) = delete;
gpu_data(gpu_data&& item);
gpu_data& operator=(gpu_data&& item);
int device_id(
) const;
/*!
ensures
- returns the ID of the CUDA device that allocated this memory. I.e. the
number returned by cudaGetDevice() when the memory was allocated.
- If CUDA is not being used then this function always returns 0.
!*/
void async_copy_to_device(
);
/*!
ensures
- if (!device_ready()) then
- Begins asynchronously copying host data to the device once it is safe
to do so. I.e. This function will wait until any previously
scheduled CUDA kernels, which are using the device() memory block,
have completed before transferring the new data to the device.
- A call to device() that happens before the transfer completes will
block until the transfer is complete. That is, it is safe to call
async_copy_to_device() and then immediately call device().
!*/
void set_size(
size_t new_size
);
/*!
ensures
- #size() == new_size
!*/
bool host_ready (
) const;
/*!
ensures
- returns true if and only if the host's copy of the data is current. The
host's data is current if there aren't any modifications to the data
which were made on the device side that have yet to be copied to the
host.
!*/
bool device_ready (
) const;
/*!
ensures
- returns true if and only if the device's copy of the data is current.
The device's data is current if there aren't any modifications to the
data which were made on the host side that have yet to be copied to the
device.
!*/
const float* host(
) const;
/*!
ensures
- returns a pointer to the host memory block of size() contiguous float
values or nullptr if size()==0.
- if (!host_ready()) then
- copies the data from the device to the host, while this is happening
the call to host() blocks.
- #host_ready() == true
!*/
float* host(
);
/*!
ensures
- returns a pointer to the host memory block of size() contiguous float
values or nullptr if size()==0.
- if (!host_ready()) then
- copies the data from the device to the host, while this is happening
the call to host() blocks.
- #host_ready() == true
- #device_ready() == false
I.e. Marks the device side data as out of date so that the next call to
device() will perform a host to device transfer. If you want to begin
the transfer immediately then you can call async_copy_to_device() after
calling host().
!*/
float* host_write_only(
);
/*!
ensures
- This function returns the same pointer as host(), except that it never
performs a device to host memory copy. Instead, it immediately marks the
device side data as out of date, effectively discarding it. Therefore,
the values in the data pointed to by host_write_only() are undefined and
you should only call host_write_only() if you are going to assign to
every memory location in the returned memory block.
- #host_ready() == true
- #device_ready() == false
!*/
const float* device(
) const;
/*!
requires
- DLIB_USE_CUDA is #defined
ensures
- returns a pointer to the device memory block of size() contiguous float
values or nullptr if size()==0.
- if (!device_ready()) then
- copies the data from the host to the device, while this is happening
the call to device() blocks.
- #device_ready() == true
!*/
float* device(
);
/*!
requires
- DLIB_USE_CUDA is #defined
ensures
- returns a pointer to the device memory block of size() contiguous float
values or nullptr if size()==0.
- if (!device_ready()) then
- copies the data from the host to the device, while this is happening
the call to device() blocks.
- #host_ready() == false
- #device_ready() == true
!*/
float* device_write_only(
);
/*!
requires
- DLIB_USE_CUDA is #defined
ensures
- This function returns the same pointer as device(), except that it never
performs a host to device memory copy. Instead, it immediately marks the
host side data as out of date, effectively discarding it. Therefore, the
values in the data pointed to by device_write_only() are undefined and
you should only call device_write_only() if you are going to assign to
every memory location in the returned memory block.
- #host_ready() == false
- #device_ready() == true
!*/
size_t size(
) const;
/*!
ensures
- returns the number of floats contained in this object.
!*/
void swap (
gpu_data& item
);
/*!
ensures
- swaps the state of *this and item
!*/
};
void serialize(const gpu_data& item, std::ostream& out);
void deserialize(gpu_data& item, std::istream& in);
/*!
provides serialization support
!*/
void memcpy (
gpu_data& dest,
const gpu_data& src
);
/*!
requires
- dest.size() == src.size()
ensures
- Copies the data in src to dest. If the device data is current (i.e.
device_ready()==true) on both src and dest then the copy will happen entirely
on the device side.
- It doesn't matter what GPU device is selected by cudaSetDevice(). You can
always copy gpu_data objects to and from each other regardless.
- This function blocks until the copy has completed.
!*/
void memcpy (
gpu_data& dest,
size_t dest_offset,
const gpu_data& src,
size_t src_offset,
size_t num
);
/*!
requires
- dest_offset + num <= dest.size()
- src_offset + num <= src.size()
ensures
- Copies the data in src to dest, but only copies data in the range
[src.host()+src_offset, src.host()+src_offset+num) to
[dest.host()+dest_offset, dest.host()+dest_offset+num). Therefore, it is
just like the above memcpy() except that you can specify some subset of data
in a gpu_data object to be copied.
- Like the above version of memcpy(), the copy will happen in the most
efficient way, automatically using the appropriate type of host/device
transfers based on where data is currently resident.
- It doesn't matter what GPU device is selected by cudaSetDevice(). You can
always copy gpu_data objects to and from each other regardless.
- This function blocks until the copy has completed.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_GPU_DaTA_ABSTRACT_H_

View File

@@ -0,0 +1,686 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_TENSOR_H_
#define DLIB_DNn_TENSOR_H_
#include "tensor_abstract.h"
#include <cstring>
#include "../matrix.h"
#include "cudnn_dlibapi.h"
#include "gpu_data.h"
#include "../byte_orderer.h"
#include <memory>
#include "../any.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class tensor;
namespace cuda
{
void set_tensor (
tensor& t,
float value
);
void scale_tensor (
tensor& t,
float value
);
}
// ----------------------------------------------------------------------------------------
class tensor
{
public:
tensor (
) :
m_n(0), m_k(0), m_nr(0), m_nc(0), m_size(0)
{
}
virtual ~tensor() {}
long long num_samples() const { return m_n; }
long long k() const { return m_k; }
long long nr() const { return m_nr; }
long long nc() const { return m_nc; }
size_t size() const { return m_size; }
typedef float* iterator;
typedef const float* const_iterator;
iterator begin() { return host(); }
const_iterator begin() const { return host(); }
iterator end() { return host()+size(); }
const_iterator end() const { return host()+size(); }
void async_copy_to_device() const
{
data().async_copy_to_device();
}
virtual const float* host() const = 0;
virtual float* host() = 0;
virtual float* host_write_only() = 0;
virtual const float* device() const = 0;
virtual float* device() = 0;
virtual float* device_write_only() = 0;
virtual const any& annotation() const = 0;
virtual any& annotation() = 0;
int device_id() const { return data().device_id(); }
tensor& operator= (float val)
{
#ifdef DLIB_USE_CUDA
// If you are using CUDA then presumably you will be mostly using tensors on
// the GPU. So unless you seem to be actively working with the host side's
// data then we do this initialization on the device side since this avoids a
// host to device transfer that would likely immediately follow.
if (data().device_ready())
{
cuda::set_tensor(*this, val);
return *this;
}
#endif
auto d = host_write_only();
for (size_t i = 0; i < size(); ++i)
d[i] = val;
return *this;
}
tensor& operator*= (float val)
{
#ifdef DLIB_USE_CUDA
cuda::scale_tensor(*this, val);
return *this;
#else
for (auto& d : *this)
d *= val;
return *this;
#endif
}
tensor& operator/= (float val)
{
*this *= 1.0/val;
return *this;
}
template <typename EXP>
tensor& operator= (const matrix_exp<EXP>& item)
{
DLIB_CASSERT(num_samples() == item.nr() &&
nr()*nc()*k() == item.nc());
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(host_write_only(), m_n, m_nr*m_nc*m_k) = item;
return *this;
}
template <typename EXP>
tensor& operator+= (const matrix_exp<EXP>& item)
{
DLIB_CASSERT(num_samples() == item.nr() &&
nr()*nc()*k() == item.nc());
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(host(), m_n, m_nr*m_nc*m_k) += item;
return *this;
}
template <typename EXP>
tensor& operator-= (const matrix_exp<EXP>& item)
{
DLIB_CASSERT(num_samples() == item.nr() &&
nr()*nc()*k() == item.nc());
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(host(), m_n, m_nr*m_nc*m_k) -= item;
return *this;
}
template <typename EXP>
void set_sample (
unsigned long long idx,
const matrix_exp<EXP>& item
)
{
DLIB_CASSERT(idx < (unsigned long long)num_samples());
DLIB_CASSERT(item.size() == nr()*nc()*k());
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item;
}
template <typename EXP>
void add_to_sample (
unsigned long long idx,
const matrix_exp<EXP>& item
)
{
DLIB_CASSERT(idx < (unsigned long long)num_samples());
DLIB_CASSERT(item.size() == nr()*nc()*k());
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item;
}
#ifdef DLIB_USE_CUDA
virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
) const = 0;
#endif
friend void memcpy (
tensor& dest,
const tensor& src
)
{
DLIB_CASSERT(dest.size() == src.size());
memcpy(dest.data(), dest.get_alias_offset(),
src.data(), src.get_alias_offset(),
src.size());
}
protected:
friend class alias_tensor;
virtual gpu_data& data() = 0;
virtual const gpu_data& data() const = 0;
virtual size_t get_alias_offset() const { return 0; } // needed by alias_tensor.
long long m_n;
long long m_k;
long long m_nr;
long long m_nc;
long long m_size; // always equal to m_n*m_k*m_nr*m_nc
};
// ----------------------------------------------------------------------------------------
inline bool is_vector (
const tensor& t
)
{
return t.size() == (size_t)t.num_samples() ||
t.size() == (size_t)t.k() ||
t.size() == (size_t)t.nr() ||
t.size() == (size_t)t.nc();
}
// ----------------------------------------------------------------------------------------
inline const matrix_op<op_pointer_to_mat<float> > mat (
const tensor& t,
long long nr,
long long nc
)
{
DLIB_ASSERT(nr >= 0 && nc >= 0 ,
"\tconst matrix_exp mat(tensor, nr, nc)"
<< "\n\t nr and nc must be >= 0"
<< "\n\t nr: " << nr
<< "\n\t nc: " << nc
);
DLIB_ASSERT(nr*nc == (long long)t.size() ,
"\tconst matrix_exp mat(tensor, nr, nc)"
<< "\n\t The sizes don't match up."
<< "\n\t nr*nc: " << nr*nc
<< "\n\t t.size(): " << t.size()
);
typedef op_pointer_to_mat<float> op;
return matrix_op<op>(op(t.host(),nr,nc));
}
inline const matrix_op<op_pointer_to_mat<float> > mat (
const tensor& t
)
{
if (t.size() != 0)
return mat(t, t.num_samples(), t.size()/t.num_samples());
else
return mat((float*)0,0,0);
}
inline const matrix_op<op_pointer_to_mat<float> > image_plane (
const tensor& t,
long long sample = 0,
long long k = 0
)
{
DLIB_ASSERT(0 <= sample && sample < t.num_samples() &&
0 <= k && k < t.k() &&
t.size() != 0,
"\tconst matrix_exp image_plane(tensor,sample,k)"
<< "\n\t Invalid arguments were given to this function."
<< "\n\t sample: " << sample
<< "\n\t k: " << k
<< "\n\t t.num_samples(): " << t.num_samples()
<< "\n\t t.k(): " << t.k()
<< "\n\t t.size(): " << t.size()
);
typedef op_pointer_to_mat<float> op;
return matrix_op<op>(op(t.host() + ((sample*t.k() + k)*t.nr())*t.nc(),
t.nr(),
t.nc()));
}
// ----------------------------------------------------------------------------------------
inline bool have_same_dimensions (
const tensor& a,
const tensor& b
)
{
return a.num_samples() == b.num_samples() &&
a.k() == b.k() &&
a.nr() == b.nr() &&
a.nc() == b.nc();
}
// ----------------------------------------------------------------------------------------
class resizable_tensor : public tensor
{
public:
resizable_tensor(
)
{}
template <typename EXP>
resizable_tensor(
const matrix_exp<EXP>& item
)
{
set_size(item.nr(), item.nc());
*this = item;
}
explicit resizable_tensor(
long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
)
{
DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0);
set_size(n_,k_,nr_,nc_);
}
resizable_tensor(const resizable_tensor& item) : _annotation(item.annotation())
{
copy_size(item);
memcpy(*this, item);
}
resizable_tensor(const tensor& item) : _annotation(item.annotation())
{
copy_size(item);
memcpy(*this, item);
}
resizable_tensor(resizable_tensor&& item) { swap(item); }
resizable_tensor& operator=(resizable_tensor&& item) { swap(item); return *this; }
virtual const float* host() const { return data_instance.host(); }
virtual float* host() { return data_instance.host(); }
virtual float* host_write_only() { return data_instance.host_write_only(); }
virtual const float* device() const { return data_instance.device(); }
virtual float* device() { return data_instance.device(); }
virtual float* device_write_only() { return data_instance.device_write_only(); }
virtual const any& annotation() const { return _annotation; }
virtual any& annotation() { return _annotation; }
void clear(
)
{
set_size(0,0,0,0);
_annotation.clear();
// free underlying memory
data_instance.set_size(0);
}
void copy_size (
const tensor& item
)
{
set_size(item.num_samples(), item.k(), item.nr(), item.nc());
}
resizable_tensor& operator= (float val)
{
tensor::operator=(val);
return *this;
}
template <typename EXP>
resizable_tensor& operator= (
const matrix_exp<EXP>& item
)
{
if (!(num_samples() == item.nr() && k()*nr()*nc() == item.nc()))
set_size(item.nr(), item.nc());
tensor::operator=(item);
return *this;
}
void set_size(
long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
)
{
DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0);
m_n = n_;
m_k = k_;
m_nr = nr_;
m_nc = nc_;
m_size = n_*k_*nr_*nc_;
if ((long long)data_instance.size() < m_size)
data_instance.set_size(m_size);
#ifdef DLIB_USE_CUDA
cudnn_descriptor.set_size(m_n,m_k,m_nr,m_nc);
#endif
}
resizable_tensor& operator= (const resizable_tensor& item)
{
resizable_tensor temp(item);
temp.swap(*this);
return *this;
}
resizable_tensor& operator= (const tensor& item)
{
resizable_tensor temp(item);
temp.swap(*this);
return *this;
}
void swap(resizable_tensor& item)
{
std::swap(m_n, item.m_n);
std::swap(m_k, item.m_k);
std::swap(m_nr, item.m_nr);
std::swap(m_nc, item.m_nc);
std::swap(m_size, item.m_size);
std::swap(data_instance, item.data_instance);
std::swap(_annotation, item._annotation);
#ifdef DLIB_USE_CUDA
std::swap(cudnn_descriptor, item.cudnn_descriptor);
#endif
}
#ifdef DLIB_USE_CUDA
virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
) const { return cudnn_descriptor; }
#endif
private:
#ifdef DLIB_USE_CUDA
cuda::tensor_descriptor cudnn_descriptor;
#endif
gpu_data data_instance;
any _annotation;
virtual gpu_data& data() { return data_instance; }
virtual const gpu_data& data() const { return data_instance; }
};
inline void serialize(const tensor& item, std::ostream& out)
{
int version = 2;
serialize(version, out);
serialize(item.num_samples(), out);
serialize(item.k(), out);
serialize(item.nr(), out);
serialize(item.nc(), out);
byte_orderer bo;
auto sbuf = out.rdbuf();
for (auto d : item)
{
// Write out our data as 4byte little endian IEEE floats rather than using
// dlib's default float serialization. We do this because it will result in
// more compact outputs. It's slightly less portable but it seems doubtful
// that any CUDA enabled platform isn't going to use IEEE floats. But if one
// does we can just update the serialization code here to handle it if such a
// platform is encountered.
bo.host_to_little(d);
static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats");
sbuf->sputn((char*)&d, sizeof(d));
}
}
inline void deserialize(resizable_tensor& item, std::istream& in)
{
int version;
deserialize(version, in);
if (version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor.");
long long num_samples=0, k=0, nr=0, nc=0;
deserialize(num_samples, in);
deserialize(k, in);
deserialize(nr, in);
deserialize(nc, in);
item.set_size(num_samples, k, nr, nc);
byte_orderer bo;
auto sbuf = in.rdbuf();
for (auto& d : item)
{
static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats");
if (sbuf->sgetn((char*)&d,sizeof(d)) != sizeof(d))
{
in.setstate(std::ios::badbit);
throw serialization_error("Error reading data while deserializing dlib::resizable_tensor.");
}
bo.little_to_host(d);
}
}
// ----------------------------------------------------------------------------------------
inline double dot(
const tensor& a,
const tensor& b
)
{
DLIB_CASSERT(a.size() == b.size());
const float* da = a.host();
const float* db = b.host();
double sum = 0;
for (size_t i = 0; i < a.size(); ++i)
sum += da[i]*db[i];
return sum;
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class alias_tensor_instance : public tensor
{
alias_tensor_instance(
) : data_instance(0), _annotation(0), data_offset(0) {}
public:
friend class alias_tensor;
friend class alias_tensor_const_instance;
alias_tensor_instance& operator= (float val)
{
tensor::operator=(val);
return *this;
}
template <typename EXP>
alias_tensor_instance& operator= (const matrix_exp<EXP>& item)
{
tensor::operator=(item);
return *this;
}
virtual const float* host() const { return data_instance->host()+data_offset; }
virtual float* host() { return data_instance->host()+data_offset; }
virtual float* host_write_only() { return data_instance->host()+data_offset; }
virtual const float* device() const { return data_instance->device()+data_offset; }
virtual float* device() { return data_instance->device()+data_offset; }
virtual float* device_write_only() { return data_instance->device()+data_offset; }
virtual const any& annotation() const { return *_annotation; }
virtual any& annotation() { return *_annotation; }
#ifdef DLIB_USE_CUDA
virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
) const { return *cudnn_descriptor; }
#endif
private:
virtual size_t get_alias_offset() const { return data_offset; }
#ifdef DLIB_USE_CUDA
std::shared_ptr<cuda::tensor_descriptor> cudnn_descriptor;
#endif
gpu_data* data_instance;
any* _annotation;
size_t data_offset;
virtual gpu_data& data() { return *data_instance; }
virtual const gpu_data& data() const { return *data_instance; }
};
// ----------------------------------------------------------------------------------------
class alias_tensor_const_instance
{
public:
const tensor& get() const { return inst; }
operator const tensor& () { return inst; }
alias_tensor_const_instance(const alias_tensor_instance& item) : inst(item) {}
private:
alias_tensor_instance inst;
friend class alias_tensor;
alias_tensor_const_instance() {}
};
// ----------------------------------------------------------------------------------------
class alias_tensor
{
public:
alias_tensor (
) {}
alias_tensor (
long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
)
{
DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0);
inst.m_n = n_;
inst.m_k = k_;
inst.m_nr = nr_;
inst.m_nc = nc_;
inst.m_size = n_*k_*nr_*nc_;
}
long long num_samples(
) const { return inst.m_n; }
long long k(
) const { return inst.m_k; }
long long nr(
) const { return inst.m_nr; }
long long nc(
) const { return inst.m_nc; }
size_t size(
) const { return inst.m_size; }
alias_tensor_instance operator() (
tensor& t,
size_t offset = 0
) const
{
DLIB_CASSERT(offset+size() <= t.size(),
"offset: "<<offset <<"\n"<<
"size(): "<<size() <<"\n"<<
"t.size(): "<<t.size() <<"\n");
#ifdef DLIB_USE_CUDA
if (!inst.cudnn_descriptor)
{
inst.cudnn_descriptor = std::make_shared<cuda::tensor_descriptor>();
inst.cudnn_descriptor->set_size(inst.m_n, inst.m_k, inst.m_nr, inst.m_nc);
}
#endif
inst.data_instance = &t.data();
inst._annotation = &t.annotation();
// Note that t might already be an aliasing tensor so we need to take that into
// account.
inst.data_offset = t.get_alias_offset()+offset;
return inst;
}
alias_tensor_const_instance operator() (
const tensor& t,
size_t offset = 0
) const
{
alias_tensor_const_instance temp;
temp.inst = (*this)(const_cast<tensor&>(t),offset);
return temp;
}
private:
mutable alias_tensor_instance inst;
};
inline void serialize(const alias_tensor& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.num_samples(), out);
serialize(item.k(), out);
serialize(item.nr(), out);
serialize(item.nc(), out);
}
inline void deserialize(alias_tensor& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::alias_tensor.");
long long num_samples, k, nr, nc;
deserialize(num_samples, in);
deserialize(k, in);
deserialize(nr, in);
deserialize(nc, in);
item = alias_tensor(num_samples, k, nr, nc);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_TENSOR_H_

View File

@@ -0,0 +1,727 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_DNn_TENSOR_ABSTRACT_H_
#ifdef DLIB_DNn_TENSOR_ABSTRACT_H_
#include "../matrix.h"
#include "../any/any_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class tensor
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents a 4D array of float values, all stored contiguously
in memory. Importantly, it keeps two copies of the floats, one on the host
CPU side and another on the GPU device side. It automatically performs the
necessary host/device transfers to keep these two copies of the data in
sync.
All transfers to the device happen asynchronously with respect to the
default CUDA stream so that CUDA kernel computations can overlap with data
transfers. However, any transfers from the device to the host happen
synchronously in the default CUDA stream. Therefore, you should perform
all your CUDA kernel launches on the default stream so that transfers back
to the host do not happen before the relevant computations have completed.
If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all.
Instead, it will simply store one host side memory block of floats.
Finally, the convention in dlib code is to interpret the tensor as a set of
num_samples() 3D arrays, each of dimension k() by nr() by nc(). Also,
while this class does not specify a memory layout, the convention is to
assume that indexing into an element at coordinates (sample,k,r,c) can be
accomplished via:
host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c]
THREAD SAFETY
Instances of this object are not thread-safe. So don't touch one from
multiple threads at the same time.
!*/
public:
virtual ~tensor();
long long num_samples(
) const;
/*!
ensures
- returns the number of 3D arrays of dimension k() by nr() by nc() there
are in this object.
!*/
long long k(
) const;
/*!
ensures
- returns the k dimension of this tensor. Generally, we think of a tensor
as containing num_samples() images of nr() by nc() rows and columns, each
with k() channels.
!*/
long long nr(
) const;
/*!
ensures
- returns the number of rows in this tensor.
!*/
long long nc(
) const;
/*!
ensures
- returns the number of columns in this tensor.
!*/
size_t size(
) const;
/*!
ensures
- returns num_samples()*k()*nr()*nc()
(i.e. the total number of floats in this tensor)
!*/
void async_copy_to_device(
) const;
/*!
ensures
- This function does not block.
- if (the host version of the data is newer than the device's copy) then
- Begins asynchronously copying host data to the device.
- A call to device() that happens before the transfer completes will
block until the transfer is complete. That is, it is safe to call
async_copy_to_device() and then immediately call device().
!*/
typedef float* iterator;
typedef const float* const_iterator;
iterator begin() { return host(); }
const_iterator begin() const { return host(); }
iterator end() { return host()+size(); }
const_iterator end() const { return host()+size(); }
/*!
ensures
- makes a tensor iterable just like the STL containers.
!*/
virtual const float* host(
) const = 0;
/*!
ensures
- returns a pointer to the host memory block of size() contiguous float
values or nullptr if size()==0.
- if (the host's copy of the data is out of date) then
- copies the data from the device to the host, while this is happening
the call to host() blocks.
!*/
virtual float* host(
) = 0;
/*!
ensures
- returns a pointer to the host memory block of size() contiguous float
values or nullptr if size()==0.
- if (the host's copy of the data is out of date) then
- copies the data from the device to the host, while this is happening
the call to host() blocks.
- Marks the device side data as out of date so that the next call to
device() will perform a host to device transfer. If you want to begin
the transfer immediately then you can call async_copy_to_device() after
calling host().
!*/
virtual float* host_write_only(
) = 0;
/*!
ensures
- This function returns the same pointer as host(), except that it never
performs a device to host memory copy. Instead, it immediately marks the
device side data as out of date, effectively discarding it. Therefore,
the values in the data pointed to by host_write_only() are undefined and
you should only call host_write_only() if you are going to assign to
every memory location in the returned memory block.
!*/
virtual const float* device(
) const = 0;
/*!
requires
- DLIB_USE_CUDA is #defined
ensures
- returns a pointer to the device memory block of size() contiguous float
values or nullptr if size()==0.
- if (the device's copy of the data is out of date) then
- copies the data from the host to the device, while this is happening
the call to device() blocks.
!*/
virtual float* device(
) = 0;
/*!
requires
- DLIB_USE_CUDA is #defined
ensures
- returns a pointer to the device memory block of size() contiguous float
values or nullptr if size()==0.
- if (the device's copy of the data is out of date) then
- copies the data from the host to the device, while this is happening
the call to device() blocks.
- Marks the host side data as out of date so that the next call to
host() will perform a device to host transfer.
!*/
virtual float* device_write_only(
) = 0;
/*!
requires
- DLIB_USE_CUDA is #defined
ensures
- This function returns the same pointer as device(), except that it never
performs a host to device memory copy. Instead, it immediately marks the
host side data as out of date, effectively discarding it. Therefore, the
values in the data pointed to by device_write_only() are undefined and
you should only call device_write_only() if you are going to assign to
every memory location in the returned memory block.
!*/
virtual const any& annotation(
) const = 0;
/*!
ensures
- returns a const reference to the any object in this tensor. The any
object can be used to store any additional annotation you like in a
tensor. However, it should be noted that the annotation() is ignored by
serialize() and therefore not saved when a tensor is serialized.
!*/
virtual any& annotation(
) = 0;
/*!
ensures
- returns a non-const reference to the any object in this tensor. The any
object can be used to store any additional annotation you like in a
tensor. However, it should be noted that the annotation() is ignored by
serialize() and therefore not saved when a tensor is serialized.
!*/
int device_id(
) const;
/*!
ensures
- returns the ID of the CUDA device that allocated this memory. I.e. the
number returned by cudaGetDevice() when the memory was allocated.
- If CUDA is not being used then this function always returns 0.
!*/
tensor& operator= (
float val
);
/*!
ensures
- sets all elements of this tensor equal to val.
- returns *this
!*/
tensor& operator*= (
float val
);
/*!
ensures
- pointwise multiplies all elements of *this tensor with val.
- returns *this
!*/
tensor& operator/= (
float val
);
/*!
ensures
- pointwise divides all elements of *this tensor with val.
- returns *this
!*/
template <typename EXP>
tensor& operator= (
const matrix_exp<EXP>& item
);
/*!
requires
- num_samples() == item.nr()
- k()*nr()*nc() == item.nc()
- item contains float values
ensures
- Assigns item to *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) = item;
!*/
template <typename EXP>
tensor& operator+= (
const matrix_exp<EXP>& item
);
/*!
requires
- num_samples() == item.nr()
- k()*nr()*nc() == item.nc()
- item contains float values
ensures
- Adds item to *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) += item;
!*/
template <typename EXP>
tensor& operator-= (
const matrix_exp<EXP>& item
);
/*!
requires
- num_samples() == item.nr()
- k()*nr()*nc() == item.nc()
- item contains float values
ensures
- Subtracts item from *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) -= item;
!*/
template <typename EXP>
void set_sample (
unsigned long long idx,
const matrix_exp<EXP>& item
);
/*!
requires
- idx < num_samples()
- k()*nr()*nc() == item.size()
- item contains float values
ensures
- Assigns item to the idx'th sample in *this by performing:
set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item;
!*/
template <typename EXP>
void add_to_sample (
unsigned long long idx,
const matrix_exp<EXP>& item
);
/*!
requires
- idx < num_samples()
- k()*nr()*nc() == item.size()
- item contains float values
ensures
- Adds item to the idx'th sample in *this by performing:
set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item;
!*/
protected:
// You can't move or copy another tensor into *this since that might modify the
// tensor's dimensions. If you want to do that sort of thing then use a
// resizable_tensor.
tensor(const tensor& item);
tensor& operator= (const tensor& item);
tensor(tensor&& item);
tensor& operator=(tensor&& item);
};
// ----------------------------------------------------------------------------------------
void memcpy (
tensor& dest,
const tensor& src
);
/*!
requires
- dest.size() == src.size()
ensures
- Copies the data in src to dest. If the device data is current on both src
and dest then the copy will happen entirely on the device side.
- It doesn't matter what GPU device is selected by cudaSetDevice(). You can
always copy tensor objects to and from each other regardless.
- This function blocks until the copy has completed.
!*/
// ----------------------------------------------------------------------------------------
bool is_vector (
const tensor& t
);
/*!
ensures
- returns true if and only if one of the following is true:
- t.size() == t.num_samples()
- t.size() == t.k()
- t.size() == t.nr()
- t.size() == t.nc()
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp mat (
const tensor& t,
long long nr,
long long nc
);
/*!
requires
- nr >= 0
- nc >= 0
- nr*nc == t.size()
ensures
- returns a matrix M such that:
- M.nr() == nr
- m.nc() == nc
- for all valid r and c:
M(r,c) == t.host()[r*nc + c]
(i.e. the tensor is interpreted as a matrix laid out in memory
in row major order)
!*/
const matrix_exp mat (
const tensor& t
);
/*!
ensures
- if (t.size() != 0) then
- returns mat(t, t.num_samples(), t.size()/t.num_samples())
- else
- returns an empty matrix.
!*/
const matrix_exp image_plane (
const tensor& t,
long long sample = 0,
long long k = 0
);
/*!
requires
- t.size() != 0
- 0 <= sample < t.num_samples()
- 0 <= k < t.k()
ensures
- returns the k-th image plane from the sample-th image in t. That is,
returns a matrix M such that:
- M contains float valued elements.
- M.nr() == t.nr()
- M.nc() == t.nc()
- for all valid r and c:
- M(r,c) == t.host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c]
!*/
// ----------------------------------------------------------------------------------------
bool have_same_dimensions (
const tensor& a,
const tensor& b
);
/*!
ensures
- returns true if and only if all of the fallowing are satisfied:
- a.num_samples() == b.num_samples()
- a.k() == b.k()
- a.nr() == b.nr()
- a.nc() == b.nc()
!*/
// ----------------------------------------------------------------------------------------
class resizable_tensor : public tensor
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is just a tensor with the additional ability to be resized.
!*/
public:
resizable_tensor(
);
/*!
ensures
- #size() == 0
- #num_samples() == 0
- #k() == 0
- #nr() == 0
- #nc() == 0
- #capacity() == 0
!*/
template <typename EXP>
resizable_tensor(
const matrix_exp<EXP>& item
);
/*!
requires
- item contains float values
ensures
- #num_samples() == item.nr()
- #k() == item.nc()
- #nr() == 1
- #nc() == 1
- Assigns item to *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) = item;
- #capacity() == size()
!*/
explicit resizable_tensor(
long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
);
/*!
requires
- n_ >= 0
- k_ >= 0
- nr_ >= 0
- nc_ >= 0
ensures
- #size() == n_*k_*nr_*nc_
- #num_samples() == n_
- #k() == k_
- #nr() == nr_
- #nc() == nc_
- #capacity() == size()
!*/
// This object is copyable and movable
resizable_tensor(const resizable_tensor&) = default;
resizable_tensor(resizable_tensor&&) = default;
resizable_tensor& operator= (const resizable_tensor&) = default;
resizable_tensor& operator= (resizable_tensor&&) = default;
size_t capacity (
) const;
/*!
ensures
- returns the total number of floats allocated. This might be different
from the size() since calls to set_size() that make a tensor smaller
don't trigger reallocations. They simply adjust the nominal dimensions
while keeping the same allocated memory block. This makes calls to
set_size() very fast. If you need to deallocate a tensor then use
clear().
!*/
void clear(
);
/*!
ensures
- #size() == 0
- #num_samples() == 0
- #k() == 0
- #nr() == 0
- #nc() == 0
- #annotation().is_empty() == true
- #capacity() == 0
!*/
void copy_size (
const tensor& item
);
/*!
ensures
- resizes *this so that: have_same_dimensions(#*this, item)==true
!*/
void set_size(
long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
);
/*!
requires
- n_ >= 0
- k_ >= 0
- nr_ >= 0
- nc_ >= 0
ensures
- #size() == n_*k_*nr_*nc_
- #num_samples() == n_
- #k() == k_
- #nr() == nr_
- #nc() == nc_
- #capacity() == max(#size(), capacity())
(i.e. capacity() never goes down when calling set_size().)
!*/
template <typename EXP>
resizable_tensor& operator= (
const matrix_exp<EXP>& item
);
/*!
requires
- item contains float values
ensures
- if (num_samples() == item.nr() && k()*nr()*nc() == item.nc()) then
- the dimensions of this tensor are not changed
- else
- #num_samples() == item.nr()
- #k() == item.nc()
- #nr() == 1
- #nc() == 1
- Assigns item to *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) = item;
!*/
};
void serialize(const tensor& item, std::ostream& out);
void deserialize(resizable_tensor& item, std::istream& in);
/*!
provides serialization support for tensor and resizable_tensor. Note that you can
serialize to/from any combination of tenor and resizable_tensor objects.
!*/
// ----------------------------------------------------------------------------------------
double dot(
const tensor& a,
const tensor& b
);
/*!
requires
- a.size() == b.size()
ensures
- returns the dot product between a and b when they are both treated as
a.size() dimensional vectors. That is, this function pointwise multiplies
the vectors together, then sums the result and returns it.
!*/
// ----------------------------------------------------------------------------------------
class alias_tensor_instance : public tensor
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is a tensor that aliases another tensor. That is, it doesn't
have its own block of memory but instead simply holds pointers to the
memory of another tensor object. It therefore allows you to efficiently
break a tensor into pieces and pass those pieces into functions.
An alias_tensor_instance doesn't own the resources it points to in any sense.
So it is important to make sure that the underlying owning tensor doesn't get
destructed before any alias tensors which point to it are destructed.
!*/
// You can't default initialize this object. You can only get instances of it from
// alias_tensor::operator().
alias_tensor_instance(
);
};
class alias_tensor_const_instance
{
/*!
WHAT THIS OBJECT REPRESENTS
This is essentially a const version of alias_tensor_instance and therefore
represents a tensor. However, due to the mechanics of C++, this object
can't inherit from tensor. So instead it provides a get() and an implicit
conversion to const tensor.
!*/
public:
// non-const alias tensors are convertible to const ones.
alias_tensor_const_instance(const alias_tensor_instance& item);
// Methods that cast the alias to a tensor.
const tensor& get() const;
operator const tensor& ();
private:
// You can't default initialize this object. You can only get instances of it from
// alias_tensor::operator().
alias_tensor_const_instance();
};
class alias_tensor
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a tool for creating tensor objects that alias other tensor objects.
That is, it allows you to make a tensor that references the memory space of
another tensor object rather than owning its own memory. This allows you
to do things like interpret a single tensor in different ways or even as a
group of multiple tensors.
!*/
public:
alias_tensor (
);
/*!
ensures
- #size() == 0
- #num_samples() == 0
- #k() == 0
- #nr() == 0
- #nc() == 0
!*/
alias_tensor (
long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
);
/*!
requires
- n_ >= 0
- k_ >= 0
- nr_ >= 0
- nc_ >= 0
ensures
- #size() == n_*k_*nr_*nc_
- #num_samples() == n_
- #k() == k_
- #nr() == nr_
- #nc() == nc_
!*/
long long num_samples() const;
long long k() const;
long long nr() const;
long long nc() const;
size_t size() const;
alias_tensor_instance operator() (
tensor& t,
size_t offset = 0
) const;
/*!
requires
- offset+size() <= t.size()
ensures
- Returns a tensor that simply aliases the elements of t beginning with t's
offset'th element. Specifically, this function returns an aliasing
tensor T such that:
- T.size() == size()
- T.num_samples() == num_samples()
- T.k() == k()
- T.nr() == nr()
- T.nc() == nc()
- T.host() == t.host()+offset
- T.device() == t.device()+offset
- &T.annotation() == &t.annotation()
!*/
alias_tensor_const_instance operator() (
const tensor& t,
size_t offset = 0
) const;
/*!
requires
- offset+size() <= t.size()
ensures
- This function is identical to the above version of operator() except that
it takes and returns const tensors instead of non-const tensors.
!*/
};
void serialize(const alias_tensor& item, std::ostream& out);
void deserialize(alias_tensor& item, std::istream& in);
/*!
provides serialization support for alias_tensor.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_TENSOR_ABSTRACT_H_

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@
#include "data_io/libsvm_io.h"
#include "data_io/image_dataset_metadata.h"
#include "data_io/mnist.h"
#ifndef DLIB_ISO_CPP_ONLY
#include "data_io/load_image_dataset.h"

View File

@@ -1,387 +0,0 @@
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_IMAGE_DAtASET_METADATA_CPPh_
#define DLIB_IMAGE_DAtASET_METADATA_CPPh_
#include "image_dataset_metadata.h"
#include <fstream>
#include <sstream>
#include "../compress_stream.h"
#include "../base64.h"
#include "../xml_parser.h"
#include "../string.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
namespace image_dataset_metadata
{
// ------------------------------------------------------------------------------------
const std::string get_decoded_string();
void create_image_metadata_stylesheet_file(const std::string& main_filename)
{
std::string path;
std::string::size_type pos = main_filename.find_last_of("/\\");
if (pos != std::string::npos)
path = main_filename.substr(0,pos+1);
std::ofstream fout((path + "image_metadata_stylesheet.xsl").c_str());
if (!fout)
throw dlib::error("ERROR: Unable to open image_metadata_stylesheet.xsl for writing.");
fout << get_decoded_string();
if (!fout)
throw dlib::error("ERROR: Unable to write to image_metadata_stylesheet.xsl.");
}
void save_image_dataset_metadata (
const dataset& meta,
const std::string& filename
)
{
create_image_metadata_stylesheet_file(filename);
const std::vector<image>& images = meta.images;
std::ofstream fout(filename.c_str());
if (!fout)
throw dlib::error("ERROR: Unable to open " + filename + " for writing.");
fout << "<?xml version='1.0' encoding='ISO-8859-1'?>\n";
fout << "<?xml-stylesheet type='text/xsl' href='image_metadata_stylesheet.xsl'?>\n";
fout << "<dataset>\n";
fout << "<name>" << meta.name << "</name>\n";
fout << "<comment>" << meta.comment << "</comment>\n";
fout << "<images>\n";
for (unsigned long i = 0; i < images.size(); ++i)
{
fout << " <image file='" << images[i].filename << "'>\n";
// save all the boxes
for (unsigned long j = 0; j < images[i].boxes.size(); ++j)
{
const box& b = images[i].boxes[j];
fout << " <box top='" << b.rect.top() << "' "
<< "left='" << b.rect.left() << "' "
<< "width='" << b.rect.width() << "' "
<< "height='" << b.rect.height() << "'";
if (b.difficult)
fout << " difficult='" << b.difficult << "'";
if (b.truncated)
fout << " truncated='" << b.truncated << "'";
if (b.occluded)
fout << " occluded='" << b.occluded << "'";
if (b.ignore)
fout << " ignore='" << b.ignore << "'";
if (b.angle != 0)
fout << " angle='" << b.angle << "'";
if (b.has_label() || b.parts.size() != 0)
{
fout << ">\n";
if (b.has_label())
fout << " <label>" << b.label << "</label>\n";
// save all the parts
std::map<std::string,point>::const_iterator itr;
for (itr = b.parts.begin(); itr != b.parts.end(); ++itr)
{
fout << " <part name='"<< itr->first << "' x='"<< itr->second.x() <<"' y='"<< itr->second.y() <<"'/>\n";
}
fout << " </box>\n";
}
else
{
fout << "/>\n";
}
}
fout << " </image>\n";
if (!fout)
throw dlib::error("ERROR: Unable to write to " + filename + ".");
}
fout << "</images>\n";
fout << "</dataset>";
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
class doc_handler : public document_handler
{
std::vector<std::string> ts;
image temp_image;
box temp_box;
dataset& meta;
public:
doc_handler(
dataset& metadata_
):
meta(metadata_)
{}
virtual void start_document (
)
{
meta = dataset();
ts.clear();
temp_image = image();
temp_box = box();
}
virtual void end_document (
)
{
}
virtual void start_element (
const unsigned long line_number,
const std::string& name,
const dlib::attribute_list& atts
)
{
try
{
if (ts.size() == 0)
{
if (name != "dataset")
{
std::ostringstream sout;
sout << "Invalid XML document. Root tag must be <dataset>. Found <" << name << "> instead.";
throw dlib::error(sout.str());
}
else
{
ts.push_back(name);
return;
}
}
if (name == "box")
{
if (atts.is_in_list("top")) temp_box.rect.top() = sa = atts["top"];
else throw dlib::error("<box> missing required attribute 'top'");
if (atts.is_in_list("left")) temp_box.rect.left() = sa = atts["left"];
else throw dlib::error("<box> missing required attribute 'left'");
if (atts.is_in_list("width")) temp_box.rect.right() = sa = atts["width"];
else throw dlib::error("<box> missing required attribute 'width'");
if (atts.is_in_list("height")) temp_box.rect.bottom() = sa = atts["height"];
else throw dlib::error("<box> missing required attribute 'height'");
if (atts.is_in_list("difficult")) temp_box.difficult = sa = atts["difficult"];
if (atts.is_in_list("truncated")) temp_box.truncated = sa = atts["truncated"];
if (atts.is_in_list("occluded")) temp_box.occluded = sa = atts["occluded"];
if (atts.is_in_list("ignore")) temp_box.ignore = sa = atts["ignore"];
if (atts.is_in_list("angle")) temp_box.angle = sa = atts["angle"];
temp_box.rect.bottom() += temp_box.rect.top()-1;
temp_box.rect.right() += temp_box.rect.left()-1;
}
else if (name == "part" && ts.back() == "box")
{
point temp;
if (atts.is_in_list("x")) temp.x() = sa = atts["x"];
else throw dlib::error("<part> missing required attribute 'x'");
if (atts.is_in_list("y")) temp.y() = sa = atts["y"];
else throw dlib::error("<part> missing required attribute 'y'");
if (atts.is_in_list("name"))
{
if (temp_box.parts.count(atts["name"])==0)
{
temp_box.parts[atts["name"]] = temp;
}
else
{
throw dlib::error("<part> with name '" + atts["name"] + "' is defined more than one time in a single box.");
}
}
else
{
throw dlib::error("<part> missing required attribute 'name'");
}
}
else if (name == "image")
{
temp_image.boxes.clear();
if (atts.is_in_list("file")) temp_image.filename = atts["file"];
else throw dlib::error("<image> missing required attribute 'file'");
}
ts.push_back(name);
}
catch (error& e)
{
throw dlib::error("Error on line " + cast_to_string(line_number) + ": " + e.what());
}
}
virtual void end_element (
const unsigned long ,
const std::string& name
)
{
ts.pop_back();
if (ts.size() == 0)
return;
if (name == "box" && ts.back() == "image")
{
temp_image.boxes.push_back(temp_box);
temp_box = box();
}
else if (name == "image" && ts.back() == "images")
{
meta.images.push_back(temp_image);
temp_image = image();
}
}
virtual void characters (
const std::string& data
)
{
if (ts.size() == 2 && ts[1] == "name")
{
meta.name = trim(data);
}
else if (ts.size() == 2 && ts[1] == "comment")
{
meta.comment = trim(data);
}
else if (ts.size() >= 2 && ts[ts.size()-1] == "label" &&
ts[ts.size()-2] == "box")
{
temp_box.label = trim(data);
}
}
virtual void processing_instruction (
const unsigned long ,
const std::string& ,
const std::string&
)
{
}
};
// ----------------------------------------------------------------------------------------
class xml_error_handler : public error_handler
{
public:
virtual void error (
const unsigned long
) { }
virtual void fatal_error (
const unsigned long line_number
)
{
std::ostringstream sout;
sout << "There is a fatal error on line " << line_number << " so parsing will now halt.";
throw dlib::error(sout.str());
}
};
// ------------------------------------------------------------------------------------
void load_image_dataset_metadata (
dataset& meta,
const std::string& filename
)
{
xml_error_handler eh;
doc_handler dh(meta);
std::ifstream fin(filename.c_str());
if (!fin)
throw dlib::error("ERROR: unable to open " + filename + " for reading.");
xml_parser parser;
parser.add_document_handler(dh);
parser.add_error_handler(eh);
parser.parse(fin);
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// This function returns the contents of the file 'images.xsl'
const std::string get_decoded_string()
{
dlib::base64 base64_coder;
dlib::compress_stream::kernel_1ea compressor;
std::ostringstream sout;
std::istringstream sin;
// The base64 encoded data from the file 'image_metadata_stylesheet.xsl' we want to decode and return.
sout << "PFWfgmWfCHr1DkV63lbjjeY2dCc2FbHDOVh0Kd7dkvaOfRYrOG24f0x77/5iMVq8FtE3UBxtGwSd";
sout << "1ZHOHRSHgieNoeBv8ssJQ75RRxYtFKRY3OTPX5eKQoCN9jUaUnHnR4QZtEHgmKqXSs50Yrdd+2Ah";
sout << "gNyarPZCiR6nvqNvCjtP2MP5FxleqNf8Fylatm2KdsXmrv5K87LYVN7i7JMkmZ++cTXYSOxDmxZi";
sout << "OiCH8funXUdF9apDW547gCjz9HOQUI6dkz5dYUeFjfp6dFugpnaJyyprFLKq048Qk7+QiL4CNF/G";
sout << "7e0VpBw8dMpiyRNi2fSQGSZGfIAUQKKT6+rPwQoRH2spdjsdXVWj4XQAqBX87nmqMnqjMhn/Vd1s";
sout << "W5aoC0drwRGu3Xe3gn9vBL8hBkRXcJvEy6q/lb9bYnsLemhE5Zp/+nTmTBjfT9UFYLcsmgsjON9M";
sout << "gbE5Q8tCa+WXXXsVP1ai5tLU3G1RUjctr/VtV55PKl2xKjjgb4zDldHKrKFQ23NkQR94PFHG25WM";
sout << "a/VSFVSzLJWdeV/SK3uDq/zUdwQ1JohWp2i+0vJuTXNGCmyT3zHxqtue3HcEw7OpGIDQ+EN0nPCV";
sout << "90Seu55zuS14zuWdRfXln3/g/hiA7Jj72Ah8Kiz3F3gwCfFbyFaMDYTbT4sda0fDkx1M9sKJ2pN8";
sout << "3Jd7T8SU+Rk2/oDc8RuTTbFaRvulLWHfdLGPuIJpLT7FUkxGpdlIvxPypjGf0wVA8kgcYGgoKLIX";
sout << "uUgWFEvqwDJtxvOYVApV2foeOMgfw53TRiFDqwxmaYC41gB32cGgKYuC90mmqGY1MWAD0KLl5+bF";
sout << "GQiRmckXmDmowK5cxibnB5nTyJX1LmXaqkHFFNGPfidznTHoSqtlAF4wnCyBMuCAdgJgC0AF5gr4";
sout << "1KNWDg042CVs3li6Nep6G9arGOkEcL7vWamNC9vvkYwOWidjDFqINBxEWGTRQCyG9RDzPX2dckEh";
sout << "jWYwrXDOFyBNbac16Yym1ftn322+sE+RnaXq9WIoTGnrK/A1paSzdCjpfiIAAizaRnwoa6Ue9xnZ";
sout << "HvSSQetmzyOErvK6IOWu2VwvqO3aOC28RP63JEztmiT7pF+Zl0NMHVWgW13WejABamVXvjDAlMSA";
sout << "iBKSBqTuyC0YbuNk14G2MfQE0pg1QrAHhOi9u2KsTRN56381lxxqAhEEGvI/h+ONsveGuuDjXgcy";
sout << "wvObjIKOawnh820yMrPBzDOx/ExSJtwqbWXBc0MGZxLXA3OgfeKsoaGB/OSB3AznJd40B1ktnmXO";
sout << "pThos8Tl3Cs6xxFdFhob0vf3ml6WumTtNnAA";
// Put the data into the istream sin
sin.str(sout.str());
sout.str("");
// Decode the base64 text into its compressed binary form
base64_coder.decode(sin,sout);
sin.clear();
sin.str(sout.str());
sout.str("");
// Decompress the data into its original form
compressor.decompress(sin,sout);
// Return the decoded and decompressed data
return sout.str();
}
}
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_IMAGE_DAtASET_METADATA_CPPh_

View File

@@ -14,6 +14,15 @@ namespace dlib
namespace image_dataset_metadata
{
// ------------------------------------------------------------------------------------
enum gender_t
{
UNKNOWN,
MALE,
FEMALE
};
// ------------------------------------------------------------------------------------
struct box
@@ -34,7 +43,11 @@ namespace dlib
truncated(false),
occluded(false),
ignore(false),
angle(0)
pose(0),
detection_score(0),
angle(0),
gender(UNKNOWN),
age(0)
{}
box (
@@ -45,7 +58,11 @@ namespace dlib
truncated(false),
occluded(false),
ignore(false),
angle(0)
pose(0),
detection_score(0),
angle(0),
gender(UNKNOWN),
age(0)
{}
rectangle rect;
@@ -58,6 +75,8 @@ namespace dlib
bool truncated;
bool occluded;
bool ignore;
double pose;
double detection_score;
// The angle of the object in radians. Positive values indicate that the
// object at the center of the box is rotated clockwise by angle radians. A
@@ -66,6 +85,9 @@ namespace dlib
// image counter-clockwise by angle radians.
double angle;
gender_t gender;
double age;
bool has_label() const { return label.size() != 0; }
/*!
ensures

View File

@@ -85,12 +85,14 @@ namespace dlib
if (sin.get() != ':')
throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name);
sin >> value >> ws;
sin >> value;
if (sin && value != 0)
{
sample.insert(sample.end(), make_pair(key, value));
}
sin >> ws;
}
samples.push_back(sample);

View File

@@ -14,6 +14,9 @@
#include <string>
#include <set>
#include "../image_processing/full_object_detection.h"
#include <utility>
#include <limits>
#include "../image_transforms/image_pyramid.h"
namespace dlib
@@ -29,6 +32,7 @@ namespace dlib
_skip_empty_images = false;
_have_parts = false;
_filename = filename;
_box_area_thresh = std::numeric_limits<double>::infinity();
}
image_dataset_file boxes_match_label(
@@ -56,6 +60,15 @@ namespace dlib
return temp;
}
image_dataset_file shrink_big_images(
double new_box_area_thresh = 150*150
) const
{
image_dataset_file temp(*this);
temp._box_area_thresh = new_box_area_thresh;
return temp;
}
bool should_load_box (
const image_dataset_metadata::box& box
) const
@@ -72,6 +85,7 @@ namespace dlib
const std::string& get_filename() const { return _filename; }
bool should_skip_empty_images() const { return _skip_empty_images; }
bool should_boxes_have_parts() const { return _have_parts; }
double box_area_thresh() const { return _box_area_thresh; }
const std::set<std::string>& get_selected_box_labels() const { return _labels; }
private:
@@ -79,23 +93,23 @@ namespace dlib
std::set<std::string> _labels;
bool _skip_empty_images;
bool _have_parts;
double _box_area_thresh;
};
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<rectangle> >& object_locations,
const image_dataset_file& source
)
{
images.clear();
object_locations.clear();
const std::string old_working_dir = get_current_dir();
std::vector<std::vector<rectangle> > ignored_rects;
@@ -106,15 +120,17 @@ namespace dlib
// Set the current directory to be the one that contains the
// metadata file. We do this because the file might contain
// file paths which are relative to this folder.
set_current_dir(get_parent_directory(file(source.get_filename())));
locally_change_current_dir chdir(get_parent_directory(file(source.get_filename())));
typedef typename array_type::value_type image_type;
image_type img;
std::vector<rectangle> rects, ignored;
for (unsigned long i = 0; i < data.images.size(); ++i)
{
double min_rect_size = std::numeric_limits<double>::infinity();
rects.clear();
ignored.clear();
for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j)
@@ -122,25 +138,147 @@ namespace dlib
if (source.should_load_box(data.images[i].boxes[j]))
{
if (data.images[i].boxes[j].ignore)
{
ignored.push_back(data.images[i].boxes[j].rect);
}
else
{
rects.push_back(data.images[i].boxes[j].rect);
min_rect_size = std::min<double>(min_rect_size, rects.back().area());
}
}
}
if (!source.should_skip_empty_images() || rects.size() != 0)
{
load_image(img, data.images[i].filename);
if (rects.size() != 0)
{
// if shrinking the image would still result in the smallest box being
// bigger than the box area threshold then shrink the image.
while(min_rect_size/2/2 > source.box_area_thresh())
{
pyramid_down<2> pyr;
pyr(img);
min_rect_size *= (1.0/2.0)*(1.0/2.0);
for (auto&& r : rects)
r = pyr.rect_down(r);
for (auto&& r : ignored)
r = pyr.rect_down(r);
}
while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh())
{
pyramid_down<3> pyr;
pyr(img);
min_rect_size *= (2.0/3.0)*(2.0/3.0);
for (auto&& r : rects)
r = pyr.rect_down(r);
for (auto&& r : ignored)
r = pyr.rect_down(r);
}
}
images.push_back(img);
object_locations.push_back(rects);
ignored_rects.push_back(ignored);
load_image(img, data.images[i].filename);
images.push_back(img);
}
}
set_current_dir(old_working_dir);
return ignored_rects;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
inline size_t num_non_ignored_boxes (const std::vector<mmod_rect>& rects)
{
size_t cnt = 0;
for (auto& b : rects)
{
if (!b.ignore)
cnt++;
}
return cnt;
}
}
template <
typename array_type
>
void load_image_dataset (
array_type& images,
std::vector<std::vector<mmod_rect> >& object_locations,
const image_dataset_file& source
)
{
images.clear();
object_locations.clear();
using namespace dlib::image_dataset_metadata;
dataset data;
load_image_dataset_metadata(data, source.get_filename());
// Set the current directory to be the one that contains the
// metadata file. We do this because the file might contain
// file paths which are relative to this folder.
locally_change_current_dir chdir(get_parent_directory(file(source.get_filename())));
typedef typename array_type::value_type image_type;
image_type img;
std::vector<mmod_rect> rects;
for (unsigned long i = 0; i < data.images.size(); ++i)
{
double min_rect_size = std::numeric_limits<double>::infinity();
rects.clear();
for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j)
{
if (source.should_load_box(data.images[i].boxes[j]))
{
if (data.images[i].boxes[j].ignore)
{
rects.push_back(ignored_mmod_rect(data.images[i].boxes[j].rect));
}
else
{
rects.push_back(mmod_rect(data.images[i].boxes[j].rect));
min_rect_size = std::min<double>(min_rect_size, rects.back().rect.area());
}
rects.back().label = data.images[i].boxes[j].label;
}
}
if (!source.should_skip_empty_images() || impl::num_non_ignored_boxes(rects) != 0)
{
load_image(img, data.images[i].filename);
if (rects.size() != 0)
{
// if shrinking the image would still result in the smallest box being
// bigger than the box area threshold then shrink the image.
while(min_rect_size/2/2 > source.box_area_thresh())
{
pyramid_down<2> pyr;
pyr(img);
min_rect_size *= (1.0/2.0)*(1.0/2.0);
for (auto&& r : rects)
r.rect = pyr.rect_down(r.rect);
}
while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh())
{
pyramid_down<3> pyr;
pyr(img);
min_rect_size *= (2.0/3.0)*(2.0/3.0);
for (auto&& r : rects)
r.rect = pyr.rect_down(r.rect);
}
}
images.push_back(std::move(img));
object_locations.push_back(std::move(rects));
}
}
}
// ----------------------------------------------------------------------------------------
// ******* THIS FUNCTION IS DEPRECATED, you should use another version of load_image_dataset() *******
@@ -167,11 +305,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<rectangle> >& object_locations,
const std::string& filename
)
@@ -179,25 +316,38 @@ namespace dlib
return load_image_dataset(images, object_locations, image_dataset_file(filename));
}
// ----------------------------------------------------------------------------------------
template <
typename array_type
>
void load_image_dataset (
array_type& images,
std::vector<std::vector<mmod_rect>>& object_locations,
const std::string& filename
)
{
load_image_dataset(images, object_locations, image_dataset_file(filename));
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<full_object_detection> >& object_locations,
const image_dataset_file& source,
std::vector<std::string>& parts_list
)
{
typedef typename array_type::value_type image_type;
parts_list.clear();
images.clear();
object_locations.clear();
const std::string old_working_dir = get_current_dir();
using namespace dlib::image_dataset_metadata;
dataset data;
@@ -206,7 +356,7 @@ namespace dlib
// Set the current directory to be the one that contains the
// metadata file. We do this because the file might contain
// file paths which are relative to this folder.
set_current_dir(get_parent_directory(file(source.get_filename())));
locally_change_current_dir chdir(get_parent_directory(file(source.get_filename())));
std::set<std::string> all_parts;
@@ -243,6 +393,7 @@ namespace dlib
std::vector<full_object_detection> object_dets;
for (unsigned long i = 0; i < data.images.size(); ++i)
{
double min_rect_size = std::numeric_limits<double>::infinity();
object_dets.clear();
ignored.clear();
for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j)
@@ -266,20 +417,57 @@ namespace dlib
}
object_dets.push_back(full_object_detection(data.images[i].boxes[j].rect, partlist));
min_rect_size = std::min<double>(min_rect_size, object_dets.back().get_rect().area());
}
}
}
if (!source.should_skip_empty_images() || object_dets.size() != 0)
{
load_image(img, data.images[i].filename);
if (object_dets.size() != 0)
{
// if shrinking the image would still result in the smallest box being
// bigger than the box area threshold then shrink the image.
while(min_rect_size/2/2 > source.box_area_thresh())
{
pyramid_down<2> pyr;
pyr(img);
min_rect_size *= (1.0/2.0)*(1.0/2.0);
for (auto&& r : object_dets)
{
r.get_rect() = pyr.rect_down(r.get_rect());
for (unsigned long k = 0; k < r.num_parts(); ++k)
r.part(k) = pyr.point_down(r.part(k));
}
for (auto&& r : ignored)
{
r = pyr.rect_down(r);
}
}
while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh())
{
pyramid_down<3> pyr;
pyr(img);
min_rect_size *= (2.0/3.0)*(2.0/3.0);
for (auto&& r : object_dets)
{
r.get_rect() = pyr.rect_down(r.get_rect());
for (unsigned long k = 0; k < r.num_parts(); ++k)
r.part(k) = pyr.point_down(r.part(k));
}
for (auto&& r : ignored)
{
r = pyr.rect_down(r);
}
}
}
images.push_back(img);
object_locations.push_back(object_dets);
ignored_rects.push_back(ignored);
load_image(img, data.images[i].filename);
images.push_back(img);
}
}
set_current_dir(old_working_dir);
return ignored_rects;
}
@@ -287,11 +475,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<full_object_detection> >& object_locations,
const image_dataset_file& source
)
@@ -303,11 +490,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<full_object_detection> >& object_locations,
const std::string& filename
)

View File

@@ -36,6 +36,7 @@ namespace dlib
This means that, initially, all boxes will be loaded. Therefore, for all
possible boxes B we have:
- #should_load_box(B) == true
- #box_area_thresh() == infinity
!*/
const std::string& get_filename(
@@ -50,8 +51,9 @@ namespace dlib
) const;
/*!
ensures
- returns true if we are supposed to skip images that don't have any boxes
to load when loading an image dataset using load_image_dataset().
- returns true if we are supposed to skip images that don't have any
non-ignored boxes to load when loading an image dataset using
load_image_dataset().
!*/
image_dataset_file boxes_match_label(
@@ -115,24 +117,45 @@ namespace dlib
- returns false
!*/
image_dataset_file shrink_big_images(
double new_box_area_thresh = 150*150
) const;
/*!
ensures
- returns a copy of *this that is identical in all respects to *this except
that #box_area_thresh() == new_box_area_thresh
!*/
double box_area_thresh(
) const;
/*!
ensures
- If the smallest non-ignored rectangle in an image has an area greater
than box_area_thresh() then we will shrink the image until the area of
the box is about equal to box_area_thresh(). This is useful if you have
a dataset containing very high resolution images and you don't want to
load it in its native high resolution. Setting the box_area_thresh()
allows you to control the resolution of the loaded images.
!*/
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<rectangle> >& object_locations,
const image_dataset_file& source
);
/*!
requires
- image_type == is an implementation of array2d/array2d_kernel_abstract.h
- pixel_traits<typename image_type::type> is defined
- array_type == An array of images. This is anything with an interface that
looks like std::vector<some generic image type> where a "generic image" is
anything that implements the generic image interface defined in
dlib/image_processing/generic_image.h.
ensures
- This routine loads the images and their associated object boxes from the
image metadata file indicated by source.get_filename(). This metadata file
@@ -162,40 +185,88 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<rectangle> >& object_locations,
const std::string& filename
);
/*!
requires
- image_type == is an implementation of array2d/array2d_kernel_abstract.h
- pixel_traits<typename image_type::type> is defined
- array_type == An array of images. This is anything with an interface that
looks like std::vector<some generic image type> where a "generic image" is
anything that implements the generic image interface defined in
dlib/image_processing/generic_image.h.
ensures
- performs: return load_image_dataset(images, object_locations, image_dataset_file(filename));
(i.e. it ignores box labels and therefore loads all the boxes in the dataset)
!*/
// ----------------------------------------------------------------------------------------
template <
typename array_type
>
void load_image_dataset (
array_type& images,
std::vector<std::vector<mmod_rect> >& object_locations,
const image_dataset_file& source
);
/*!
requires
- array_type == An array of images. This is anything with an interface that
looks like std::vector<some generic image type> where a "generic image" is
anything that implements the generic image interface defined in
dlib/image_processing/generic_image.h.
ensures
- This function has essentially the same behavior as the above
load_image_dataset() routines, except here we output to a vector of
mmod_rects instead of rectangles. In this case, both ignore and non-ignore
rectangles go into object_locations since mmod_rect has an ignore boolean
field that records the ignored/non-ignored state of each rectangle. We also store
a each box's string label into the mmod_rect::label field as well.
!*/
// ----------------------------------------------------------------------------------------
template <
typename array_type
>
void load_image_dataset (
array_type& images,
std::vector<std::vector<mmod_rect> >& object_locations,
const std::string& filename
);
/*!
requires
- array_type == An array of images. This is anything with an interface that
looks like std::vector<some generic image type> where a "generic image" is
anything that implements the generic image interface defined in
dlib/image_processing/generic_image.h.
ensures
- performs: load_image_dataset(images, object_locations, image_dataset_file(filename));
(i.e. it ignores box labels and therefore loads all the boxes in the dataset)
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<full_object_detection> >& object_locations,
const image_dataset_file& source,
std::vector<std::string>& parts_list
);
/*!
requires
- image_type == is an implementation of array2d/array2d_kernel_abstract.h
- pixel_traits<typename image_type::type> is defined
- array_type == An array of images. This is anything with an interface that
looks like std::vector<some generic image type> where a "generic image" is
anything that implements the generic image interface defined in
dlib/image_processing/generic_image.h.
ensures
- This routine loads the images and their associated object locations from the
image metadata file indicated by source.get_filename(). This metadata file
@@ -237,18 +308,19 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<full_object_detection> >& object_locations,
const image_dataset_file& source
);
/*!
requires
- image_type == is an implementation of array2d/array2d_kernel_abstract.h
- pixel_traits<typename image_type::type> is defined
- array_type == An array of images. This is anything with an interface that
looks like std::vector<some generic image type> where a "generic image" is
anything that implements the generic image interface defined in
dlib/image_processing/generic_image.h.
ensures
- performs: return load_image_dataset(images, object_locations, source, parts_list);
(i.e. this function simply calls the above function and discards the output
@@ -259,18 +331,19 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename image_type,
typename MM
typename array_type
>
std::vector<std::vector<rectangle> > load_image_dataset (
array<image_type,MM>& images,
array_type& images,
std::vector<std::vector<full_object_detection> >& object_locations,
const std::string& filename
);
/*!
requires
- image_type == is an implementation of array2d/array2d_kernel_abstract.h
- pixel_traits<typename image_type::type> is defined
- array_type == An array of images. This is anything with an interface that
looks like std::vector<some generic image type> where a "generic image" is
anything that implements the generic image interface defined in
dlib/image_processing/generic_image.h.
ensures
- performs: return load_image_dataset(images, object_locations, image_dataset_file(filename));
(i.e. it ignores box labels and therefore loads all the boxes in the dataset)

View File

@@ -0,0 +1,32 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MNIST_Hh_
#define DLIB_MNIST_Hh_
#include "mnist_abstract.h"
#include <string>
#include <vector>
#include "../matrix.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
void load_mnist_dataset (
const std::string& folder_name,
std::vector<matrix<unsigned char> >& training_images,
std::vector<unsigned long>& training_labels,
std::vector<matrix<unsigned char> >& testing_images,
std::vector<unsigned long>& testing_labels
);
}
// ----------------------------------------------------------------------------------------
#ifdef NO_MAKEFILE
#include "mnist.cpp"
#endif
#endif // DLIB_MNIST_Hh_

View File

@@ -0,0 +1,46 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_MNIST_ABSTRACT_Hh_
#ifdef DLIB_MNIST_ABSTRACT_Hh_
#include <string>
#include <vector>
#include "../matrix.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
void load_mnist_dataset (
const std::string& folder_name,
std::vector<matrix<unsigned char> >& training_images,
std::vector<unsigned long>& training_labels,
std::vector<matrix<unsigned char> >& testing_images,
std::vector<unsigned long>& testing_labels
);
/*!
ensures
- Attempts to load the MNIST dataset from the hard drive. This is the dataset
of handwritten digits available from http://yann.lecun.com/exdb/mnist/. In
particular, the 4 files comprising the MNIST dataset should be present in the
folder indicated by folder_name. These four files are:
- train-images-idx3-ubyte
- train-labels-idx1-ubyte
- t10k-images-idx3-ubyte
- t10k-labels-idx1-ubyte
- #training_images == The 60,000 training images from the dataset.
- #training_labels == The labels for the contents of #training_images.
I.e. #training_labels[i] is the label of #training_images[i].
- #testing_images == The 10,000 testing images from the dataset.
- #testing_labels == The labels for the contents of #testing_images.
I.e. #testing_labels[i] is the label of #testing_images[i].
throws
- dlib::error if some problem prevents us from loading the data or the files
can't be found.
!*/
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_MNIST_ABSTRACT_Hh_

View File

@@ -1,87 +0,0 @@
// Copyright (C) 2009 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DIR_NAV_EXTENSIONs_CPP_
#define DLIB_DIR_NAV_EXTENSIONs_CPP_
#include "dir_nav_extensions.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace implementation_details
{
void get_all_sub_dirs (
const directory& top_of_tree,
unsigned long max_depth,
std::vector<directory>& result,
std::vector<directory>& temp
)
{
if (max_depth > 0)
{
top_of_tree.get_dirs(temp);
const unsigned long start = result.size();
result.insert(result.end(), temp.begin(), temp.end());
const unsigned long end = start + temp.size();
for (unsigned long i = start; i < end; ++i)
{
get_all_sub_dirs(result[i], max_depth-1, result, temp);
}
}
}
}
// ----------------------------------------------------------------------------------------
bool file_exists (
const std::string& filename
)
{
try
{
dlib::file temp(filename);
return true;
}
catch (file::file_not_found&)
{
return false;
}
}
// ----------------------------------------------------------------------------------------
directory get_parent_directory (
const directory& dir
)
{
return dir.get_parent();
}
// ----------------------------------------------------------------------------------------
directory get_parent_directory (
const file& f
)
{
if (f.full_name().size() == 0)
return directory();
std::string::size_type pos = f.full_name().find_last_of("\\/");
if (pos == std::string::npos)
return directory();
return directory(f.full_name().substr(0,pos));
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DIR_NAV_EXTENSIONs_CPP_

View File

@@ -146,6 +146,20 @@ namespace dlib
const file& f
);
// ----------------------------------------------------------------------------------------
std::string select_oldest_file (
const std::string& filename1,
const std::string& filename2
);
// ----------------------------------------------------------------------------------------
std::string select_newest_file (
const std::string& filename1,
const std::string& filename2
);
// ----------------------------------------------------------------------------------------
}

View File

@@ -165,6 +165,35 @@ namespace dlib
- returns a default initialized directory (i.e. directory())
!*/
// ----------------------------------------------------------------------------------------
std::string select_oldest_file (
const std::string& filename1,
const std::string& filename2
);
/*!
ensures
- Checks the last modification times of the two given files and returns the
filename of the oldest file, i.e., the file that has gone longest since being
modified. Ties are broken arbitrarily.
- For the purpose of comparison, a file that doesn't exist is presumed to have
a last modification time of -infinity (i.e. very far in the past).
!*/
// ----------------------------------------------------------------------------------------
std::string select_newest_file (
const std::string& filename1,
const std::string& filename2
);
/*!
ensures
- Checks the last modification times of the two given files and returns the
filename that was most recently modified. Ties are broken arbitrarily.
- For the purpose of comparison, a file that doesn't exist is presumed to have
a last modification time of -infinity (i.e. very far in the past).
!*/
// ----------------------------------------------------------------------------------------
}

View File

@@ -1,252 +0,0 @@
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DIR_NAV_KERNEL_1_CPp_
#define DLIB_DIR_NAV_KERNEL_1_CPp_
#include "../platform.h"
#ifdef WIN32
#include "dir_nav_kernel_1.h"
#include "../string.h"
#ifdef __BORLANDC__
// Apparently the borland compiler doesn't define this.
#define INVALID_FILE_ATTRIBUTES ((DWORD)-1)
#endif
namespace dlib
{
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// file object implementation
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void file::
init (
const std::string& name
)
{
using namespace std;
char buf[3000];
char* str;
if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0)
{
// the file was not found
throw file_not_found("Unable to find file " + name);
}
state.full_name = buf;
string::size_type pos = state.full_name.find_last_of(directory::get_separator());
if (pos == string::npos)
{
// no valid full path has no separator characters.
throw file_not_found("Unable to find file " + name);
}
state.name = state.full_name.substr(pos+1);
// now find the size of this file
WIN32_FIND_DATAA data;
HANDLE ffind = FindFirstFileA(state.full_name.c_str(), &data);
if (ffind == INVALID_HANDLE_VALUE ||
(data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) != 0)
{
throw file_not_found("Unable to find file " + name);
}
else
{
uint64 temp = data.nFileSizeHigh;
temp <<= 32;
temp |= data.nFileSizeLow;
state.file_size = temp;
FindClose(ffind);
}
}
// ----------------------------------------------------------------------------------------
bool file::
operator == (
const file& rhs
) const
{
using namespace std;
if (state.full_name.size() != rhs.state.full_name.size())
return false;
// compare the strings but ignore the case because file names
// are not case sensitive on windows
return tolower(state.full_name) == tolower(rhs.state.full_name);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// directory object implementation
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void directory::
init (
const std::string& name
)
{
using namespace std;
char buf[3000];
char* str;
if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0)
{
// the directory was not found
throw dir_not_found("Unable to find directory " + name);
}
state.full_name = buf;
const char sep = get_separator();
if (is_root_path(state.full_name) == false)
{
// ensure that thre is not a trialing separator
if (state.full_name[state.full_name.size()-1] == sep)
state.full_name.erase(state.full_name.size()-1);
// pick out the directory name
string::size_type pos = state.full_name.find_last_of(sep);
state.name = state.full_name.substr(pos+1);
}
else
{
// ensure that there is a trailing separator
if (state.full_name[state.full_name.size()-1] != sep)
state.full_name += sep;
}
// now check that this is actually a valid directory
DWORD attribs = GetFileAttributesA(state.full_name.c_str());
if (attribs == INVALID_FILE_ATTRIBUTES ||
(attribs&FILE_ATTRIBUTE_DIRECTORY) == 0)
{
// the directory was not found
throw dir_not_found("Unable to find directory " + name);
}
}
// ----------------------------------------------------------------------------------------
char directory::
get_separator (
)
{
return '\\';
}
// ----------------------------------------------------------------------------------------
bool directory::
operator == (
const directory& rhs
) const
{
using namespace std;
if (state.full_name.size() != rhs.state.full_name.size())
return false;
// compare the strings but ignore the case because file names
// are not case sensitive on windows
return tolower(state.full_name) == tolower(rhs.state.full_name);
}
// ----------------------------------------------------------------------------------------
const directory directory::
get_parent (
) const
{
using namespace std;
// if *this is the root then just return *this
if (is_root())
{
return *this;
}
else
{
directory temp;
const char sep = get_separator();
string::size_type pos = state.full_name.find_last_of(sep);
temp.state.full_name = state.full_name.substr(0,pos);
if ( is_root_path(temp.state.full_name))
{
temp.state.full_name += sep;
}
else
{
pos = temp.state.full_name.find_last_of(sep);
if (pos != string::npos)
{
temp.state.name = temp.state.full_name.substr(pos+1);
}
else
{
temp.state.full_name += sep;
}
}
return temp;
}
}
// ----------------------------------------------------------------------------------------
bool directory::
is_root_path (
const std::string& path
) const
{
using namespace std;
const char sep = get_separator();
bool root_path = false;
if (path.size() > 2 && path[0] == sep && path[1] == sep)
{
// in this case this is a windows share path
string::size_type pos = path.find_first_of(sep,2);
if (pos != string::npos)
{
pos = path.find_first_of(sep,pos+1);
if (pos == string::npos && path[path.size()-1] != sep)
root_path = true;
else if (pos == path.size()-1)
root_path = true;
}
}
else if ( (path.size() == 2 || path.size() == 3) && path[1] == ':')
{
// if this is a valid windows path then it must be a root path
root_path = true;
}
return root_path;
}
// ----------------------------------------------------------------------------------------
}
#endif // WIN32
#endif // DLIB_DIR_NAV_KERNEL_1_CPp_

View File

@@ -21,6 +21,7 @@
#include "../stl_checked.h"
#include "../enable_if.h"
#include "../queue.h"
#include <chrono>
namespace dlib
{
@@ -39,11 +40,13 @@ namespace dlib
state.name == name()
state.full_name == full_name()
state.file_size == size()
state.last_modified == last_modified()
CONVENTION
state.name == name()
state.full_name == full_name()
state.file_size == size()
state.last_modified == last_modified()
!*/
@@ -54,6 +57,7 @@ namespace dlib
uint64 file_size;
std::string name;
std::string full_name;
std::chrono::time_point<std::chrono::system_clock> last_modified;
};
@@ -66,12 +70,14 @@ namespace dlib
const std::string& name,
const std::string& full_name,
const uint64 file_size,
const std::chrono::time_point<std::chrono::system_clock>& last_modified,
private_constructor
)
{
state.file_size = file_size;
state.name = name;
state.full_name = full_name;
state.last_modified = last_modified;
}
@@ -107,6 +113,9 @@ namespace dlib
inline uint64 size (
) const { return state.file_size; }
inline std::chrono::time_point<std::chrono::system_clock> last_modified (
) const { return state.last_modified; }
bool operator == (
const file& rhs
) const;
@@ -413,8 +422,15 @@ namespace dlib
uint64 file_size = data.nFileSizeHigh;
file_size <<= 32;
file_size |= data.nFileSizeLow;
ULARGE_INTEGER ull;
ull.LowPart = data.ftLastWriteTime.dwLowDateTime;
ull.HighPart = data.ftLastWriteTime.dwHighDateTime;
std::chrono::nanoseconds epoch(100 * (ull.QuadPart - 116444736000000000));
auto last_modified = std::chrono::time_point<std::chrono::system_clock>(std::chrono::duration_cast<std::chrono::system_clock::duration>(epoch));
// this is a file so add it to the queue
file temp(data.cFileName,path+data.cFileName,file_size, private_constructor());
file temp(data.cFileName,path+data.cFileName,file_size, last_modified, private_constructor());
files.enqueue(temp);
}

View File

@@ -1,248 +0,0 @@
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DIR_NAV_KERNEL_2_CPp_
#define DLIB_DIR_NAV_KERNEL_2_CPp_
#include "../platform.h"
#ifdef POSIX
#include "dir_nav_kernel_2.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// file object implementation
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void file::
init (
const std::string& name
)
{
using namespace std;
char buf[PATH_MAX];
if (realpath(name.c_str(),buf) == 0)
{
// the file was not found
throw file_not_found("Unable to find file " + name);
}
state.full_name = buf;
string::size_type pos = state.full_name.find_last_of(directory::get_separator());
if (pos == string::npos)
{
// no valid full path has no separtor characters.
throw file_not_found("Unable to find file " + name);
}
state.name = state.full_name.substr(pos+1);
// now find the size of this file
struct stat64 buffer;
if (::stat64(state.full_name.c_str(), &buffer) ||
S_ISDIR(buffer.st_mode))
{
// there was an error during the call to stat64 or
// name is actually a directory
throw file_not_found("Unable to find file " + name);
}
else
{
state.file_size = static_cast<uint64>(buffer.st_size);
}
}
// ----------------------------------------------------------------------------------------
bool file::
operator == (
const file& rhs
) const
{
using namespace std;
if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0)
return true;
// These files might have different names but actually represent the same
// file due to the presence of symbolic links.
char buf[PATH_MAX];
string left, right;
if (realpath(state.full_name.c_str(),buf) == 0)
return false;
left = buf;
if (realpath(rhs.state.full_name.c_str(),buf) == 0)
return false;
right = buf;
return (left == right);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// directory object implementation
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void directory::
init (
const std::string& name
)
{
using namespace std;
char buf[PATH_MAX];
if (realpath(name.c_str(),buf) == 0)
{
// the directory was not found
throw dir_not_found("Unable to find directory " + name);
}
state.full_name = buf;
const char sep = get_separator();
if (is_root_path(state.full_name) == false)
{
// ensure that thre is not a trialing separator
if (state.full_name[state.full_name.size()-1] == sep)
state.full_name.erase(state.full_name.size()-1);
// pick out the directory name
string::size_type pos = state.full_name.find_last_of(sep);
state.name = state.full_name.substr(pos+1);
}
else
{
// ensure that there is a trailing separator
if (state.full_name[state.full_name.size()-1] != sep)
state.full_name += sep;
}
struct stat64 buffer;
// now check that this is actually a valid directory
if (::stat64(state.full_name.c_str(),&buffer))
{
// the directory was not found
throw dir_not_found("Unable to find directory " + name);
}
else if (S_ISDIR(buffer.st_mode) == 0)
{
// It is not a directory
throw dir_not_found("Unable to find directory " + name);
}
}
// ----------------------------------------------------------------------------------------
char directory::
get_separator (
)
{
return '/';
}
// ----------------------------------------------------------------------------------------
bool directory::
operator == (
const directory& rhs
) const
{
using namespace std;
if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0)
return true;
// These directories might have different names but actually represent the same
// directory due to the presence of symbolic links.
char buf[PATH_MAX];
string left, right;
if (realpath(state.full_name.c_str(),buf) == 0)
return false;
left = buf;
if (realpath(rhs.state.full_name.c_str(),buf) == 0)
return false;
right = buf;
return (left == right);
}
// ----------------------------------------------------------------------------------------
const directory directory::
get_parent (
) const
{
using namespace std;
// if *this is the root then just return *this
if (is_root())
{
return *this;
}
else
{
directory temp;
const char sep = get_separator();
string::size_type pos = state.full_name.find_last_of(sep);
temp.state.full_name = state.full_name.substr(0,pos);
if ( is_root_path(temp.state.full_name))
{
temp.state.full_name += sep;
}
else
{
pos = temp.state.full_name.find_last_of(sep);
if (pos != string::npos)
{
temp.state.name = temp.state.full_name.substr(pos+1);
}
else
{
temp.state.full_name += sep;
}
}
return temp;
}
}
// ----------------------------------------------------------------------------------------
bool directory::
is_root_path (
const std::string& path
) const
{
const char sep = get_separator();
if (path.size() == 1 && path[0] == sep)
return true;
else
return false;
}
// ----------------------------------------------------------------------------------------
}
#endif // POSIX
#endif // DLIB_DIR_NAV_KERNEL_2_CPp_

View File

@@ -22,6 +22,7 @@
#include <sys/stat.h>
#include <errno.h>
#include <stdlib.h>
#include <chrono>
#if !defined(__USE_LARGEFILE64 ) && !defined(_LARGEFILE64_SOURCE)
#define stat64 stat
@@ -48,11 +49,13 @@ namespace dlib
state.name == name()
state.full_name == full_name()
state.file_size == size()
state.last_modified == last_modified()
CONVENTION
state.name == name()
state.full_name == full_name()
state.file_size == size()
state.last_modified == last_modified()
!*/
@@ -63,6 +66,7 @@ namespace dlib
uint64 file_size;
std::string name;
std::string full_name;
std::chrono::time_point<std::chrono::system_clock> last_modified;
};
void init(const std::string& name);
@@ -74,12 +78,14 @@ namespace dlib
const std::string& name,
const std::string& full_name,
const uint64 file_size,
const std::chrono::time_point<std::chrono::system_clock>& last_modified,
private_constructor
)
{
state.file_size = file_size;
state.name = name;
state.full_name = full_name;
state.last_modified = last_modified;
}
@@ -110,6 +116,9 @@ namespace dlib
inline uint64 size (
) const { return state.file_size; }
inline std::chrono::time_point<std::chrono::system_clock> last_modified (
) const { return state.last_modified; }
operator std::string (
) const { return full_name(); }
@@ -383,6 +392,10 @@ namespace dlib
{
file_size = static_cast<uint64>(buffer.st_size);
}
auto last_modified = std::chrono::system_clock::from_time_t(buffer.st_mtime);
#ifdef _BSD_SOURCE
last_modified += std::chrono::duration_cast<std::chrono::system_clock::duration>(std::chrono::nanoseconds(buffer.st_atim.tv_nsec));
#endif
if (S_ISDIR(buffer.st_mode) == 0)
{
@@ -391,6 +404,7 @@ namespace dlib
data->d_name,
path+data->d_name,
file_size,
last_modified,
file::private_constructor()
);
files.enqueue(temp);

View File

@@ -7,6 +7,7 @@
#include <vector>
#include "../uintn.h"
#include "../algs.h"
#include <chrono>
namespace dlib
{
@@ -139,6 +140,13 @@ namespace dlib
- returns the size of this file in bytes.
!*/
std::chrono::time_point<std::chrono::system_clock> last_modified (
) const;
/*!
ensures
- returns the time the file was last modified.
!*/
operator std::string (
) const;
/*!

View File

@@ -3,12 +3,13 @@
#ifndef DLIB_DIRECTED_GRAPH_KERNEl_1_
#define DLIB_DIRECTED_GRAPH_KERNEl_1_
#include <memory>
#include <vector>
#include "../serialize.h"
#include "../noncopyable.h"
#include "../std_allocator.h"
#include "../smart_pointers.h"
#include "../algs.h"
#include <vector>
#include "directed_graph_kernel_abstract.h"
#include "../is_kind.h"
@@ -357,18 +358,18 @@ namespace dlib
private:
friend class directed_graph_kernel_1;
typedef std_allocator<node_type*,mem_manager> alloc_type;
typedef std_allocator<shared_ptr<E>,mem_manager> alloc_edge_type;
typedef std_allocator<std::shared_ptr<E>,mem_manager> alloc_edge_type;
std::vector<node_type*,alloc_type> parents;
std::vector<node_type*,alloc_type> children;
std::vector<shared_ptr<E>,alloc_edge_type> edge_parents;
std::vector<shared_ptr<E>,alloc_edge_type> edge_children;
std::vector<std::shared_ptr<E>,alloc_edge_type> edge_parents;
std::vector<std::shared_ptr<E>,alloc_edge_type> edge_children;
unsigned long idx;
};
private:
typedef std_allocator<shared_ptr<node_type>,mem_manager> alloc_type;
typedef std::vector<shared_ptr<node_type>, alloc_type> vector_type;
typedef std_allocator<std::shared_ptr<node_type>,mem_manager> alloc_type;
typedef std::vector<std::shared_ptr<node_type>, alloc_type> vector_type;
vector_type nodes;
};
@@ -574,7 +575,7 @@ namespace dlib
p.children.push_back(&c);
c.parents.push_back(&p);
p.edge_children.push_back(shared_ptr<E>(new E));
p.edge_children.push_back(std::shared_ptr<E>(new E));
c.edge_parents.push_back(p.edge_children.back());
}
catch (...)
@@ -632,7 +633,7 @@ namespace dlib
{
try
{
shared_ptr<node_type> n(new node_type);
std::shared_ptr<node_type> n(new node_type);
n->idx = nodes.size();
nodes.push_back(n);
return n->idx;

View File

@@ -5,6 +5,7 @@
#include "disjoint_subsets/disjoint_subsets.h"
#include "disjoint_subsets/disjoint_subsets_sized.h"
#endif // DLIB_DISJOINt_SUBSETS_

View File

@@ -17,7 +17,7 @@ namespace dlib
public:
void clear (
)
) noexcept
{
items.clear();
}
@@ -34,22 +34,22 @@ namespace dlib
}
}
unsigned long size (
) const
size_t size (
) const noexcept
{
return items.size();
}
unsigned long find_set (
unsigned long item
unsigned long item
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(item < size(),
DLIB_ASSERT(item < size(),
"\t unsigned long disjoint_subsets::find_set()"
<< "\n\t item must be less than size()"
<< "\n\t item: " << item
<< "\n\t size(): " << size()
<< "\n\t item: " << item
<< "\n\t size(): " << size()
<< "\n\t this: " << this
);
@@ -88,16 +88,16 @@ namespace dlib
// make sure requires clause is not broken
DLIB_ASSERT(a != b &&
a < size() &&
b < size() &&
b < size() &&
find_set(a) == a &&
find_set(b) == b,
"\t unsigned long disjoint_subsets::merge_sets(a,b)"
<< "\n\t invalid arguments were given to this function"
<< "\n\t a: " << a
<< "\n\t b: " << b
<< "\n\t size(): " << size()
<< "\n\t find_set(a): " << find_set(a)
<< "\n\t find_set(b): " << find_set(b)
<< "\n\t a: " << a
<< "\n\t b: " << b
<< "\n\t size(): " << size()
<< "\n\t find_set(a): " << find_set(a)
<< "\n\t find_set(b): " << find_set(b)
<< "\n\t this: " << this
);
@@ -139,4 +139,3 @@ namespace dlib
}
#endif // DLIB_DISJOINT_SUBsETS_Hh_

View File

@@ -20,13 +20,13 @@ namespace dlib
WHAT THIS OBJECT REPRESENTS
This object represents a set of integers which is partitioned into
a number of disjoint subsets. It supports the two fundamental operations
of finding which subset a particular integer belongs to as well as
of finding which subset a particular integer belongs to as well as
merging subsets.
!*/
public:
void clear (
);
) noexcept;
/*!
ensures
- #size() == 0
@@ -44,29 +44,29 @@ namespace dlib
(i.e. this object contains new_size subsets, each containing exactly one element)
!*/
unsigned long size (
) const;
size_t size (
) const noexcept;
/*!
ensures
- returns the total number of integer elements represented
by this object.
by this object.
!*/
unsigned long find_set (
unsigned long item
unsigned long item
) const;
/*!
requires
- item < size()
ensures
- Each disjoint subset can be represented by any of its elements (since
the sets are all disjoint). In particular, for each subset we define
a special "representative element" which is used to represent it.
Therefore, this function returns the representative element for the
- Each disjoint subset can be represented by any of its elements (since
the sets are all disjoint). In particular, for each subset we define
a special "representative element" which is used to represent it.
Therefore, this function returns the representative element for the
set which contains item.
- find_set(find_set(item)) == find_set(item)
- Note that if A and B are both elements of the same subset then we always
have find_set(A) == find_set(B).
have find_set(A) == find_set(B).
!*/
unsigned long merge_sets (
@@ -87,7 +87,6 @@ namespace dlib
(i.e. merges the set's containing a and b)
- returns #find_set(a)
!*/
};
// ----------------------------------------------------------------------------------------
@@ -95,5 +94,3 @@ namespace dlib
}
#endif // DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_

Some files were not shown because too many files have changed in this diff Show More