Change to use encrypt/decrypt final functions in mz_aes.

This commit is contained in:
Nathan Moinvaziri 2023-04-26 19:56:52 -07:00
parent c958ac8f7c
commit 780459d80c
6 changed files with 110 additions and 93 deletions

View File

@ -36,9 +36,9 @@ void mz_crypt_sha_delete(void **handle);
void mz_crypt_aes_reset(void *handle);
int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size);
int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size);
int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size);
int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size);
int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length);
int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size);
int32_t mz_crypt_aes_set_encrypt_key(void *handle, const void *key, int32_t key_length,
const void *iv, int32_t iv_length);
int32_t mz_crypt_aes_set_decrypt_key(void *handle, const void *key, int32_t key_length,

View File

@ -228,6 +228,24 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
if (!aes || !tag || !tag_size || aes->mode != MZ_AES_MODE_GCM)
return MZ_PARAM_ERROR;
aes->error = CCCryptorGCMEncrypt(aes->crypt, buf, size, buf);
if (aes->error != kCCSuccess)
return MZ_CRYPT_ERROR;
aes->error = CCCryptorGCMFinal(aes->crypt, tag, (size_t *)&tag_size);
if (aes->error != kCCSuccess)
return MZ_CRYPT_ERROR;
return size;
}
int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
size_t data_moved = 0;
@ -246,21 +264,7 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
if (!aes || !tag || !tag_size)
return MZ_PARAM_ERROR;
aes->error = CCCryptorGCMFinal(aes->crypt, tag, (size_t *)&tag_size);
if (aes->error != kCCSuccess)
return MZ_CRYPT_ERROR;
return MZ_OK;
}
int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) {
int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
uint8_t tag_actual_buf[MZ_AES_BLOCK_SIZE];
size_t tag_actual_len = sizeof(tag_actual_buf);
@ -268,9 +272,13 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length)
int32_t c = tag_length;
int32_t is_ok = 0;
if (!aes || !tag || !tag_length)
if (!aes || !tag || !tag_length || aes->mode != MZ_AES_MODE_GCM)
return MZ_PARAM_ERROR;
aes->error = CCCryptorGCMDecrypt(aes->crypt, buf, size, buf);
if (aes->error != kCCSuccess)
return MZ_CRYPT_ERROR;
/* CCCryptorGCMFinal does not verify tag */
aes->error = CCCryptorGCMFinal(aes->crypt, tag_actual, &tag_actual_len);
@ -286,7 +294,7 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length)
if (is_ok)
return MZ_CRYPT_ERROR;
return MZ_OK;
return size;
}
static int32_t mz_crypt_aes_set_key(void *handle, const void *key, int32_t key_length,

View File

@ -278,7 +278,9 @@ void mz_crypt_aes_reset(void *handle) {
int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0)
if (!aes || !buf)
return MZ_PARAM_ERROR;
if (aes->mode != MZ_AES_MODE_GCM && size % MZ_AES_BLOCK_SIZE != 0)
return MZ_PARAM_ERROR;
#if OPENSSL_VERSION_NUMBER < 0x00900070L
@ -302,10 +304,41 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) {
#if OPENSSL_VERSION_NUMBER < 0x00900070L
return MZ_SUPPORT_ERROR;
#else
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
int result = 0;
int out_len = 0;
if (!aes || !tag || !tag_size || aes->mode != MZ_AES_MODE_GCM)
return MZ_PARAM_ERROR;
if (buf && size) {
if (!EVP_EncryptUpdate(aes->ctx, buf, &size, buf, size))
return MZ_CRYPT_ERROR;
}
/* Must call EncryptFinal for tag to be calculated */
result = EVP_EncryptFinal_ex(aes->ctx, NULL, &out_len);
if (result)
result = EVP_CIPHER_CTX_ctrl(aes->ctx, EVP_CTRL_GCM_GET_TAG, tag_size, tag);
if (!result) {
aes->error = ERR_get_error();
return MZ_CRYPT_ERROR;
}
return size;
#endif
}
int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
if (!aes || !buf || size != MZ_AES_BLOCK_SIZE)
if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0)
return MZ_PARAM_ERROR;
#if OPENSSL_VERSION_NUMBER < 0x00900070L
@ -329,45 +362,21 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) {
int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) {
#if OPENSSL_VERSION_NUMBER < 0x00900070L
return MZ_SUPPORT_ERROR;
#else
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
uint8_t temp[MZ_AES_BLOCK_SIZE];
int temp_len = sizeof(temp);
int result = 0;
int out_len = 0;
if (!aes || !tag || !tag_size)
if (!aes || !tag || !tag_length || aes->mode != MZ_AES_MODE_GCM)
return MZ_PARAM_ERROR;
/* Must call EncryptFinal for tag to be calculated */
result = EVP_EncryptFinal_ex(aes->ctx, NULL, &temp_len);
if (result)
result = EVP_CIPHER_CTX_ctrl(aes->ctx, EVP_CTRL_GCM_GET_TAG, tag_size, tag);
if (!result) {
aes->error = ERR_get_error();
if (buf && size) {
if (!EVP_DecryptUpdate(aes->ctx, buf, &size, buf, size))
return MZ_CRYPT_ERROR;
}
return MZ_OK;
#endif
}
int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) {
#if OPENSSL_VERSION_NUMBER < 0x00900070L
return MZ_SUPPORT_ERROR;
#else
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
uint8_t temp[MZ_AES_BLOCK_SIZE];
int temp_len = sizeof(temp);
int result = 0;
if (!aes || !tag || !tag_length)
return MZ_PARAM_ERROR;
/* Set expected tag */
if (!EVP_CIPHER_CTX_ctrl(aes->ctx, EVP_CTRL_GCM_SET_TAG, tag_length, tag)) {
aes->error = ERR_get_error();
@ -375,14 +384,12 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length)
}
/* Must call DecryptFinal for tag verification */
result = EVP_DecryptFinal_ex(aes->ctx, temp, &temp_len);
if (!result) {
if (!EVP_DecryptFinal_ex(aes->ctx, NULL, &out_len)) {
aes->error = ERR_get_error();
return MZ_CRYPT_ERROR;
}
return MZ_OK;
return size;
#endif
}
@ -430,7 +437,7 @@ static int32_t mz_crypt_aes_set_key(void *handle, const void *key, int32_t key_l
return MZ_HASH_ERROR;
}
EVP_CIPHER_CTX_set_padding(aes->ctx, 0);
EVP_CIPHER_CTX_set_padding(aes->ctx, aes->mode == MZ_AES_MODE_GCM);
return MZ_OK;
}

View File

@ -240,6 +240,8 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) {
if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0)
return MZ_PARAM_ERROR;
if (aes->mode == MZ_AES_MODE_GCM && !aes->auth_info)
return MZ_PARAM_ERROR;
status = BCryptEncrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size,
&output_size, 0);
@ -251,6 +253,30 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
NTSTATUS status = 0;
ULONG output_size = 0;
if (!aes || !tag || !tag_size || aes->mode != MZ_AES_MODE_GCM || !aes->auth_info)
return MZ_PARAM_ERROR;
aes->auth_info->pbTag = tag;
aes->auth_info->cbTag = tag_size;
aes->auth_info->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
status = BCryptEncrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size,
&output_size, 0);
if (!NT_SUCCESS(status)) {
aes->error = status;
return MZ_CRYPT_ERROR;
}
return size;
}
int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
ULONG output_size = 0;
@ -258,6 +284,8 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
if (!aes || !buf || size % MZ_AES_BLOCK_SIZE != 0)
return MZ_PARAM_ERROR;
if (aes->mode == MZ_AES_MODE_GCM && !aes->auth_info)
return MZ_PARAM_ERROR;
status = BCryptDecrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size,
&output_size, 0);
@ -269,36 +297,12 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) {
int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
NTSTATUS status = 0;
ULONG output_size = 0;
if (!aes || !tag || !tag_size || !aes->auth_info)
return MZ_PARAM_ERROR;
aes->auth_info->pbTag = tag;
aes->auth_info->cbTag = tag_size;
aes->auth_info->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
status = BCryptEncrypt(aes->key, NULL, 0, aes->auth_info, aes->iv, aes->iv_length, NULL, 0,
&output_size, 0);
if (!NT_SUCCESS(status)) {
aes->error = status;
return MZ_CRYPT_ERROR;
}
return MZ_OK;
}
int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
NTSTATUS status = 0;
ULONG output_size = 0;
if (!aes || !tag || !tag_length || !aes->auth_info)
if (!aes || !tag || !tag_length || aes->mode != MZ_AES_MODE_GCM || !aes->auth_info)
return MZ_PARAM_ERROR;
aes->auth_info->pbTag = tag;
@ -306,7 +310,7 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length)
aes->auth_info->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
status = BCryptDecrypt(aes->key, NULL, 0, aes->auth_info, aes->iv, aes->iv_length, NULL, 0,
status = BCryptDecrypt(aes->key, buf, size, aes->auth_info, aes->iv, aes->iv_length, buf, size,
&output_size, 0);
if (!NT_SUCCESS(status)) {
@ -314,7 +318,7 @@ int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length)
return MZ_CRYPT_ERROR;
}
return MZ_OK;
return size;
}
static int32_t mz_crypt_aes_set_key(void *handle, const void *key, int32_t key_length,

View File

@ -227,6 +227,10 @@ int32_t mz_crypt_aes_encrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_encrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_size) {
return MZ_SUPPORT_ERROR;
}
int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
mz_crypt_aes *aes = (mz_crypt_aes *)handle;
int32_t result = 0;
@ -240,11 +244,7 @@ int32_t mz_crypt_aes_decrypt(void *handle, uint8_t *buf, int32_t size) {
return size;
}
int32_t mz_crypt_aes_get_tag(void *handle, uint8_t *tag, int32_t tag_size) {
return MZ_SUPPORT_ERROR;
}
int32_t mz_crypt_aes_verify_tag(void *handle, uint8_t *tag, int32_t tag_length) {
int32_t mz_crypt_aes_decrypt_final(void *handle, uint8_t *buf, int32_t size, uint8_t *tag, int32_t tag_length) {
return MZ_SUPPORT_ERROR;
}

View File

@ -255,8 +255,7 @@ TEST(crypt, aes128_gcm) {
mz_crypt_aes_set_mode(aes, MZ_AES_MODE_GCM);
EXPECT_EQ(mz_crypt_aes_set_encrypt_key(aes, key, key_length, iv, iv_length), MZ_OK);
EXPECT_EQ(mz_crypt_aes_encrypt(aes, buf, test_length), test_length);
EXPECT_EQ(mz_crypt_aes_encrypt(aes, buf + test_length - 1, test_length - 1), test_length - 1);
EXPECT_EQ(mz_crypt_aes_get_tag(aes, tag, sizeof(tag)), MZ_OK);
EXPECT_EQ(mz_crypt_aes_encrypt_final(aes, buf + test_length, test_length - 1, tag, sizeof(tag)), test_length - 1);
mz_crypt_aes_delete(&aes);
EXPECT_STRNE((char*)buf, test);
@ -266,8 +265,7 @@ TEST(crypt, aes128_gcm) {
mz_crypt_aes_set_mode(aes, MZ_AES_MODE_GCM);
EXPECT_EQ(mz_crypt_aes_set_decrypt_key(aes, key, key_length, iv, iv_length), MZ_OK);
EXPECT_EQ(mz_crypt_aes_decrypt(aes, buf, test_length), test_length);
EXPECT_EQ(mz_crypt_aes_decrypt(aes, buf + test_length - 1, test_length - 1), test_length - 1);
EXPECT_EQ(mz_crypt_aes_verify_tag(aes, tag, sizeof(tag)), MZ_OK);
EXPECT_EQ(mz_crypt_aes_decrypt_final(aes, buf + test_length, test_length - 1, tag, sizeof(tag)), test_length - 1);
mz_crypt_aes_delete(&aes);
EXPECT_EQ(memcmp(buf, test, test_length), 0);