diff options
author | Chris Robinson <[email protected]> | 2020-06-19 16:43:09 -0700 |
---|---|---|
committer | Chris Robinson <[email protected]> | 2020-06-19 16:43:09 -0700 |
commit | a01dbeb09f0a6cdb1f2946b9fdf6c16c8b979066 (patch) | |
tree | 4ba291551c5bde6c97043aeec77a5e08c0c87af8 | |
parent | bc91e4d302b8acb882a38948458c828a8e0e8726 (diff) |
Use multiple threads for magnitude calculations
Also allow user-declarable thread count.
-rw-r--r-- | utils/makemhr/loadsofa.cpp | 101 | ||||
-rw-r--r-- | utils/makemhr/loadsofa.h | 4 | ||||
-rw-r--r-- | utils/makemhr/makemhr.cpp | 48 |
3 files changed, 107 insertions, 46 deletions
diff --git a/utils/makemhr/loadsofa.cpp b/utils/makemhr/loadsofa.cpp index f53a0b6b..ec000d72 100644 --- a/utils/makemhr/loadsofa.cpp +++ b/utils/makemhr/loadsofa.cpp @@ -28,6 +28,7 @@ #include <chrono> #include <cmath> #include <cstdio> +#include <functional> #include <future> #include <iterator> #include <memory> @@ -335,8 +336,46 @@ static bool LoadResponses(MYSOFA_HRTF *sofaHrtf, HrirDataT *hData) } -bool LoadSofaFile(const char *filename, const uint fftSize, const uint truncSize, - const ChannelModeT chanMode, HrirDataT *hData) +/* Calculates the frequency magnitudes of the HRIR set. Work is delegated to + * this struct, which runs asynchronously on one or more threads (sharing the + * same calculator object). + */ +struct MagCalculator { + const uint mFftSize{}; + const uint mIrPoints{}; + std::vector<double*> mIrs{}; + std::atomic<size_t> mCurrent{}; + std::atomic<size_t> mDone{}; + + void Worker() + { + auto htemp = std::vector<complex_d>(mFftSize); + + while(1) + { + /* Load the current index to process. */ + size_t idx{mCurrent.load()}; + do { + /* If the index is at the end, we're done. */ + if(idx >= mIrs.size()) + return; + /* Otherwise, increment the current index atomically so other + * threads know to go to the next one. If this call fails, the + * current index was just changed by another thread and the new + * value is loaded into idx, which we'll recheck. + */ + } while(!mCurrent.compare_exchange_weak(idx, idx+1, std::memory_order_relaxed)); + + CalcHrirMagnitude(mIrPoints, mFftSize, htemp, mIrs[idx]); + + /* Increment the number of IRs done. */ + mDone.fetch_add(1); + } + } +}; + +bool LoadSofaFile(const char *filename, const uint numThreads, const uint fftSize, + const uint truncSize, const ChannelModeT chanMode, HrirDataT *hData) { int err; MySofaHrtfPtr sofaHrtf{mysofa_load(filename, &err)}; @@ -480,28 +519,6 @@ bool LoadSofaFile(const char *filename, const uint fftSize, const uint truncSize } return true; }; - auto magnitude_proc = [hData,channels,&hrir_done]() -> bool - { - /* Temporary buffers used to calculate the IR's frequency magnitudes. */ - auto htemp = std::vector<complex_d>(hData->mFftSize); - - for(uint fi{0u};fi < hData->mFdCount;fi++) - { - for(uint ei{hData->mFds[fi].mEvStart};ei < hData->mFds[fi].mEvCount;ei++) - { - for(uint ai{0};ai < hData->mFds[fi].mEvs[ei].mAzCount;ai++) - { - HrirAzT &azd = hData->mFds[fi].mEvs[ei].mAzs[ai]; - for(uint ti{0};ti < channels;ti++) - { - hrir_done.fetch_add(1u, std::memory_order_acq_rel); - CalcHrirMagnitude(hData->mIrPoints, hData->mFftSize, htemp, azd.mIrs[ti]); - } - } - } - } - return true; - }; std::future_status load_status{}; auto load_future = std::async(std::launch::async, onset_proc); @@ -514,14 +531,38 @@ bool LoadSofaFile(const char *filename, const uint fftSize, const uint truncSize if(!load_future.get()) return false; - hrir_done.store(0u, std::memory_order_relaxed); - load_future = std::async(std::launch::async, magnitude_proc); + MagCalculator calculator{hData->mFftSize, hData->mIrPoints}; + for(uint fi{0u};fi < hData->mFdCount;fi++) + { + for(uint ei{hData->mFds[fi].mEvStart};ei < hData->mFds[fi].mEvCount;ei++) + { + for(uint ai{0};ai < hData->mFds[fi].mEvs[ei].mAzCount;ai++) + { + HrirAzT &azd = hData->mFds[fi].mEvs[ei].mAzs[ai]; + for(uint ti{0};ti < channels;ti++) + calculator.mIrs.push_back(azd.mIrs[ti]); + } + } + } + + std::vector<std::thread> thrds; + thrds.reserve(numThreads); + for(size_t i{0};i < numThreads;++i) + thrds.emplace_back(std::mem_fn(&MagCalculator::Worker), &calculator); + size_t count; do { - load_status = load_future.wait_for(std::chrono::milliseconds{50}); - printf("\rCalculating HRIR magnitudes... %zu of %zu", hrir_done.load(), hrir_total); + std::this_thread::sleep_for(std::chrono::milliseconds{50}); + count = calculator.mDone.load(); + + printf("\rCalculating HRIR magnitudes... %zu of %zu", count, calculator.mIrs.size()); fflush(stdout); - } while(load_status != std::future_status::ready); + } while(count != calculator.mIrs.size()); fputc('\n', stdout); - return load_future.get(); + for(auto &thrd : thrds) + { + if(thrd.joinable()) + thrd.join(); + } + return true; } diff --git a/utils/makemhr/loadsofa.h b/utils/makemhr/loadsofa.h index 93bf1704..803bcf88 100644 --- a/utils/makemhr/loadsofa.h +++ b/utils/makemhr/loadsofa.h @@ -4,7 +4,7 @@ #include "makemhr.h" -bool LoadSofaFile(const char *filename, const uint fftSize, const uint truncSize, - const ChannelModeT chanMode, HrirDataT *hData); +bool LoadSofaFile(const char *filename, const uint numThreads, const uint fftSize, + const uint truncSize, const ChannelModeT chanMode, HrirDataT *hData); #endif /* LOADSOFA_H */ diff --git a/utils/makemhr/makemhr.cpp b/utils/makemhr/makemhr.cpp index 4d83cd6d..e9e950fb 100644 --- a/utils/makemhr/makemhr.cpp +++ b/utils/makemhr/makemhr.cpp @@ -846,7 +846,7 @@ struct HrirReconstructor { } }; -static void ReconstructHrirs(const HrirDataT *hData) +static void ReconstructHrirs(const HrirDataT *hData, const uint numThreads) { const uint channels{(hData->mChannelType == CT_STEREO) ? 2u : 1u}; @@ -885,9 +885,11 @@ static void ReconstructHrirs(const HrirDataT *hData) } } - /* Launch two threads to work on reconstruction. */ - std::thread thrd1{std::mem_fn(&HrirReconstructor::Worker), &reconstructor}; - std::thread thrd2{std::mem_fn(&HrirReconstructor::Worker), &reconstructor}; + /* Launch threads to work on reconstruction. */ + std::vector<std::thread> thrds; + thrds.reserve(numThreads); + for(size_t i{0};i < numThreads;++i) + thrds.emplace_back(std::mem_fn(&HrirReconstructor::Worker), &reconstructor); /* Keep track of the number of IRs done, periodically reporting it. */ size_t count; @@ -902,8 +904,11 @@ static void ReconstructHrirs(const HrirDataT *hData) } while(count != total); fputc('\n', stdout); - if(thrd2.joinable()) thrd2.join(); - if(thrd1.joinable()) thrd1.join(); + for(auto &thrd : thrds) + { + if(thrd.joinable()) + thrd.join(); + } } /* Given field and elevation indices and an azimuth, calculate the indices of @@ -1371,13 +1376,14 @@ int PrepareHrirData(const uint fdCount, const double (&distances)[MAX_FD_COUNT], * from standard input. */ static int ProcessDefinition(const char *inName, const uint outRate, const ChannelModeT chanMode, - const bool farfield, const uint fftSize, const int equalize, const int surface, - const double limit, const uint truncSize, const HeadModelT model, const double radius, - const char *outName) + const bool farfield, const uint numThreads, const uint fftSize, const int equalize, + const int surface, const double limit, const uint truncSize, const HeadModelT model, + const double radius, const char *outName) { char rateStr[8+1], expName[MAX_PATH_LEN]; HrirDataT hData; + fprintf(stdout, "Using %u thread%s.\n", numThreads, (numThreads==1)?"":"s"); if(!inName) { inName = "stdin"; @@ -1408,7 +1414,7 @@ static int ProcessDefinition(const char *inName, const uint outRate, const Chann { input = nullptr; fprintf(stdout, "Reading HRTF data from %s...\n", inName); - if(!LoadSofaFile(inName, fftSize, truncSize, chanMode, &hData)) + if(!LoadSofaFile(inName, numThreads, fftSize, truncSize, chanMode, &hData)) return 0; } else @@ -1455,7 +1461,7 @@ static int ProcessDefinition(const char *inName, const uint outRate, const Chann ResampleHrirs(outRate, &hData); } fprintf(stdout, "Performing minimum phase reconstruction...\n"); - ReconstructHrirs(&hData); + ReconstructHrirs(&hData, numThreads); fprintf(stdout, "Truncating minimum-phase HRIRs...\n"); hData.mIrPoints = truncSize; fprintf(stdout, "Synthesizing missing elevations...\n"); @@ -1481,6 +1487,7 @@ static void PrintHelp(const char *argv0, FILE *ofile) fprintf(ofile, " -m Change the data set to mono, mirroring the left ear for the\n"); fprintf(ofile, " right ear.\n"); fprintf(ofile, " -a Change the data set to single field, using the farthest field.\n"); + fprintf(ofile, " -j <threads> Number of threads used to process HRIRs (default: 2).\n"); fprintf(ofile, " -f <points> Override the FFT window size (default: %u).\n", DEFAULT_FFTSIZE); fprintf(ofile, " -e {on|off} Toggle diffuse-field equalization (default: %s).\n", (DEFAULT_EQUALIZE ? "on" : "off")); fprintf(ofile, " -s {on|off} Toggle surface-weighted diffuse-field average (default: %s).\n", (DEFAULT_SURFACE ? "on" : "off")); @@ -1505,6 +1512,7 @@ int main(int argc, char *argv[]) char *end = nullptr; ChannelModeT chanMode; HeadModelT model; + uint numThreads; uint truncSize; double radius; bool farfield; @@ -1527,12 +1535,13 @@ int main(int argc, char *argv[]) equalize = DEFAULT_EQUALIZE; surface = DEFAULT_SURFACE; limit = DEFAULT_LIMIT; + numThreads = 2; truncSize = DEFAULT_TRUNCSIZE; model = DEFAULT_HEAD_MODEL; radius = DEFAULT_CUSTOM_RADIUS; farfield = false; - while((opt=getopt(argc, argv, "r:maf:e:s:l:w:d:c:e:i:o:h")) != -1) + while((opt=getopt(argc, argv, "r:maj:f:e:s:l:w:d:c:e:i:o:h")) != -1) { switch(opt) { @@ -1553,6 +1562,17 @@ int main(int argc, char *argv[]) farfield = true; break; + case 'j': + numThreads = static_cast<uint>(strtoul(optarg, &end, 10)); + if(end[0] != '\0' || numThreads > 64) + { + fprintf(stderr, "\nError: Got unexpected value \"%s\" for option -%c, expected between %u to %u.\n", optarg, opt, 0, 64); + exit(EXIT_FAILURE); + } + if(numThreads == 0) + numThreads = std::thread::hardware_concurrency(); + break; + case 'f': fftSize = static_cast<uint>(strtoul(optarg, &end, 10)); if(end[0] != '\0' || (fftSize&(fftSize-1)) || fftSize < MIN_FFTSIZE || fftSize > MAX_FFTSIZE) @@ -1648,8 +1668,8 @@ int main(int argc, char *argv[]) } } - int ret = ProcessDefinition(inName, outRate, chanMode, farfield, fftSize, equalize, surface, - limit, truncSize, model, radius, outName); + int ret = ProcessDefinition(inName, outRate, chanMode, farfield, numThreads, fftSize, equalize, + surface, limit, truncSize, model, radius, outName); if(!ret) return -1; fprintf(stdout, "Operation completed.\n"); |