diff --git a/meson.build b/meson.build index 34da08e..3c7a067 100644 --- a/meson.build +++ b/meson.build @@ -11,6 +11,7 @@ project('tqftpserv', prefix = get_option('prefix') zstd_dep = dependency('libzstd') +add_project_arguments('-DHAVE_ZSTD', language : 'c') # Not required to build the executable, only to install unit file systemd = dependency('systemd', required : false) diff --git a/tqftpserv.c b/tqftpserv.c index d6fa673..3f8a4c4 100644 --- a/tqftpserv.c +++ b/tqftpserv.c @@ -15,7 +15,6 @@ #include "list.h" #include "translate.h" -#include "zstd-decompress.h" #define MAX(x, y) ((x) > (y) ? (x) : (y)) @@ -571,8 +570,6 @@ int main(int argc, char **argv) exit(1); } - zstd_init(); - for (;;) { FD_ZERO(&rfds); FD_SET(fd, &rfds); @@ -674,7 +671,6 @@ int main(int argc, char **argv) } close(fd); - zstd_free(); return 0; } diff --git a/zstd-decompress.c b/zstd-decompress.c index 90b9219..83cf4c8 100644 --- a/zstd-decompress.c +++ b/zstd-decompress.c @@ -1,6 +1,7 @@ // SPDX-License-Identifier: BSD-3-Clause /* * Copyright (c) 2024, Stefan Hansson + * Copyright (c) 2024, Emil Velikov */ /* For memfd_create */ @@ -17,24 +18,6 @@ #include "zstd-decompress.h" -static ZSTD_DCtx *zstd_context = NULL; - -/** - * zstd_init() - set up state for decompression. Needs to be called before zstd_decompress_file() - */ -void zstd_init() -{ - zstd_context = ZSTD_createDCtx(); -} - -/** - * zstd_free() - free state used for decompression. zstd_decompress_file() may not be called after this - */ -void zstd_free() -{ - ZSTD_freeDCtx(zstd_context); -} - /** * zstd_decompress_file() - decompress a zstd-compressed file * @filename: path to a file to decompress @@ -52,7 +35,7 @@ int zstd_decompress_file(const char *filename) const size_t file_size = file_stat.st_size; - const int input_file_fd = open(filename, 0); + const int input_file_fd = open(filename, O_RDONLY); if (input_file_fd == -1) { perror("open failed"); return -1; @@ -69,34 +52,43 @@ int zstd_decompress_file(const char *filename) const unsigned long long decompressed_size = ZSTD_getFrameContentSize(compressed_buffer, file_size); if (decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) { fprintf(stderr, "Content size could not be determined for %s\n", filename); + munmap(compressed_buffer, file_size); return -1; } if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) { fprintf(stderr, "Error getting content size for %s\n", filename); + munmap(compressed_buffer, file_size); return -1; } void* const decompressed_buffer = malloc((size_t)decompressed_size); if (decompressed_buffer == NULL) { perror("malloc failed"); + munmap(compressed_buffer, file_size); return -1; } - const size_t return_size = ZSTD_decompressDCtx(zstd_context, decompressed_buffer, decompressed_size, compressed_buffer, file_size); + const size_t return_size = ZSTD_decompress(decompressed_buffer, decompressed_size, compressed_buffer, file_size); if (ZSTD_isError(return_size)) { fprintf(stderr, "ZSTD_decompress failed: %s\n", ZSTD_getErrorName(return_size)); + free(decompressed_buffer); + munmap(compressed_buffer, file_size); return -1; } const int output_file_fd = memfd_create(filename, 0); if (output_file_fd == -1) { perror("memfd_create failed"); + free(decompressed_buffer); + munmap(compressed_buffer, file_size); return -1; } if (write(output_file_fd, decompressed_buffer, decompressed_size) != decompressed_size) { perror("write failed"); close(output_file_fd); + free(decompressed_buffer); + munmap(compressed_buffer, file_size); return -1; } diff --git a/zstd-decompress.h b/zstd-decompress.h index 52e088e..2747ed4 100644 --- a/zstd-decompress.h +++ b/zstd-decompress.h @@ -8,8 +8,14 @@ #include -void zstd_init(); -void zstd_free(); +#ifdef HAVE_ZSTD int zstd_decompress_file(const char *filename); +#else +static int zstd_decompress_file(const char *filename) +{ + fprintf(stderr, "Built without ZSTD support\n"); + return -1; +} +#endif #endif