diff --git a/README.md b/README.md new file mode 100644 index 0000000..2f0b099 --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# tblyler/go-mcrypt + +This is a work in progress and currently just implements an encrypt and decrypt function that is compliant with PHP's rijndael AES 256 specification. +As one might guess, this was to replace existing PHP code that relied on rijndael AES 256. + +## Requirements + * libmcrypt diff --git a/mcrypt.go b/mcrypt.go new file mode 100644 index 0000000..2d962b0 --- /dev/null +++ b/mcrypt.go @@ -0,0 +1,253 @@ +package mcrypt + +/* +#cgo LDFLAGS: -lmcrypt +#include +#include +#include + +#define FAILED_MCRYPT_MODULE 1 +#define INVALID_KEY_LENGTH 2 +#define INVALID_IV_LENGTH 3 +#define FAILED_TO_ENCRYPT_DATA 4 +#define FAILED_TO_DECRYPT_DATA 5 +#define INVALID_DATA_LENGTH 6 + +// getError convert a given error code to its string representation +const char* getError(int err) { + switch (err) { + case FAILED_MCRYPT_MODULE: + return "Failed to open mcrypt module"; + + case INVALID_KEY_LENGTH: + return "Invalid key length"; + + case INVALID_IV_LENGTH: + return "Invalid iv length"; + + case FAILED_TO_ENCRYPT_DATA: + return "Failed to encrypt data"; + + case FAILED_TO_DECRYPT_DATA: + return "Failed to decrypt data"; + + case INVALID_DATA_LENGTH: + return "Invalid data length"; + } + + return mcrypt_strerror(err); +} + +// encrypt encrypt a given data set with rijndael 256 +char* encrypt(void* key, int keyLength, void* iv, int ivLength, char* data, int* length, int* err) { + if (*length <= 0) { + *err = INVALID_DATA_LENGTH; + return NULL; + } + + int i; + + MCRYPT td = mcrypt_module_open("rijndael-256", NULL, "cbc", NULL); + if (td == MCRYPT_FAILED) { + *err = FAILED_MCRYPT_MODULE; + return NULL; + } + + int requiredKeySize = mcrypt_enc_get_key_size(td); + int requiredIvSize = mcrypt_enc_get_iv_size(td); + + // make sure the key and iv are the correct sizes + if (keyLength != requiredKeySize) { + *err = INVALID_KEY_LENGTH; + mcrypt_module_close(td); + return NULL; + } + + if (ivLength != requiredIvSize) { + *err = INVALID_IV_LENGTH; + mcrypt_module_close(td); + return NULL; + } + + *err = mcrypt_generic_init(td, key, keyLength, iv); + if (*err) { + mcrypt_generic_deinit(td); + mcrypt_module_close(td); + return NULL; + } + + // get the block size + int blockSize = mcrypt_enc_get_block_size(td); + + // determine the new length if needed + int newLength = 0; + + // if blockSize is greater than length, expand length to the blockSize + if (blockSize > *length) { + newLength = blockSize; + } else { + int lengthBlockMod = *length % blockSize; + if (lengthBlockMod) { + // if length is not multiple of blockSize, make it the next highest blockSize + newLength = *length - lengthBlockMod + blockSize; + } else { + // we do not need to change the length + newLength = *length; + } + } + + // allocate and copy the data to the output value + char* output = malloc(sizeof *output * newLength); + // append byte zeroes to the output array if needed + for (i = *length; i < newLength; ++i) { + output[i] = 0; + } + + memcpy(output, data, *length); + + // update the length to the reallocated length + *length = newLength; + + // loop through the output data by blockSize at a time + for (i = 0; i < *length; i += blockSize) { + // encrypt the block of output[i] plus blockSize + if (mcrypt_generic(td, output+i, blockSize)) { + *err = FAILED_TO_ENCRYPT_DATA; + mcrypt_generic_deinit(td); + mcrypt_module_close(td); + free(output); + return NULL; + } + } + + // finish up mcrypt + mcrypt_generic_deinit(td); + mcrypt_module_close(td); + + // return the encrypted data + return output; +} + +// decrypt decrypt a given data set with rijndael 256 +char* decrypt(void* key, int keyLength, void* iv, int ivLength, char* data, int* length, int* err) { + int i; + + MCRYPT td = mcrypt_module_open("rijndael-256", NULL, "cbc", NULL); + if (td == MCRYPT_FAILED) { + *err = FAILED_MCRYPT_MODULE; + mcrypt_module_close(td); + return NULL; + } + + int requiredKeySize = mcrypt_enc_get_key_size(td); + int requiredIvSize = mcrypt_enc_get_iv_size(td); + + // make sure the key and iv are the correct sizes + if (keyLength != requiredKeySize) { + *err = INVALID_KEY_LENGTH; + mcrypt_module_close(td); + return NULL; + } + + if (ivLength != requiredIvSize) { + *err = INVALID_IV_LENGTH; + mcrypt_module_close(td); + return NULL; + } + + *err = mcrypt_generic_init(td, key, keyLength, iv); + if (*err) { + mcrypt_generic_deinit(td); + mcrypt_module_close(td); + return NULL; + } + + // get the block size + int blockSize = mcrypt_enc_get_block_size(td); + + if (*length < blockSize || *length % blockSize) { + *err = INVALID_DATA_LENGTH; + mcrypt_generic_deinit(td); + mcrypt_module_close(td); + return NULL; + } + + // allocate and copy the data to the output value + char* output = malloc(sizeof *output * *length); + + memcpy(output, data, *length); + + // loop through the output data by blockSize at a time + for (i = 0; i < *length; i += blockSize) { + // decrypt the block of output[i] plus blockSize + if (mdecrypt_generic(td, output+i, blockSize)) { + *err = FAILED_TO_DECRYPT_DATA; + mcrypt_generic_deinit(td); + mcrypt_module_close(td); + free(output); + return NULL; + } + } + + // finish up mcrypt + mcrypt_generic_deinit(td); + mcrypt_module_close(td); + + // return the decrypted data + return output; +} +*/ +import "C" + +import ( + "errors" + "unsafe" +) + +// Encrypt encrypt something with mcrypt rijndael-256 PHP-style +func Encrypt(key []byte, iv []byte, data []byte) ([]byte, error) { + // keep track of the size of the input data + length := C.int(len(data)) + if length == 0 { + return nil, errors.New("Invalid data size of 0") + } + // keep track of any errors that occur on encryption + err := C.int(0) + // encrypt the data + encryptedData := C.encrypt(unsafe.Pointer(&key[0]), C.int(len(key)), unsafe.Pointer(&iv[0]), C.int(len(iv)), (*C.char)(unsafe.Pointer(&data[0])), (*C.int)(unsafe.Pointer(&length)), (*C.int)(unsafe.Pointer(&err))) + + // if err is not 0, there is an error + if int(err) != 0 { + return nil, errors.New(C.GoString(C.getError(err))) + } + + // ensure that memory is freed on the encrypted data after it is converted to Go bytes + defer C.free(unsafe.Pointer(encryptedData)) + + // return the Go bytes of the encrypted data + return C.GoBytes(unsafe.Pointer(encryptedData), length), nil +} + +// Decrypt decrypt something with mcrypt rijndael-256 PHP-style +func Decrypt(key []byte, iv []byte, data []byte) ([]byte, error) { + // keep track of the size of the input data + length := C.int(len(data)) + if length == 0 { + return nil, errors.New("Invalid data size of 0") + } + // keep track of any errors that occur on decryption + err := C.int(0) + // decrypt the data + decryptedData := C.decrypt(unsafe.Pointer(&key[0]), C.int(len(key)), unsafe.Pointer(&iv[0]), C.int(len(iv)), (*C.char)(unsafe.Pointer(&data[0])), (*C.int)(unsafe.Pointer(&length)), (*C.int)(unsafe.Pointer(&err))) + + // if err is not 0, there is an error + if int(err) != 0 { + return nil, errors.New(C.GoString(C.getError(err))) + } + + // ensure that memory is freed on the decrypted data after it is converted to Go bytes + defer C.free(unsafe.Pointer(decryptedData)) + + // return the Go bytes of the decrypted data + return C.GoBytes(unsafe.Pointer(decryptedData), length), nil +} diff --git a/mcrypt_test.go b/mcrypt_test.go new file mode 100644 index 0000000..8077db5 --- /dev/null +++ b/mcrypt_test.go @@ -0,0 +1,153 @@ +package mcrypt + +import ( + "bytes" + "crypto/rand" + mrand "math/rand" + "testing" +) + +func TestEncrypt(t *testing.T) { + dataSizes := []int{8, 13, 16, 32, 64, 1024, 1048576, 4194304, (mrand.Int() % 26214400) + 1} + + for _, dataSize := range dataSizes { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + t.Error("Failed to get random data from crypto/rand") + } + iv := make([]byte, 32) + _, err = rand.Read(iv) + if err != nil { + t.Error("Failed to get random data from crypto/rand") + } + data := make([]byte, dataSize) + _, err = rand.Read(data) + if err != nil { + t.Error("Failed to get random data from crypto/rand") + } + + encrypted, err := Encrypt(key, iv, data) + if err != nil { + t.Error("Failed Encrypt with error: " + err.Error()) + } + + if bytes.Equal(encrypted, data) { + t.Error("Failed Encrypt: Encrypted data was the same as input data") + } + + decrypted, err := Decrypt(key, iv, encrypted) + if err != nil { + t.Error("Failed Decrypt with error: " + err.Error()) + } + + cryptLen := len(decrypted) + for i := 0; i < cryptLen; i++ { + if i >= dataSize { + if decrypted[i] != 0 { + t.Error("Failed encryption/decryption: invalid padding") + } + } else if decrypted[i] != data[i] { + t.Error("Failed encryption/decryption: invalid data") + } + } + } + + key := make([]byte, 31) + iv := make([]byte, 32) + data := make([]byte, 32) + + _, err := Encrypt(key, iv, data) + if err == nil { + t.Error("Failed to receive error for invalid key size") + } + + key = make([]byte, 32) + iv = make([]byte, 31) + _, err = Encrypt(key, iv, data) + if err == nil { + t.Error("Failed to receive error for invalid iv size") + } + + key = make([]byte, 32) + iv = make([]byte, 32) + _, err = Encrypt(key, iv, []byte{}) + if err == nil { + t.Error("Failed to receive error for 0 byte data size") + } +} + +func TestDecrypt(t *testing.T) { + dataSizes := []int{8, 13, 16, 32, 64, 1024, 1048576, 4194304, (mrand.Int() % 26214400) + 1} + + for _, dataSize := range dataSizes { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + t.Error("Failed to get random data from crypto/rand") + } + iv := make([]byte, 32) + _, err = rand.Read(iv) + if err != nil { + t.Error("Failed to get random data from crypto/rand") + } + data := make([]byte, dataSize) + _, err = rand.Read(data) + if err != nil { + t.Error("Failed to get random data from crypto/rand") + } + + encrypted, err := Encrypt(key, iv, data) + if err != nil { + t.Error("Failed Encrypt with error: " + err.Error()) + } + + if bytes.Equal(encrypted, data) { + t.Error("Failed Encrypt: Encrypted data was the same as input data") + } + + decrypted, err := Decrypt(key, iv, encrypted) + if err != nil { + t.Error("Failed Decrypt with error: " + err.Error()) + } + + cryptLen := len(decrypted) + for i := 0; i < cryptLen; i++ { + if i >= dataSize { + if decrypted[i] != 0 { + t.Error("Failed encryption/decryption: invalid padding") + } + } else if decrypted[i] != data[i] { + t.Error("Failed encryption/decryption: invalid data") + } + } + } + + key := make([]byte, 31) + iv := make([]byte, 32) + data := make([]byte, 32) + + _, err := Decrypt(key, iv, data) + if err == nil { + t.Error("Failed to receive error for invalid key size") + } + + key = make([]byte, 32) + iv = make([]byte, 31) + _, err = Decrypt(key, iv, data) + if err == nil { + t.Error("Failed to receive error for invalid iv size") + } + + key = make([]byte, 32) + iv = make([]byte, 32) + data = make([]byte, 31) + if err == nil { + t.Error("Failed to receive error for invalid data size") + } + + data = make([]byte, 0) + if err == nil { + t.Error("Failed to receive error for 0 byte data size") + } +}