fix: cap parallelisim to 4 for cuda to avoid oom (#601)
parent
7877d300ab
commit
5a822c03b6
|
|
@ -116,8 +116,8 @@ std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
|
|
||||||
const size_t num_cpus = std::thread::hardware_concurrency();
|
const size_t num_cpus = std::thread::hardware_concurrency();
|
||||||
if (loader.device == ctranslate2::Device::CUDA) {
|
if (loader.device == ctranslate2::Device::CUDA) {
|
||||||
// When device is cuda, set parallelism to be number of thread.
|
// When device is cuda, set parallelism to be number of thread, capped to 4 to avoid VRAM oom.
|
||||||
loader.num_replicas_per_device = num_cpus;
|
loader.num_replicas_per_device = std::min<int32_t>(num_cpus, 4);
|
||||||
} else if (loader.device == ctranslate2::Device::CPU){
|
} else if (loader.device == ctranslate2::Device::CPU){
|
||||||
// When device is cpu, adjust the number based on threads per replica.
|
// When device is cpu, adjust the number based on threads per replica.
|
||||||
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
|
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue