From 780459d80ccb1f0118ecccfb96d08e316156a4bd Mon Sep 17 00:00:00 2001 From: Nathan Moinvaziri Date: Wed, 26 Apr 2023 19:56:52 -0700 Subject: [PATCH] Change to use encrypt/decrypt final functions in mz_aes. --- mz_crypt.h | 4 +-- mz_crypt_apple.c | 42 +++++++++++++---------- mz_crypt_openssl.c | 81 ++++++++++++++++++++++++--------------------- mz_crypt_winvista.c | 60 +++++++++++++++++---------------- mz_crypt_winxp.c | 10 +++--- test/test_crypt.cc | 6 ++-- 6 files changed, 110 insertions(+), 93 deletions(-) diff --git a/mz_crypt.h b/mz_crypt.h index caedc08..0b1fd83 100644 --- a/mz_crypt.h +++ b/mz_crypt.h @@ -36,9 +36,9 @@ void mz_crypt_sha_delete(void **handle); void mz_crypt_aes_reset(void *handle); int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size); +int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size); int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size); -int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size); -int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length); +int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size); int32_t mz_crypt_aes_set_encrypt_key(void *handle, const void *key, int32_t key_length, const void *iv, int32_t iv_length); int32_t mz_crypt_aes_set_decrypt_key(void *handle, const void *key, int32_t key_length, diff --git a/mz_crypt_apple.c b/mz_crypt_apple.c index 6e8ecff..fd402fb 100644 --- a/mz_crypt_apple.c +++ b/mz_crypt_apple.c @@ -228,6 +228,24 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) { return size; } +int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) { + mz_crypt_aes *aes = (mz_crypt_aes *)handle; + + if (!aes || !tag || !tag_size || aes->mode != MZ_AES_MODE_GCM) + return MZ_PARAM_ERROR; + + aes->error = CCCryptorGCMEncrypt(aes->crypt, buf, size, buf); + if (aes->error != kCCSuccess) + return MZ_CRYPT_ERROR; + + aes->error = CCCryptorGCMFinal(aes->crypt, tag, (size_t *)&tag_size); + + if (aes->error != kCCSuccess) + return MZ_CRYPT_ERROR; + + return size; +} + int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { mz_crypt_aes *aes = (mz_crypt_aes *)handle; size_t data_moved = 0; @@ -246,21 +264,7 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { return size; } -int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) { - mz_crypt_aes *aes = (mz_crypt_aes *)handle; - - if (!aes || !tag || !tag_size) - return MZ_PARAM_ERROR; - - aes->error = CCCryptorGCMFinal(aes->crypt, tag, (size_t *)&tag_size); - - if (aes->error != kCCSuccess) - return MZ_CRYPT_ERROR; - - return MZ_OK; -} - -int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) { +int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) { mz_crypt_aes *aes = (mz_crypt_aes *)handle; uint8_t tag_actual_buf[MZ_AES_BLOCK_SIZE]; size_t tag_actual_len = sizeof(tag_actual_buf); @@ -268,9 +272,13 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) int32_t c = tag_length; int32_t is_ok = 0; - if (!aes || !tag || !tag_length) + if (!aes || !tag || !tag_length || aes->mode != MZ_AES_MODE_GCM) return MZ_PARAM_ERROR; + aes->error = CCCryptorGCMDecrypt(aes->crypt, buf, size, buf); + if (aes->error != kCCSuccess) + return MZ_CRYPT_ERROR; + /* CCCryptorGCMFinal does not verify tag */ aes->error = CCCryptorGCMFinal(aes->crypt, tag_actual, &tag_actual_len); @@ -286,7 +294,7 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) if (is_ok) return MZ_CRYPT_ERROR; - return MZ_OK; + return size; } static int32_t mz_crypt_aes_set_key(void *handle, const void *key, int32_t key_length, diff --git a/mz_crypt_openssl.c b/mz_crypt_openssl.c index af7a0f1..1ff5c0c 100644 --- a/mz_crypt_openssl.c +++ b/mz_crypt_openssl.c @@ -278,7 +278,9 @@ void mz_crypt_aes_reset(void *handle) { int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) { mz_crypt_aes *aes = (mz_crypt_aes *)handle; - if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0) + if (!aes || !buf) + return MZ_PARAM_ERROR; + if (aes->mode != MZ_AES_MODE_GCM && size % MZ_AES_BLOCK_SIZE != 0) return MZ_PARAM_ERROR; #if OPENSSL_VERSION_NUMBER < 0x00900070L @@ -302,10 +304,41 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) { return size; } +int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) { +#if OPENSSL_VERSION_NUMBER < 0x00900070L + return MZ_SUPPORT_ERROR; +#else + mz_crypt_aes *aes = (mz_crypt_aes *)handle; + int result = 0; + int out_len = 0; + + if (!aes || !tag || !tag_size || aes->mode != MZ_AES_MODE_GCM) + return MZ_PARAM_ERROR; + + if (buf && size) { + if (!EVP_EncryptUpdate(aes->ctx, buf, &size, buf, size)) + return MZ_CRYPT_ERROR; + } + + /* Must call EncryptFinal for tag to be calculated */ + result = EVP_EncryptFinal_ex(aes->ctx, NULL, &out_len); + + if (result) + result = EVP_CIPHER_CTX_ctrl(aes->ctx, EVP_CTRL_GCM_GET_TAG, tag_size, tag); + + if (!result) { + aes->error = ERR_get_error(); + return MZ_CRYPT_ERROR; + } + + return size; +#endif +} + int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { mz_crypt_aes *aes = (mz_crypt_aes *)handle; - if (!aes || !buf || size != MZ_AES_BLOCK_SIZE) + if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0) return MZ_PARAM_ERROR; #if OPENSSL_VERSION_NUMBER < 0x00900070L @@ -329,45 +362,21 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { return size; } -int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) { +int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) { #if OPENSSL_VERSION_NUMBER < 0x00900070L return MZ_SUPPORT_ERROR; #else mz_crypt_aes *aes = (mz_crypt_aes *)handle; - uint8_t temp[MZ_AES_BLOCK_SIZE]; - int temp_len = sizeof(temp); - int result = 0; + int out_len = 0; - if (!aes || !tag || !tag_size) + if (!aes || !tag || !tag_length || aes->mode != MZ_AES_MODE_GCM) return MZ_PARAM_ERROR; - /* Must call EncryptFinal for tag to be calculated */ - result = EVP_EncryptFinal_ex(aes->ctx, NULL, &temp_len); - - if (result) - result = EVP_CIPHER_CTX_ctrl(aes->ctx, EVP_CTRL_GCM_GET_TAG, tag_size, tag); - - if (!result) { - aes->error = ERR_get_error(); - return MZ_CRYPT_ERROR; + if (buf && size) { + if (!EVP_DecryptUpdate(aes->ctx, buf, &size, buf, size)) + return MZ_CRYPT_ERROR; } - return MZ_OK; -#endif -} - -int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) { -#if OPENSSL_VERSION_NUMBER < 0x00900070L - return MZ_SUPPORT_ERROR; -#else - mz_crypt_aes *aes = (mz_crypt_aes *)handle; - uint8_t temp[MZ_AES_BLOCK_SIZE]; - int temp_len = sizeof(temp); - int result = 0; - - if (!aes || !tag || !tag_length) - return MZ_PARAM_ERROR; - /* Set expected tag */ if (!EVP_CIPHER_CTX_ctrl(aes->ctx, EVP_CTRL_GCM_SET_TAG, tag_length, tag)) { aes->error = ERR_get_error(); @@ -375,14 +384,12 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) } /* Must call DecryptFinal for tag verification */ - result = EVP_DecryptFinal_ex(aes->ctx, temp, &temp_len); - - if (!result) { + if (!EVP_DecryptFinal_ex(aes->ctx, NULL, &out_len)) { aes->error = ERR_get_error(); return MZ_CRYPT_ERROR; } - return MZ_OK; + return size; #endif } @@ -430,7 +437,7 @@ static int32_t mz_crypt_aes_set_key(void *handle, const void *key, int32_t key_l return MZ_HASH_ERROR; } - EVP_CIPHER_CTX_set_padding(aes->ctx, 0); + EVP_CIPHER_CTX_set_padding(aes->ctx, aes->mode == MZ_AES_MODE_GCM); return MZ_OK; } diff --git a/mz_crypt_winvista.c b/mz_crypt_winvista.c index 9e74420..4dd014f 100644 --- a/mz_crypt_winvista.c +++ b/mz_crypt_winvista.c @@ -240,6 +240,8 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) { if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0) return MZ_PARAM_ERROR; + if (aes->mode == MZ_AES_MODE_GCM && !aes->auth_info) + return MZ_PARAM_ERROR; status = BCryptEncrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size, &output_size, 0); @@ -251,6 +253,30 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) { return size; } +int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) { + mz_crypt_aes *aes = (mz_crypt_aes *)handle; + NTSTATUS status = 0; + ULONG output_size = 0; + + if (!aes || !tag || !tag_size || aes->mode != MZ_AES_MODE_GCM || !aes->auth_info) + return MZ_PARAM_ERROR; + + aes->auth_info->pbTag = tag; + aes->auth_info->cbTag = tag_size; + + aes->auth_info->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG; + + status = BCryptEncrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size, + &output_size, 0); + + if (!NT_SUCCESS(status)) { + aes->error = status; + return MZ_CRYPT_ERROR; + } + + return size; +} + int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { mz_crypt_aes *aes = (mz_crypt_aes *)handle; ULONG output_size = 0; @@ -258,6 +284,8 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0) return MZ_PARAM_ERROR; + if (aes->mode == MZ_AES_MODE_GCM && !aes->auth_info) + return MZ_PARAM_ERROR; status = BCryptDecrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size, &output_size, 0); @@ -269,36 +297,12 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { return size; } -int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) { +int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) { mz_crypt_aes *aes = (mz_crypt_aes *)handle; NTSTATUS status = 0; ULONG output_size = 0; - if (!aes || !tag || !tag_size || !aes->auth_info) - return MZ_PARAM_ERROR; - - aes->auth_info->pbTag = tag; - aes->auth_info->cbTag = tag_size; - - aes->auth_info->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG; - - status = BCryptEncrypt(aes->key, NULL, 0, aes->auth_info, aes->iv, aes->iv_length, NULL, 0, - &output_size, 0); - - if (!NT_SUCCESS(status)) { - aes->error = status; - return MZ_CRYPT_ERROR; - } - - return MZ_OK; -} - -int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) { - mz_crypt_aes *aes = (mz_crypt_aes *)handle; - NTSTATUS status = 0; - ULONG output_size = 0; - - if (!aes || !tag || !tag_length || !aes->auth_info) + if (!aes || !tag || !tag_length || aes->mode != MZ_AES_MODE_GCM || !aes->auth_info) return MZ_PARAM_ERROR; aes->auth_info->pbTag = tag; @@ -306,7 +310,7 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) aes->auth_info->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG; - status = BCryptDecrypt(aes->key, NULL, 0, aes->auth_info, aes->iv, aes->iv_length, NULL, 0, + status = BCryptDecrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size, &output_size, 0); if (!NT_SUCCESS(status)) { @@ -314,7 +318,7 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) return MZ_CRYPT_ERROR; } - return MZ_OK; + return size; } static int32_t mz_crypt_aes_set_key(void *handle, const void *key, int32_t key_length, diff --git a/mz_crypt_winxp.c b/mz_crypt_winxp.c index 3144a3b..86ccc00 100644 --- a/mz_crypt_winxp.c +++ b/mz_crypt_winxp.c @@ -227,6 +227,10 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) { return size; } +int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) { + return MZ_SUPPORT_ERROR; +} + int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { mz_crypt_aes *aes = (mz_crypt_aes *)handle; int32_t result = 0; @@ -240,11 +244,7 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) { return size; } -int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) { - return MZ_SUPPORT_ERROR; -} - -int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) { +int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) { return MZ_SUPPORT_ERROR; } diff --git a/test/test_crypt.cc b/test/test_crypt.cc index aabd192..7508731 100644 --- a/test/test_crypt.cc +++ b/test/test_crypt.cc @@ -255,8 +255,7 @@ TEST(crypt, aes128_gcm) { mz_crypt_aes_set_mode(aes, MZ_AES_MODE_GCM); EXPECT_EQ(mz_crypt_aes_set_encrypt_key(aes, key, key_length, iv, iv_length), MZ_OK); EXPECT_EQ(mz_crypt_aes_encrypt(aes, buf, test_length), test_length); - EXPECT_EQ(mz_crypt_aes_encrypt(aes, buf + test_length - 1, test_length - 1), test_length - 1); - EXPECT_EQ(mz_crypt_aes_get_tag(aes, tag, sizeof(tag)), MZ_OK); + EXPECT_EQ(mz_crypt_aes_encrypt_final(aes, buf + test_length, test_length - 1, tag, sizeof(tag)), test_length - 1); mz_crypt_aes_delete(&aes); EXPECT_STRNE((char*)buf, test); @@ -266,8 +265,7 @@ TEST(crypt, aes128_gcm) { mz_crypt_aes_set_mode(aes, MZ_AES_MODE_GCM); EXPECT_EQ(mz_crypt_aes_set_decrypt_key(aes, key, key_length, iv, iv_length), MZ_OK); EXPECT_EQ(mz_crypt_aes_decrypt(aes, buf, test_length), test_length); - EXPECT_EQ(mz_crypt_aes_decrypt(aes, buf + test_length - 1, test_length - 1), test_length - 1); - EXPECT_EQ(mz_crypt_aes_verify_tag(aes, tag, sizeof(tag)), MZ_OK); + EXPECT_EQ(mz_crypt_aes_decrypt_final(aes, buf + test_length, test_length - 1, tag, sizeof(tag)), test_length - 1); mz_crypt_aes_delete(&aes); EXPECT_EQ(memcmp(buf, test, test_length), 0);