diff --git a/lib/dictBuilder/zdict.c b/lib/dictBuilder/zdict.c index befa616b87f..23f155a760d 100644 --- a/lib/dictBuilder/zdict.c +++ b/lib/dictBuilder/zdict.c @@ -99,6 +99,20 @@ unsigned ZDICT_isError(size_t errorCode) { return ERR_isError(errorCode); } const char* ZDICT_getErrorName(size_t errorCode) { return ERR_getErrorName(errorCode); } +static int ZDICT_size_t_addOverflow(size_t a, size_t b, size_t* result) +{ + if (a > (size_t)(-1) - b) return 1; + *result = a + b; + return 0; +} + +static int ZDICT_size_t_mulOverflow(size_t a, size_t b, size_t* result) +{ + if (a != 0 && b > (size_t)(-1) / a) return 1; + *result = a * b; + return 0; +} + unsigned ZDICT_getDictID(const void* dictBuffer, size_t dictSize) { if (dictSize < 8) return 0; @@ -471,11 +485,17 @@ static size_t ZDICT_trainBuffer_legacy(dictItem* dictList, U32 dictListSize, const size_t* fileSizes, unsigned nbFiles, unsigned minRatio, U32 notificationLevel) { - unsigned* const suffix0 = (unsigned*)malloc((bufferSize+2)*sizeof(*suffix0)); - unsigned* const suffix = suffix0+1; - U32* reverseSuffix = (U32*)malloc((bufferSize)*sizeof(*reverseSuffix)); - BYTE* doneMarks = (BYTE*)malloc((bufferSize+16)*sizeof(*doneMarks)); /* +16 for overflow security */ - U32* filePos = (U32*)malloc(nbFiles * sizeof(*filePos)); + unsigned* suffix0; + unsigned* suffix; + U32* reverseSuffix; + BYTE* doneMarks; + U32* filePos; + size_t suffix0Count; + size_t suffix0Bytes; + size_t reverseSuffixBytes; + size_t doneMarksCount; + size_t doneMarksBytes; + size_t filePosBytes; size_t result = 0; clock_t displayClock = 0; clock_t const refreshRate = CLOCKS_PER_SEC * 3 / 10; @@ -494,12 +514,26 @@ static size_t ZDICT_trainBuffer_legacy(dictItem* dictList, U32 dictListSize, /* init */ DISPLAYLEVEL(2, "\r%70s\r", ""); /* clean display line */ + if (ZDICT_size_t_addOverflow(bufferSize, 2, &suffix0Count) + || ZDICT_size_t_mulOverflow(suffix0Count, sizeof(*suffix0), &suffix0Bytes) + || ZDICT_size_t_mulOverflow(bufferSize, sizeof(*reverseSuffix), &reverseSuffixBytes) + || ZDICT_size_t_addOverflow(bufferSize, 16, &doneMarksCount) + || ZDICT_size_t_mulOverflow(doneMarksCount, sizeof(*doneMarks), &doneMarksBytes) + || ZDICT_size_t_mulOverflow((size_t)nbFiles, sizeof(*filePos), &filePosBytes)) { + return ERROR(memory_allocation); + } + + suffix0 = (unsigned*)malloc(suffix0Bytes); + reverseSuffix = (U32*)malloc(reverseSuffixBytes); + doneMarks = (BYTE*)malloc(doneMarksBytes); /* +16 for overflow security */ + filePos = (U32*)malloc(filePosBytes); if (!suffix0 || !reverseSuffix || !doneMarks || !filePos) { result = ERROR(memory_allocation); goto _cleanup; } + suffix = suffix0 + 1; if (minRatio < MINRATIO) minRatio = MINRATIO; - memset(doneMarks, 0, bufferSize+16); + memset(doneMarks, 0, doneMarksCount); /* limit sample set size (divsufsort limitation)*/ if (bufferSize > ZDICT_MAX_SAMPLES_SIZE) DISPLAYLEVEL(3, "sample set too large : reduced to %u MB ...\n", (unsigned)(ZDICT_MAX_SAMPLES_SIZE>>20));