fundamental: Add overflow-safe math helpers

ADD_SAFE/SUB_SAFE/MUL_SAFE do addition/subtraction/multiplication
respectively with an overflow check. If an overflow occurs these return
false, otherwise true. Example: (c = a + b) would become ADD_SAFE(&c, a,
b)

INC_SAFE/DEC_SAFE/MUL_ASSIGN_SAFE are like above but they also reassign
the first argument. Example: (a += b) would become INC_SAFE(&a, b)
This commit is contained in:
Adrian Vovk
2024-01-17 14:48:45 -05:00
parent 6d55e3a364
commit a7a67dfd9d
6 changed files with 179 additions and 8 deletions

View File

@@ -2288,7 +2288,7 @@ static EFI_STATUS initrd_prepare(
continue;
size_t new_size, read_size = info->FileSize;
if (__builtin_add_overflow(size, read_size, &new_size))
if (!ADD_SAFE(&new_size, size, read_size))
return EFI_OUT_OF_RESOURCES;
initrd = xrealloc(initrd, size, new_size);

View File

@@ -372,9 +372,9 @@ bool efi_fnmatch(const char16_t *pattern, const char16_t *haystack) {
\
uint64_t u = 0; \
while (*s >= '0' && *s <= '9') { \
if (__builtin_mul_overflow(u, 10, &u)) \
if (!MUL_ASSIGN_SAFE(&u, 10)) \
return false; \
if (__builtin_add_overflow(u, *s - '0', &u)) \
if (!INC_SAFE(&u, *s - '0')) \
return false; \
s++; \
} \
@@ -593,13 +593,13 @@ typedef struct {
static void grow_buf(FormatContext *ctx, size_t need) {
assert(ctx);
assert_se(!__builtin_add_overflow(ctx->n, need, &need));
assert_se(INC_SAFE(&need, ctx->n));
if (need < ctx->n_buf)
return;
/* Greedily allocate if we can. */
if (__builtin_mul_overflow(need, 2, &ctx->n_buf))
if (!MUL_SAFE(&ctx->n_buf, need, 2))
ctx->n_buf = need;
/* We cannot use realloc here as ctx->buf may be ctx->stack_buf, which we cannot free. */

View File

@@ -37,7 +37,7 @@ static inline void *xmalloc(size_t size) {
_malloc_ _alloc_(1, 2) _returns_nonnull_ _warn_unused_result_
static inline void *xmalloc_multiply(size_t n, size_t size) {
assert_se(!__builtin_mul_overflow(size, n, &size));
assert_se(MUL_ASSIGN_SAFE(&size, n));
return xmalloc(size);
}

View File

@@ -249,6 +249,30 @@
CONST_ISPOWEROF2(_x); \
}))
#define ADD_SAFE(ret, a, b) (!__builtin_add_overflow(a, b, ret))
#define INC_SAFE(a, b) __INC_SAFE(UNIQ, a, b)
#define __INC_SAFE(q, a, b) \
({ \
const typeof(a) UNIQ_T(A, q) = (a); \
ADD_SAFE(UNIQ_T(A, q), *UNIQ_T(A, q), b); \
})
#define SUB_SAFE(ret, a, b) (!__builtin_sub_overflow(a, b, ret))
#define DEC_SAFE(a, b) __DEC_SAFE(UNIQ, a, b)
#define __DEC_SAFE(q, a, b) \
({ \
const typeof(a) UNIQ_T(A, q) = (a); \
SUB_SAFE(UNIQ_T(A, q), *UNIQ_T(A, q), b); \
})
#define MUL_SAFE(ret, a, b) (!__builtin_mul_overflow(a, b, ret))
#define MUL_ASSIGN_SAFE(a, b) __MUL_ASSIGN_SAFE(UNIQ, a, b)
#define __MUL_ASSIGN_SAFE(q, a, b) \
({ \
const typeof(a) UNIQ_T(A, q) = (a); \
MUL_SAFE(UNIQ_T(A, q), *UNIQ_T(A, q), b); \
})
#define LESS_BY(a, b) __LESS_BY(UNIQ, (a), UNIQ, (b))
#define __LESS_BY(aq, a, bq, b) \
({ \
@@ -298,7 +322,7 @@
const typeof(y) UNIQ_T(A, q) = (y); \
const typeof(x) UNIQ_T(B, q) = DIV_ROUND_UP((x), UNIQ_T(A, q)); \
typeof(x) UNIQ_T(C, q); \
__builtin_mul_overflow(UNIQ_T(B, q), UNIQ_T(A, q), &UNIQ_T(C, q)) ? (typeof(x)) -1 : UNIQ_T(C, q); \
MUL_SAFE(&UNIQ_T(C, q), UNIQ_T(B, q), UNIQ_T(A, q)) ? UNIQ_T(C, q) : (typeof(x)) -1; \
})
#define ROUND_UP(x, y) __ROUND_UP(UNIQ, (x), (y))

View File

@@ -540,7 +540,7 @@ static int module_callback(Dwfl_Module *mod, void **userdata, const char *name,
continue;
/* Check that the end of segment is a valid address. */
if (__builtin_add_overflow(program_header->p_vaddr, program_header->p_memsz, &end_of_segment)) {
if (!ADD_SAFE(&end_of_segment, program_header->p_vaddr, program_header->p_memsz)) {
log_error("Abort due to corrupted core dump, end of segment address %#zx + %#zx overflows", (size_t)program_header->p_vaddr, (size_t)program_header->p_memsz);
return DWARF_CB_ABORT;
}

View File

@@ -159,6 +159,153 @@ TEST(container_of) {
#pragma GCC diagnostic pop
#define TEST_OVERFLOW_MATH_BY_TYPE(type, min, max, lit) \
({ \
type x; \
\
assert_se(ADD_SAFE(&x, lit(5), lit(10))); \
assert_se(x == lit(15)); \
if (IS_SIGNED_INTEGER_TYPE(type)) { \
assert_se(ADD_SAFE(&x, lit(5), lit(-10))); \
assert_se(x == lit(-5)); \
} \
assert_se(ADD_SAFE(&x, min, lit(0))); \
assert_se(x == min); \
assert_se(ADD_SAFE(&x, max, lit(0))); \
assert_se(x == max); \
if (IS_SIGNED_INTEGER_TYPE(type)) \
assert_se(!ADD_SAFE(&x, min, lit(-1))); \
assert_se(!ADD_SAFE(&x, max, lit(1))); \
\
x = lit(5); \
assert_se(INC_SAFE(&x, lit(10))); \
assert_se(x == lit(15)); \
if (IS_SIGNED_INTEGER_TYPE(type)) { \
assert_se(INC_SAFE(&x, lit(-20))); \
assert_se(x == lit(-5)); \
} \
x = min; \
assert_se(INC_SAFE(&x, lit(0))); \
assert_se(x == min); \
if (IS_SIGNED_INTEGER_TYPE(type)) \
assert_se(!INC_SAFE(&x, lit(-1))); \
x = max; \
assert_se(INC_SAFE(&x, lit(0))); \
assert_se(x == max); \
assert_se(!INC_SAFE(&x, lit(1))); \
\
assert_se(SUB_SAFE(&x, lit(10), lit(5))); \
assert_se(x == lit(5)); \
if (IS_SIGNED_INTEGER_TYPE(type)) { \
assert_se(SUB_SAFE(&x, lit(5), lit(10))); \
assert_se(x == lit(-5)); \
\
assert_se(SUB_SAFE(&x, lit(5), lit(-10))); \
assert_se(x == lit(15)); \
} else \
assert_se(!SUB_SAFE(&x, lit(5), lit(10))); \
assert_se(SUB_SAFE(&x, min, lit(0))); \
assert_se(x == min); \
assert_se(SUB_SAFE(&x, max, lit(0))); \
assert_se(x == max); \
assert_se(!SUB_SAFE(&x, min, lit(1))); \
if (IS_SIGNED_INTEGER_TYPE(type)) \
assert_se(!SUB_SAFE(&x, max, lit(-1))); \
\
x = lit(10); \
assert_se(DEC_SAFE(&x, lit(5))); \
assert_se(x == lit(5)); \
if (IS_SIGNED_INTEGER_TYPE(type)) { \
assert_se(DEC_SAFE(&x, lit(10))); \
assert_se(x == lit(-5)); \
\
x = lit(5); \
assert_se(DEC_SAFE(&x, lit(-10))); \
assert_se(x == lit(15)); \
} else \
assert_se(!DEC_SAFE(&x, lit(10))); \
x = min; \
assert_se(DEC_SAFE(&x, lit(0))); \
assert_se(x == min); \
assert_se(!DEC_SAFE(&x, lit(1))); \
x = max; \
assert_se(DEC_SAFE(&x, lit(0))); \
if (IS_SIGNED_INTEGER_TYPE(type)) \
assert_se(!DEC_SAFE(&x, lit(-1))); \
\
assert_se(MUL_SAFE(&x, lit(2), lit(4))); \
assert_se(x == lit(8)); \
if (IS_SIGNED_INTEGER_TYPE(type)) { \
assert_se(MUL_SAFE(&x, lit(2), lit(-4))); \
assert_se(x == lit(-8)); \
} \
assert_se(MUL_SAFE(&x, lit(5), lit(0))); \
assert_se(x == lit(0)); \
assert_se(MUL_SAFE(&x, min, lit(1))); \
assert_se(x == min); \
if (IS_SIGNED_INTEGER_TYPE(type)) \
assert_se(!MUL_SAFE(&x, min, lit(2))); \
assert_se(MUL_SAFE(&x, max, lit(1))); \
assert_se(x == max); \
assert_se(!MUL_SAFE(&x, max, lit(2))); \
\
x = lit(2); \
assert_se(MUL_ASSIGN_SAFE(&x, lit(4))); \
assert_se(x == lit(8)); \
if (IS_SIGNED_INTEGER_TYPE(type)) { \
assert_se(MUL_ASSIGN_SAFE(&x, lit(-1))); \
assert_se(x == lit(-8)); \
} \
assert_se(MUL_ASSIGN_SAFE(&x, lit(0))); \
assert_se(x == lit(0)); \
x = min; \
assert_se(MUL_ASSIGN_SAFE(&x, lit(1))); \
assert_se(x == min); \
if IS_SIGNED_INTEGER_TYPE(type) \
assert_se(!MUL_ASSIGN_SAFE(&x, lit(2))); \
x = max; \
assert_se(MUL_ASSIGN_SAFE(&x, lit(1))); \
assert_se(x == max); \
assert_se(!MUL_ASSIGN_SAFE(&x, lit(2))); \
})
TEST(overflow_safe_math) {
int64_t i;
uint64_t u, *p;
/* basic tests */
TEST_OVERFLOW_MATH_BY_TYPE(int8_t, INT8_MIN, INT8_MAX, INT8_C);
TEST_OVERFLOW_MATH_BY_TYPE(int16_t, INT16_MIN, INT16_MAX, INT16_C);
TEST_OVERFLOW_MATH_BY_TYPE(int32_t, INT32_MIN, INT32_MAX, INT32_C);
TEST_OVERFLOW_MATH_BY_TYPE(int64_t, INT64_MIN, INT64_MAX, INT64_C);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtype-limits" /* Otherwise the compiler complains about comparisons to negative numbers always being false */
TEST_OVERFLOW_MATH_BY_TYPE(uint8_t, UINT8_C(0), UINT8_MAX, UINT8_C);
TEST_OVERFLOW_MATH_BY_TYPE(uint16_t, UINT16_C(0), UINT16_MAX, UINT16_C);
TEST_OVERFLOW_MATH_BY_TYPE(uint32_t, UINT32_C(0), UINT32_MAX, UINT32_C);
TEST_OVERFLOW_MATH_BY_TYPE(uint64_t, UINT64_C(0), UINT64_MAX, UINT64_C);
#pragma GCC diagnostic pop
/* make sure we handle pointers correctly */
p = &u;
assert_se(ADD_SAFE(p, 35, 15) && (u == 50));
assert_se(SUB_SAFE(p, 35, 15) && (u == 20));
assert_se(MUL_SAFE(p, 5, 10) && (u == 50));
assert_se(INC_SAFE(p, 10) && (u == 60));
assert_se(DEC_SAFE(p, 10) && (u == 50));
assert_se(MUL_ASSIGN_SAFE(p, 3) && (u == 150));
assert_se(!ADD_SAFE(p, UINT64_MAX, 1));
assert_se(!SUB_SAFE(p, 0, 1));
/* cross-type sanity checks */
assert_se(ADD_SAFE(&i, INT32_MAX, 1));
assert_se(SUB_SAFE(&i, INT32_MIN, 1));
assert_se(!ADD_SAFE(&i, UINT64_MAX, 0));
assert_se(ADD_SAFE(&u, INT32_MAX, 1));
assert_se(MUL_SAFE(&u, INT32_MAX, 2));
}
TEST(DIV_ROUND_UP) {
int div;