From c33e43317f1b942db0606286c38412dccda1ec2c Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 02:37:20 -0800 Subject: [PATCH 01/12] feat: wip/tinygrad --- aes.py | 19 ++++++------ main.py | 93 +++++++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 77 insertions(+), 35 deletions(-) diff --git a/aes.py b/aes.py index 724c872..a3acc16 100644 --- a/aes.py +++ b/aes.py @@ -121,6 +121,7 @@ def change_key(self, master_key): ^ self.round_keys[i - 1][j] self.round_keys[i].append(byte) + print(f'round_keys: {self.round_keys}') # print self.round_keys def encrypt(self, plaintext): @@ -166,9 +167,15 @@ def __add_round_key(self, s, k): def __round_encrypt(self, state_matrix, key_matrix): self.__sub_bytes(state_matrix) + print("After sub_bytes:", state_matrix) self.__shift_rows(state_matrix) + print("After shift_rows:", state_matrix) self.__mix_columns(state_matrix) + print("After mix_columns:", state_matrix) + print(f'[ROUND_ENCRYPT] state_matrix: {state_matrix}') + print(f'[ROUND_ENCRYPT] key_matrix: {key_matrix}') self.__add_round_key(state_matrix, key_matrix) + print("After add_round_key:", state_matrix) def __round_decrypt(self, state_matrix, key_matrix): @@ -229,12 +236,6 @@ def __inv_mix_columns(self, s): if __name__ == "__main__": - # aes = AES(0x2b) - # print(hex(aes.encrypt(0x3243f6a8885a308d313198a2e0370734))) - # print(hex(aes.decrypt(0x3925841d02dc09fbdc118597196a0b32))) - - pt = 0x3243f6a8885a308d313198a2e0370734 - mat = text2matrix(pt) - recovered = matrix2text(mat) - print(mat) - print(hex(recovered)) \ No newline at end of file + aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) + print(hex(aes.encrypt(0x3243f6a8885a308d313198a2e0370734))) + print(hex(aes.decrypt(0x3925841d02dc09fbdc118597196a0b32))) diff --git a/main.py b/main.py index 899bc4d..7b8f757 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ from tinygrad.tensor import Tensor from tinygrad import dtypes -Sbox = [ +Sbox = Tensor([ 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, @@ -18,23 +18,55 @@ 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, -] - -Rcon = [ +], dtype=dtypes.uint8) + + +InvSbox = Tensor([ + 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, + 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, + 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, + 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25, + 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92, + 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84, + 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06, + 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B, + 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73, + 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E, + 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B, + 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4, + 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F, + 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, + 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, + 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, +], dtype=dtypes.uint8) + +Rcon = Tensor([ 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, 0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, -] +], dtype=dtypes.uint8) def xtime(a: int) -> int: return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) def xtime_tensor(a: Tensor) -> Tensor: - high_bits = (a.bitwise_and(0x80)).cast(dtypes.uint64) + # Get high bit mask (equivalent to a & 0x80) + high_bit_mask = a.bitwise_and(0x80) + + # Compute left shift for all values shifted = a.lshift(1) - return (shifted.xor(high_bits * 0x1B)).bitwise_and(0xFF) + + # Create the conditional xor with 0x1B + condition = high_bit_mask != 0 + result = condition.where( + shifted.xor(0x1B), # when high bit is set + shifted # when high bit is not set + ) + + # Mask to keep only bottom 8 bits + return result.bitwise_and(0xFF) def text2matrix(text: int) -> Tensor: @@ -56,35 +88,35 @@ def __init__(self, master_key): self.change_key(master_key) def change_key(self, master_key): - key_matrix = text2matrix(master_key) - self.round_keys = Tensor.zeros((4, 44), dtype=dtypes.uint64).contiguous() - self.round_keys[:, :4] = key_matrix - for i in range(4, 44): + self.round_keys = Tensor.zeros((44, 4), dtype=dtypes.uint8).contiguous() + self.round_keys[:4] = text2matrix(master_key) + + for i in range(4, 4 * 11): if i % 4 == 0: - temp = self.round_keys[:, i-1] - temp = Tensor([Sbox[x] for x in temp.roll(-1, dims=0).numpy().astype(int)], dtype=dtypes.uint64) - rcon = Tensor([Rcon[i // 4], 0, 0, 0], dtype=dtypes.uint64) - self.round_keys[:, i] = self.round_keys[:, i-4].xor(temp).xor(rcon) + self.round_keys[i, 0] = self.round_keys[i-4][0] ^ Sbox[self.round_keys[i-1][1]] ^ Rcon[i//4].item() + bytes = self.round_keys[i-4][1:].xor(Sbox[self.round_keys[i-1][1:].roll(-1, dims=0)]) + self.round_keys[i, 1:] = bytes else: - self.round_keys[:, i] = self.round_keys[:, i-1].xor(self.round_keys[:, i-4]) - + self.round_keys[i] = self.round_keys[i-4].xor(self.round_keys[i-1]) + def encrypt(self, plaintext: int) -> int: print("Plaintext:", plaintext) self.plain_state = text2matrix(plaintext) print("After text2matrix:", self.plain_state.numpy()) + print(f'round_keys: {self.round_keys.numpy()}') - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:, :4]) + self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) print("After initial add_round_key:", self.plain_state.numpy()) for i in range(1, 10): - self.plain_state = self.__round_encrypt(self.plain_state, self.round_keys[:, 4 * i : 4 * (i + 1)]) + self.plain_state = self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) print(f"After round {i}:", self.plain_state.numpy()) self.plain_state = self.__sub_bytes(self.plain_state) print("After final sub_bytes:", self.plain_state.numpy()) self.plain_state = self.__shift_rows(self.plain_state) print("After final shift_rows:", self.plain_state.numpy()) - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:, 40:44]) + self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[40:]) print("After final add_round_key:", self.plain_state.numpy()) return matrix2text(self.plain_state) @@ -106,9 +138,15 @@ def decrypt(self, ciphertext: int) -> int: def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: state_matrix = self.__sub_bytes(state_matrix) + print("After sub_bytes:", state_matrix.numpy()) state_matrix = self.__shift_rows(state_matrix) + print("After shift_rows:", state_matrix.numpy()) state_matrix = self.__mix_columns(state_matrix) + print("After mix_columns:", state_matrix.numpy()) + print(f'[ADD_ROUND_KEY] state_matrix: {state_matrix.numpy()}') + print(f'[ADD_ROUND_KEY] key_matrix: {key_matrix.numpy()}') state_matrix = self.__add_round_key(state_matrix, key_matrix) + print("After add_round_key:", state_matrix.numpy()) return state_matrix def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: @@ -122,10 +160,10 @@ def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: return s.xor(k) def __sub_bytes(self, s: Tensor) -> Tensor: - return Tensor([Sbox[int(x.item())] for x in s.flatten()], dtype=dtypes.uint64).reshape(4, 4) + return Sbox[s] def __inv_sub_bytes(self, s: Tensor) -> Tensor: - return Tensor([Sbox.index(int(x.item())) for x in s.flatten()], dtype=dtypes.uint64).reshape(4, 4) + return InvSbox[s] def __shift_rows(self, s: Tensor) -> Tensor: state = s.clone() @@ -145,9 +183,12 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: def __mix_columns(self, state: Tensor) -> Tensor: t = state[0].xor(state[1]).xor(state[2]).xor(state[3]) - shifted = state.roll(-1, dims=0) - xt_pairs = xtime_tensor(state ^ shifted) - result = state.xor(t.reshape(1, 4)).xor(xt_pairs) + result = state.clone() + + result[0] = state[0].xor(t).xor(xtime_tensor(state[0].xor(state[1]))) + result[1] = state[1].xor(t).xor(xtime_tensor(state[1].xor(state[2]))) + result[2] = state[2].xor(t).xor(xtime_tensor(state[2].xor(state[3]))) + result[3] = state[3].xor(t).xor(xtime_tensor(state[3].xor(state[0]))) return result @@ -160,6 +201,6 @@ def __inv_mix_columns(self, state: Tensor) -> Tensor: return self.__mix_columns(state) if __name__ == "__main__": - aes = AES(0x2b) + aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) print(hex(aes.encrypt(0x3243f6a8885a308d313198a2e0370734))) print(hex(aes.decrypt(0x3925841d02dc09fbdc118597196a0b32))) From a8ac3610f97bba6e4d2cb35288230a88b6b35e4f Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 18:47:24 -0800 Subject: [PATCH 02/12] feat: working --- aes.py | 21 ------------------- main.py | 64 ++++++++++++++++++++++++--------------------------------- 2 files changed, 27 insertions(+), 58 deletions(-) diff --git a/aes.py b/aes.py index a3acc16..b0afc21 100644 --- a/aes.py +++ b/aes.py @@ -75,7 +75,6 @@ def text2matrix(text): - print(f'text: {text}') matrix = [] for i in range(16): byte = (text >> (8 * (15 - i))) & 0xFF @@ -83,7 +82,6 @@ def text2matrix(text): matrix.append([byte]) else: matrix[i // 4].append(byte) - print(f'matrix: {matrix}') return matrix @@ -101,7 +99,6 @@ def __init__(self, master_key): def change_key(self, master_key): self.round_keys = text2matrix(master_key) - # print self.round_keys for i in range(4, 4 * 11): self.round_keys.append([]) @@ -121,27 +118,17 @@ def change_key(self, master_key): ^ self.round_keys[i - 1][j] self.round_keys[i].append(byte) - print(f'round_keys: {self.round_keys}') - # print self.round_keys - def encrypt(self, plaintext): - print("Plaintext:", plaintext) self.plain_state = text2matrix(plaintext) - print("After text2matrix:", self.plain_state) self.__add_round_key(self.plain_state, self.round_keys[:4]) - print("After initial add_round_key:", self.plain_state) for i in range(1, 10): self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) - print(f"After round {i}:", self.plain_state) self.__sub_bytes(self.plain_state) - print("After final sub_bytes:", self.plain_state) self.__shift_rows(self.plain_state) - print("After final shift_rows:", self.plain_state) self.__add_round_key(self.plain_state, self.round_keys[40:]) - print("After final add_round_key:", self.plain_state) return matrix2text(self.plain_state) @@ -167,16 +154,9 @@ def __add_round_key(self, s, k): def __round_encrypt(self, state_matrix, key_matrix): self.__sub_bytes(state_matrix) - print("After sub_bytes:", state_matrix) self.__shift_rows(state_matrix) - print("After shift_rows:", state_matrix) self.__mix_columns(state_matrix) - print("After mix_columns:", state_matrix) - print(f'[ROUND_ENCRYPT] state_matrix: {state_matrix}') - print(f'[ROUND_ENCRYPT] key_matrix: {key_matrix}') self.__add_round_key(state_matrix, key_matrix) - print("After add_round_key:", state_matrix) - def __round_decrypt(self, state_matrix, key_matrix): self.__add_round_key(state_matrix, key_matrix) @@ -223,7 +203,6 @@ def __mix_columns(self, s): def __inv_mix_columns(self, s): - # see Sec 4.1.3 in The Design of Rijndael for i in range(4): u = xtime(xtime(s[i][0] ^ s[i][2])) v = xtime(xtime(s[i][1] ^ s[i][3])) diff --git a/main.py b/main.py index 7b8f757..ce84ea2 100644 --- a/main.py +++ b/main.py @@ -90,63 +90,49 @@ def __init__(self, master_key): def change_key(self, master_key): self.round_keys = Tensor.zeros((44, 4), dtype=dtypes.uint8).contiguous() self.round_keys[:4] = text2matrix(master_key) - for i in range(4, 4 * 11): if i % 4 == 0: - self.round_keys[i, 0] = self.round_keys[i-4][0] ^ Sbox[self.round_keys[i-1][1]] ^ Rcon[i//4].item() - bytes = self.round_keys[i-4][1:].xor(Sbox[self.round_keys[i-1][1:].roll(-1, dims=0)]) - self.round_keys[i, 1:] = bytes + self.round_keys[i, 0] = (self.round_keys[i-4, 0] ^ + Sbox[self.round_keys[i-1, 1].item()] ^ + Rcon[i//4].item()) + + shifted_indices = Tensor([2,3,0], dtype=dtypes.uint8) + sboxed = Sbox[self.round_keys[i-1][shifted_indices]] + self.round_keys[i, 1:] = self.round_keys[i-4, 1:].xor(sboxed) else: self.round_keys[i] = self.round_keys[i-4].xor(self.round_keys[i-1]) def encrypt(self, plaintext: int) -> int: - print("Plaintext:", plaintext) self.plain_state = text2matrix(plaintext) - print("After text2matrix:", self.plain_state.numpy()) - print(f'round_keys: {self.round_keys.numpy()}') - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) - print("After initial add_round_key:", self.plain_state.numpy()) for i in range(1, 10): self.plain_state = self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) - print(f"After round {i}:", self.plain_state.numpy()) self.plain_state = self.__sub_bytes(self.plain_state) - print("After final sub_bytes:", self.plain_state.numpy()) self.plain_state = self.__shift_rows(self.plain_state) - print("After final shift_rows:", self.plain_state.numpy()) self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[40:]) - print("After final add_round_key:", self.plain_state.numpy()) return matrix2text(self.plain_state) def decrypt(self, ciphertext: int) -> int: self.cipher_state = text2matrix(ciphertext) - - self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:, 40:44]) + self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[40:]) self.cipher_state = self.__inv_shift_rows(self.cipher_state) self.cipher_state = self.__inv_sub_bytes(self.cipher_state) - for i in range(9, 0, -1): - self.cipher_state = self.__round_decrypt(self.cipher_state, self.round_keys[:, 4 * i : 4 * (i + 1)]) + self.cipher_state = self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]) - self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:, :4]) + self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:4]) return matrix2text(self.cipher_state) def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: state_matrix = self.__sub_bytes(state_matrix) - print("After sub_bytes:", state_matrix.numpy()) state_matrix = self.__shift_rows(state_matrix) - print("After shift_rows:", state_matrix.numpy()) state_matrix = self.__mix_columns(state_matrix) - print("After mix_columns:", state_matrix.numpy()) - print(f'[ADD_ROUND_KEY] state_matrix: {state_matrix.numpy()}') - print(f'[ADD_ROUND_KEY] key_matrix: {key_matrix.numpy()}') state_matrix = self.__add_round_key(state_matrix, key_matrix) - print("After add_round_key:", state_matrix.numpy()) return state_matrix def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: @@ -169,7 +155,7 @@ def __shift_rows(self, s: Tensor) -> Tensor: state = s.clone() for i in range(1, 4): - state[i] = state[i].roll(-i, dims=0) + state[:, i] = state[:, i].roll(-i, dims=0) return state @@ -177,28 +163,32 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: state = s.clone() for i in range(1, 4): - state[i] = state[i].roll(i, dims=0) + state[:, i] = state[:, i].roll(i, dims=0) return state def __mix_columns(self, state: Tensor) -> Tensor: - t = state[0].xor(state[1]).xor(state[2]).xor(state[3]) + t = state[:, 0].xor(state[:, 1]).xor(state[:, 2]).xor(state[:, 3]) result = state.clone() - result[0] = state[0].xor(t).xor(xtime_tensor(state[0].xor(state[1]))) - result[1] = state[1].xor(t).xor(xtime_tensor(state[1].xor(state[2]))) - result[2] = state[2].xor(t).xor(xtime_tensor(state[2].xor(state[3]))) - result[3] = state[3].xor(t).xor(xtime_tensor(state[3].xor(state[0]))) + result[:, 0] = state[:, 0].xor(t).xor(xtime_tensor(state[:, 0].xor(state[:, 1]))) + result[:, 1] = state[:, 1].xor(t).xor(xtime_tensor(state[:, 1].xor(state[:, 2]))) + result[:, 2] = state[:, 2].xor(t).xor(xtime_tensor(state[:, 2].xor(state[:, 3]))) + result[:, 3] = state[:, 3].xor(t).xor(xtime_tensor(state[:, 3].xor(state[:, 0]))) return result - + def __inv_mix_columns(self, state: Tensor) -> Tensor: - u_v = xtime_tensor(xtime_tensor( - state[::2].xor(state[1::2]) - )).repeat_interleave(2, dim=0) + u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2]))) + v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3]))) + + out = state.clone() + out[:, 0] = state[:, 0].xor(u) + out[:, 1] = state[:, 1].xor(v) + out[:, 2] = state[:, 2].xor(u) + out[:, 3] = state[:, 3].xor(v) - state = state.xor(u_v) - return self.__mix_columns(state) + return self.__mix_columns(out) if __name__ == "__main__": aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) From 067b679410597392e4d10cf426365868bdeb9796 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 19:20:47 -0800 Subject: [PATCH 03/12] feat: working --- .gitignore | 4 +- aes.py | 275 +++++++++++++++++---------------------- bench.py | 21 +++ main.py | 196 ---------------------------- reference_aes.py | 220 +++++++++++++++++++++++++++++++ tests/__init__.py | 0 tests/test_comparison.py | 42 ++++++ tests/tests_aes.py | 50 +++++++ 8 files changed, 457 insertions(+), 351 deletions(-) create mode 100644 bench.py delete mode 100644 main.py create mode 100644 reference_aes.py create mode 100644 tests/__init__.py create mode 100644 tests/test_comparison.py create mode 100644 tests/tests_aes.py diff --git a/.gitignore b/.gitignore index eba74f4..cd39c30 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -venv/ \ No newline at end of file +.venv/ +__pycache__/ +.pytest_cache/ \ No newline at end of file diff --git a/aes.py b/aes.py index b0afc21..2c3526a 100644 --- a/aes.py +++ b/aes.py @@ -1,29 +1,7 @@ -#!/usr/bin/env python +from tinygrad.tensor import Tensor +from tinygrad import dtypes - -""" - Copyright (C) 2012 Bo Zhu http://about.bozhu.me - - Permission is hereby granted, free of charge, to any person obtaining a - copy of this software and associated documentation files (the "Software"), - to deal in the Software without restriction, including without limitation - the rights to use, copy, modify, merge, publish, distribute, sublicense, - and/or sell copies of the Software, and to permit persons to whom the - Software is furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL - THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - DEALINGS IN THE SOFTWARE. -""" - -Sbox = ( +Sbox = Tensor([ 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, @@ -40,9 +18,10 @@ 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, -) +], dtype=dtypes.uint8) -InvSbox = ( + +InvSbox = Tensor([ 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, @@ -59,160 +38,148 @@ 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, -) - - -# learnt from http://cs.ucsb.edu/~koc/cs178/projects/JT/aes.c -xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) +], dtype=dtypes.uint8) - -Rcon = ( +Rcon = Tensor([ 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, 0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, -) - - -def text2matrix(text): - matrix = [] +], dtype=dtypes.uint8) + + +def xtime(a: int) -> int: + return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) + +def xtime_tensor(a: Tensor) -> Tensor: + high_bit_mask = a.bitwise_and(0x80) + shifted = a.lshift(1) + + condition = high_bit_mask != 0 + result = condition.where( + shifted.xor(0x1B), + shifted + ) + + return result.bitwise_and(0xFF) + + +def text2matrix(text: int) -> Tensor: + return (Tensor([text >> (8 * (15 - i)) for i in range(16)], + dtype=dtypes.uint64) + .bitwise_and(0xFF) + .reshape((4, 4))) + +def matrix2text(matrix: Tensor) -> int: + flat = matrix.flatten() + result = 0 for i in range(16): - byte = (text >> (8 * (15 - i))) & 0xFF - if i % 4 == 0: - matrix.append([byte]) - else: - matrix[i // 4].append(byte) - return matrix - - -def matrix2text(matrix): - text = 0 - for i in range(4): - for j in range(4): - text |= (matrix[i][j] << (120 - 8 * (4 * i + j))) - return text - + byte = int(flat[i].item()) + result = (result << 8) | byte + return result class AES: def __init__(self, master_key): self.change_key(master_key) def change_key(self, master_key): - self.round_keys = text2matrix(master_key) - + self.round_keys = Tensor.zeros((44, 4), dtype=dtypes.uint8).contiguous() + self.round_keys[:4] = text2matrix(master_key) for i in range(4, 4 * 11): - self.round_keys.append([]) if i % 4 == 0: - byte = self.round_keys[i - 4][0] \ - ^ Sbox[self.round_keys[i - 1][1]] \ - ^ Rcon[i // 4] - self.round_keys[i].append(byte) - - for j in range(1, 4): - byte = self.round_keys[i - 4][j] \ - ^ Sbox[self.round_keys[i - 1][(j + 1) % 4]] - self.round_keys[i].append(byte) + self.round_keys[i, 0] = (self.round_keys[i-4, 0] ^ + Sbox[self.round_keys[i-1, 1].item()] ^ + Rcon[i//4].item()) + + shifted_indices = Tensor([2,3,0], dtype=dtypes.uint8) + sboxed = Sbox[self.round_keys[i-1][shifted_indices]] + self.round_keys[i, 1:] = self.round_keys[i-4, 1:].xor(sboxed) else: - for j in range(4): - byte = self.round_keys[i - 4][j] \ - ^ self.round_keys[i - 1][j] - self.round_keys[i].append(byte) - - def encrypt(self, plaintext): + self.round_keys[i] = self.round_keys[i-4].xor(self.round_keys[i-1]) + + def encrypt(self, plaintext: int) -> int: self.plain_state = text2matrix(plaintext) - - self.__add_round_key(self.plain_state, self.round_keys[:4]) + self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) for i in range(1, 10): - self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) + self.plain_state = self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) - self.__sub_bytes(self.plain_state) - self.__shift_rows(self.plain_state) - self.__add_round_key(self.plain_state, self.round_keys[40:]) + self.plain_state = self.__sub_bytes(self.plain_state) + self.plain_state = self.__shift_rows(self.plain_state) + self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[40:]) return matrix2text(self.plain_state) - def decrypt(self, ciphertext): + def decrypt(self, ciphertext: int) -> int: self.cipher_state = text2matrix(ciphertext) - - self.__add_round_key(self.cipher_state, self.round_keys[40:]) - self.__inv_shift_rows(self.cipher_state) - self.__inv_sub_bytes(self.cipher_state) - + self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[40:]) + self.cipher_state = self.__inv_shift_rows(self.cipher_state) + self.cipher_state = self.__inv_sub_bytes(self.cipher_state) for i in range(9, 0, -1): - self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]) + self.cipher_state = self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]) - self.__add_round_key(self.cipher_state, self.round_keys[:4]) + self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:4]) return matrix2text(self.cipher_state) - - def __add_round_key(self, s, k): - for i in range(4): - for j in range(4): - s[i][j] ^= k[i][j] - - - def __round_encrypt(self, state_matrix, key_matrix): - self.__sub_bytes(state_matrix) - self.__shift_rows(state_matrix) - self.__mix_columns(state_matrix) - self.__add_round_key(state_matrix, key_matrix) - - def __round_decrypt(self, state_matrix, key_matrix): - self.__add_round_key(state_matrix, key_matrix) - self.__inv_mix_columns(state_matrix) - self.__inv_shift_rows(state_matrix) - self.__inv_sub_bytes(state_matrix) - - def __sub_bytes(self, s): - for i in range(4): - for j in range(4): - s[i][j] = Sbox[s[i][j]] - - - def __inv_sub_bytes(self, s): - for i in range(4): - for j in range(4): - s[i][j] = InvSbox[s[i][j]] - - - def __shift_rows(self, s): - s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1] - s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] - s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3] - - - def __inv_shift_rows(self, s): - s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1] - s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] - s[0][3], s[1][3], s[2][3], s[3][3] = s[1][3], s[2][3], s[3][3], s[0][3] - - def __mix_single_column(self, a): - # please see Sec 4.1.2 in The Design of Rijndael - t = a[0] ^ a[1] ^ a[2] ^ a[3] - u = a[0] - a[0] ^= t ^ xtime(a[0] ^ a[1]) - a[1] ^= t ^ xtime(a[1] ^ a[2]) - a[2] ^= t ^ xtime(a[2] ^ a[3]) - a[3] ^= t ^ xtime(a[3] ^ u) - - - def __mix_columns(self, s): - for i in range(4): - self.__mix_single_column(s[i]) - - - def __inv_mix_columns(self, s): - for i in range(4): - u = xtime(xtime(s[i][0] ^ s[i][2])) - v = xtime(xtime(s[i][1] ^ s[i][3])) - s[i][0] ^= u - s[i][1] ^= v - s[i][2] ^= u - s[i][3] ^= v - - self.__mix_columns(s) - + + + def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: + state_matrix = self.__sub_bytes(state_matrix) + state_matrix = self.__shift_rows(state_matrix) + state_matrix = self.__mix_columns(state_matrix) + state_matrix = self.__add_round_key(state_matrix, key_matrix) + return state_matrix + + def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: + state_matrix = self.__add_round_key(state_matrix, key_matrix) + state_matrix = self.__inv_mix_columns(state_matrix) + state_matrix = self.__inv_shift_rows(state_matrix) + state_matrix = self.__inv_sub_bytes(state_matrix) + return state_matrix + + def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: + return s.xor(k) + + def __sub_bytes(self, s: Tensor) -> Tensor: + return Sbox[s] + + def __inv_sub_bytes(self, s: Tensor) -> Tensor: + return InvSbox[s] + + def __shift_rows(self, s: Tensor) -> Tensor: + state = s.clone() + + for i in range(1, 4): + state[:, i] = state[:, i].roll(-i, dims=0) + + return state + + def __inv_shift_rows(self, s: Tensor) -> Tensor: + state = s.clone() + + for i in range(1, 4): + state[:, i] = state[:, i].roll(i, dims=0) + + return state + + def __mix_columns(self, state: Tensor) -> Tensor: + t = state[:, 0].xor(state[:, 1]).xor(state[:, 2]).xor(state[:, 3]) + xtimes = xtime_tensor(state.roll(-1, dims=1).xor(state)) + state = state.xor(t.unsqueeze(1)).xor(xtimes) + + return state + + def __inv_mix_columns(self, state: Tensor) -> Tensor: + u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2]))) + v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3]))) + + out = state.clone() + out[:, 0] = state[:, 0].xor(u) + out[:, 1] = state[:, 1].xor(v) + out[:, 2] = state[:, 2].xor(u) + out[:, 3] = state[:, 3].xor(v) + + return self.__mix_columns(out) if __name__ == "__main__": aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) diff --git a/bench.py b/bench.py new file mode 100644 index 0000000..a161e1b --- /dev/null +++ b/bench.py @@ -0,0 +1,21 @@ +import pytest +from aes import AES as TinyGradAES +from reference_aes import AES as ReferenceAES + +@pytest.mark.parametrize("num_ops", [2, 4, 8]) +@pytest.mark.parametrize("aes_class", [TinyGradAES, ReferenceAES], ids=["TinyGradAES", "ReferenceAES"]) +def test_aes_performance(benchmark, aes_class, num_ops): + key = 0x2b7e151628aed2a6abf7158809cf4f3c + data = 0x3243f6a8885a308d313198a2e0370734 + + aes = aes_class(key) + + def aes_ops(): + for _ in range(num_ops): + c = aes.encrypt(data) + p = aes.decrypt(c) + + benchmark.pedantic(aes_ops, rounds=1, iterations=1) + +if __name__ == "__main__": + pytest.main([__file__, "--benchmark-only"]) diff --git a/main.py b/main.py deleted file mode 100644 index ce84ea2..0000000 --- a/main.py +++ /dev/null @@ -1,196 +0,0 @@ -from tinygrad.tensor import Tensor -from tinygrad import dtypes - -Sbox = Tensor([ - 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, - 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, - 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, - 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, - 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, - 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, - 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, - 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, - 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, - 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, - 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, - 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, - 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, - 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, - 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, - 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, -], dtype=dtypes.uint8) - - -InvSbox = Tensor([ - 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, - 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, - 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, - 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25, - 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92, - 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84, - 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06, - 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B, - 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73, - 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E, - 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B, - 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4, - 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F, - 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, - 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, - 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, -], dtype=dtypes.uint8) - -Rcon = Tensor([ - 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, - 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, - 0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, - 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, -], dtype=dtypes.uint8) - - -def xtime(a: int) -> int: - return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) - -def xtime_tensor(a: Tensor) -> Tensor: - # Get high bit mask (equivalent to a & 0x80) - high_bit_mask = a.bitwise_and(0x80) - - # Compute left shift for all values - shifted = a.lshift(1) - - # Create the conditional xor with 0x1B - condition = high_bit_mask != 0 - result = condition.where( - shifted.xor(0x1B), # when high bit is set - shifted # when high bit is not set - ) - - # Mask to keep only bottom 8 bits - return result.bitwise_and(0xFF) - - -def text2matrix(text: int) -> Tensor: - return (Tensor([text >> (8 * (15 - i)) for i in range(16)], - dtype=dtypes.uint64) - .bitwise_and(0xFF) - .reshape((4, 4))) - -def matrix2text(matrix: Tensor) -> int: - flat = matrix.flatten() - result = 0 - for i in range(16): - byte = int(flat[i].item()) - result = (result << 8) | byte - return result - -class AES: - def __init__(self, master_key): - self.change_key(master_key) - - def change_key(self, master_key): - self.round_keys = Tensor.zeros((44, 4), dtype=dtypes.uint8).contiguous() - self.round_keys[:4] = text2matrix(master_key) - for i in range(4, 4 * 11): - if i % 4 == 0: - self.round_keys[i, 0] = (self.round_keys[i-4, 0] ^ - Sbox[self.round_keys[i-1, 1].item()] ^ - Rcon[i//4].item()) - - shifted_indices = Tensor([2,3,0], dtype=dtypes.uint8) - sboxed = Sbox[self.round_keys[i-1][shifted_indices]] - self.round_keys[i, 1:] = self.round_keys[i-4, 1:].xor(sboxed) - else: - self.round_keys[i] = self.round_keys[i-4].xor(self.round_keys[i-1]) - - def encrypt(self, plaintext: int) -> int: - self.plain_state = text2matrix(plaintext) - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) - - for i in range(1, 10): - self.plain_state = self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) - - self.plain_state = self.__sub_bytes(self.plain_state) - self.plain_state = self.__shift_rows(self.plain_state) - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[40:]) - - return matrix2text(self.plain_state) - - def decrypt(self, ciphertext: int) -> int: - self.cipher_state = text2matrix(ciphertext) - self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[40:]) - self.cipher_state = self.__inv_shift_rows(self.cipher_state) - self.cipher_state = self.__inv_sub_bytes(self.cipher_state) - for i in range(9, 0, -1): - self.cipher_state = self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]) - - self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:4]) - - return matrix2text(self.cipher_state) - - - def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: - state_matrix = self.__sub_bytes(state_matrix) - state_matrix = self.__shift_rows(state_matrix) - state_matrix = self.__mix_columns(state_matrix) - state_matrix = self.__add_round_key(state_matrix, key_matrix) - return state_matrix - - def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: - state_matrix = self.__add_round_key(state_matrix, key_matrix) - state_matrix = self.__inv_mix_columns(state_matrix) - state_matrix = self.__inv_shift_rows(state_matrix) - state_matrix = self.__inv_sub_bytes(state_matrix) - return state_matrix - - def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: - return s.xor(k) - - def __sub_bytes(self, s: Tensor) -> Tensor: - return Sbox[s] - - def __inv_sub_bytes(self, s: Tensor) -> Tensor: - return InvSbox[s] - - def __shift_rows(self, s: Tensor) -> Tensor: - state = s.clone() - - for i in range(1, 4): - state[:, i] = state[:, i].roll(-i, dims=0) - - return state - - def __inv_shift_rows(self, s: Tensor) -> Tensor: - state = s.clone() - - for i in range(1, 4): - state[:, i] = state[:, i].roll(i, dims=0) - - return state - - def __mix_columns(self, state: Tensor) -> Tensor: - t = state[:, 0].xor(state[:, 1]).xor(state[:, 2]).xor(state[:, 3]) - result = state.clone() - - result[:, 0] = state[:, 0].xor(t).xor(xtime_tensor(state[:, 0].xor(state[:, 1]))) - result[:, 1] = state[:, 1].xor(t).xor(xtime_tensor(state[:, 1].xor(state[:, 2]))) - result[:, 2] = state[:, 2].xor(t).xor(xtime_tensor(state[:, 2].xor(state[:, 3]))) - result[:, 3] = state[:, 3].xor(t).xor(xtime_tensor(state[:, 3].xor(state[:, 0]))) - - return result - - def __inv_mix_columns(self, state: Tensor) -> Tensor: - u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2]))) - v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3]))) - - out = state.clone() - out[:, 0] = state[:, 0].xor(u) - out[:, 1] = state[:, 1].xor(v) - out[:, 2] = state[:, 2].xor(u) - out[:, 3] = state[:, 3].xor(v) - - return self.__mix_columns(out) - -if __name__ == "__main__": - aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) - print(hex(aes.encrypt(0x3243f6a8885a308d313198a2e0370734))) - print(hex(aes.decrypt(0x3925841d02dc09fbdc118597196a0b32))) diff --git a/reference_aes.py b/reference_aes.py new file mode 100644 index 0000000..b0afc21 --- /dev/null +++ b/reference_aes.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python + + +""" + Copyright (C) 2012 Bo Zhu http://about.bozhu.me + + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. +""" + +Sbox = ( + 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, + 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, + 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, + 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, + 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, + 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, + 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, + 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, + 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, + 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, + 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, + 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, + 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, + 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, + 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, + 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, +) + +InvSbox = ( + 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, + 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, + 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, + 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25, + 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92, + 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84, + 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06, + 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B, + 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73, + 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E, + 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B, + 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4, + 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F, + 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, + 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, + 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, +) + + +# learnt from http://cs.ucsb.edu/~koc/cs178/projects/JT/aes.c +xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) + + +Rcon = ( + 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, + 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, + 0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, + 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, +) + + +def text2matrix(text): + matrix = [] + for i in range(16): + byte = (text >> (8 * (15 - i))) & 0xFF + if i % 4 == 0: + matrix.append([byte]) + else: + matrix[i // 4].append(byte) + return matrix + + +def matrix2text(matrix): + text = 0 + for i in range(4): + for j in range(4): + text |= (matrix[i][j] << (120 - 8 * (4 * i + j))) + return text + + +class AES: + def __init__(self, master_key): + self.change_key(master_key) + + def change_key(self, master_key): + self.round_keys = text2matrix(master_key) + + for i in range(4, 4 * 11): + self.round_keys.append([]) + if i % 4 == 0: + byte = self.round_keys[i - 4][0] \ + ^ Sbox[self.round_keys[i - 1][1]] \ + ^ Rcon[i // 4] + self.round_keys[i].append(byte) + + for j in range(1, 4): + byte = self.round_keys[i - 4][j] \ + ^ Sbox[self.round_keys[i - 1][(j + 1) % 4]] + self.round_keys[i].append(byte) + else: + for j in range(4): + byte = self.round_keys[i - 4][j] \ + ^ self.round_keys[i - 1][j] + self.round_keys[i].append(byte) + + def encrypt(self, plaintext): + self.plain_state = text2matrix(plaintext) + + self.__add_round_key(self.plain_state, self.round_keys[:4]) + + for i in range(1, 10): + self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) + + self.__sub_bytes(self.plain_state) + self.__shift_rows(self.plain_state) + self.__add_round_key(self.plain_state, self.round_keys[40:]) + + return matrix2text(self.plain_state) + + def decrypt(self, ciphertext): + self.cipher_state = text2matrix(ciphertext) + + self.__add_round_key(self.cipher_state, self.round_keys[40:]) + self.__inv_shift_rows(self.cipher_state) + self.__inv_sub_bytes(self.cipher_state) + + for i in range(9, 0, -1): + self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]) + + self.__add_round_key(self.cipher_state, self.round_keys[:4]) + + return matrix2text(self.cipher_state) + + def __add_round_key(self, s, k): + for i in range(4): + for j in range(4): + s[i][j] ^= k[i][j] + + + def __round_encrypt(self, state_matrix, key_matrix): + self.__sub_bytes(state_matrix) + self.__shift_rows(state_matrix) + self.__mix_columns(state_matrix) + self.__add_round_key(state_matrix, key_matrix) + + def __round_decrypt(self, state_matrix, key_matrix): + self.__add_round_key(state_matrix, key_matrix) + self.__inv_mix_columns(state_matrix) + self.__inv_shift_rows(state_matrix) + self.__inv_sub_bytes(state_matrix) + + def __sub_bytes(self, s): + for i in range(4): + for j in range(4): + s[i][j] = Sbox[s[i][j]] + + + def __inv_sub_bytes(self, s): + for i in range(4): + for j in range(4): + s[i][j] = InvSbox[s[i][j]] + + + def __shift_rows(self, s): + s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1] + s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] + s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3] + + + def __inv_shift_rows(self, s): + s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1] + s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] + s[0][3], s[1][3], s[2][3], s[3][3] = s[1][3], s[2][3], s[3][3], s[0][3] + + def __mix_single_column(self, a): + # please see Sec 4.1.2 in The Design of Rijndael + t = a[0] ^ a[1] ^ a[2] ^ a[3] + u = a[0] + a[0] ^= t ^ xtime(a[0] ^ a[1]) + a[1] ^= t ^ xtime(a[1] ^ a[2]) + a[2] ^= t ^ xtime(a[2] ^ a[3]) + a[3] ^= t ^ xtime(a[3] ^ u) + + + def __mix_columns(self, s): + for i in range(4): + self.__mix_single_column(s[i]) + + + def __inv_mix_columns(self, s): + for i in range(4): + u = xtime(xtime(s[i][0] ^ s[i][2])) + v = xtime(xtime(s[i][1] ^ s[i][3])) + s[i][0] ^= u + s[i][1] ^= v + s[i][2] ^= u + s[i][3] ^= v + + self.__mix_columns(s) + + +if __name__ == "__main__": + aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) + print(hex(aes.encrypt(0x3243f6a8885a308d313198a2e0370734))) + print(hex(aes.decrypt(0x3925841d02dc09fbdc118597196a0b32))) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_comparison.py b/tests/test_comparison.py new file mode 100644 index 0000000..3c115b1 --- /dev/null +++ b/tests/test_comparison.py @@ -0,0 +1,42 @@ +import pytest +import random +from reference_aes import AES as ReferenceAES +from aes import AES as TinyGradAES + +def generate_random_128bit(): + return random.getrandbits(128) + +class TestAESImplementations: + def setup_method(self): + self.key = generate_random_128bit() + self.plaintext = generate_random_128bit() + self.ref_aes = ReferenceAES(self.key) + self.tiny_aes = TinyGradAES(self.key) + + def test_encryption_comparison(self): + ref_ciphertext = self.ref_aes.encrypt(self.plaintext) + tiny_ciphertext = self.tiny_aes.encrypt(self.plaintext) + assert ref_ciphertext == tiny_ciphertext + + def test_decryption_comparison(self): + ref_ciphertext = self.ref_aes.encrypt(self.plaintext) + tiny_decrypted = self.tiny_aes.decrypt(ref_ciphertext) + assert tiny_decrypted == self.plaintext + + tiny_ciphertext = self.tiny_aes.encrypt(self.plaintext) + ref_decrypted = self.ref_aes.decrypt(tiny_ciphertext) + assert ref_decrypted == self.plaintext + + def test_multiple_random_values(self): + for _ in range(5): + plaintext = generate_random_128bit() + ref_ciphertext = self.ref_aes.encrypt(plaintext) + tiny_ciphertext = self.tiny_aes.encrypt(plaintext) + assert ref_ciphertext == tiny_ciphertext + + ref_decrypted = self.ref_aes.decrypt(tiny_ciphertext) + tiny_decrypted = self.tiny_aes.decrypt(ref_ciphertext) + assert ref_decrypted == plaintext and tiny_decrypted == plaintext + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/tests_aes.py b/tests/tests_aes.py new file mode 100644 index 0000000..b3de766 --- /dev/null +++ b/tests/tests_aes.py @@ -0,0 +1,50 @@ + + +import pytest +from aes import AES + +class TestAESEncryptDecrypt: + def setup_method(self): + # Standard test vector from NIST FIPS-197 + self.key = 0x2b7e151628aed2a6abf7158809cf4f3c + self.plaintext = 0x3243f6a8885a308d313198a2e0370734 + self.ciphertext = 0x3925841d02dc09fbdc118597196a0b32 + self.aes = AES(self.key) + + def test_encrypt(self): + """Test basic encryption with NIST test vector""" + encrypted = self.aes.encrypt(self.plaintext) + assert encrypted == self.ciphertext + + def test_decrypt(self): + """Test basic decryption with NIST test vector""" + decrypted = self.aes.decrypt(self.ciphertext) + assert decrypted == self.plaintext + + def test_encrypt_decrypt_roundtrip(self): + """Test that encryption followed by decryption returns the original plaintext""" + encrypted = self.aes.encrypt(self.plaintext) + decrypted = self.aes.decrypt(encrypted) + assert decrypted == self.plaintext + + @pytest.mark.parametrize("plaintext,key,ciphertext", [ + # NIST SP 800-38A test vectors + (0x6bc1bee22e409f96e93d7e117393172a, + 0x2b7e151628aed2a6abf7158809cf4f3c, + 0x3ad77bb40d7a3660a89ecaf32466ef97), + (0xae2d8a571e03ac9c9eb76fac45af8e51, + 0x2b7e151628aed2a6abf7158809cf4f3c, + 0xf5d3d58503b9699de785895a96fdbaaf), + # Additional test with all zeros + (0x00000000000000000000000000000000, + 0x00000000000000000000000000000000, + 0x66e94bd4ef8a2c3b884cfa59ca342b2e) + ]) + def test_known_vectors(self, plaintext, key, ciphertext): + """Test encryption and decryption with additional known test vectors""" + aes = AES(key) + assert aes.encrypt(plaintext) == ciphertext + assert aes.decrypt(ciphertext) == plaintext + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file From 990b8c7b39ce0e7805f02ef9bbb54ed6e139e499 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 19:27:49 -0800 Subject: [PATCH 04/12] feat: black, testing, uv --- .github/workflows/test.yml | 32 +++ .python-version | 1 + aes.py | 127 ++++----- bench.py | 14 +- constants.py | 553 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 17 ++ reference_aes.py | 121 +++----- tests/test_comparison.py | 9 +- tests/tests_aes.py | 47 ++-- uv.lock | 181 ++++++++++++ 10 files changed, 911 insertions(+), 191 deletions(-) create mode 100644 .github/workflows/test.yml create mode 100644 .python-version create mode 100644 constants.py create mode 100644 pyproject.toml create mode 100644 uv.lock diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a378b87 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,32 @@ +name: Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies using uv + run: | + uv pip install + + - name: Run tests with uv + run: | + uv venv exec pytest tests/ diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..2c07333 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/aes.py b/aes.py index 2c3526a..869805e 100644 --- a/aes.py +++ b/aes.py @@ -1,74 +1,32 @@ from tinygrad.tensor import Tensor from tinygrad import dtypes +from constants import Sbox as Sbox_const, InvSbox as InvSbox_const, Rcon as Rcon_const -Sbox = Tensor([ - 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, - 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, - 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, - 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, - 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, - 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, - 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, - 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, - 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, - 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, - 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, - 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, - 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, - 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, - 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, - 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, -], dtype=dtypes.uint8) - - -InvSbox = Tensor([ - 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, - 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, - 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, - 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25, - 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92, - 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84, - 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06, - 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B, - 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73, - 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E, - 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B, - 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4, - 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F, - 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, - 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, - 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, -], dtype=dtypes.uint8) - -Rcon = Tensor([ - 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, - 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, - 0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, - 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, -], dtype=dtypes.uint8) - +Sbox = Tensor(Sbox_const, dtype=dtypes.uint8) +InvSbox = Tensor(InvSbox_const, dtype=dtypes.uint8) +Rcon = Tensor(Rcon_const, dtype=dtypes.uint8) def xtime(a: int) -> int: return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) + def xtime_tensor(a: Tensor) -> Tensor: high_bit_mask = a.bitwise_and(0x80) shifted = a.lshift(1) - + condition = high_bit_mask != 0 - result = condition.where( - shifted.xor(0x1B), - shifted - ) - + result = condition.where(shifted.xor(0x1B), shifted) + return result.bitwise_and(0xFF) def text2matrix(text: int) -> Tensor: - return (Tensor([text >> (8 * (15 - i)) for i in range(16)], - dtype=dtypes.uint64) - .bitwise_and(0xFF) - .reshape((4, 4))) + return ( + Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint64) + .bitwise_and(0xFF) + .reshape((4, 4)) + ) + def matrix2text(matrix: Tensor) -> int: flat = matrix.flatten() @@ -78,6 +36,7 @@ def matrix2text(matrix: Tensor) -> int: result = (result << 8) | byte return result + class AES: def __init__(self, master_key): self.change_key(master_key) @@ -87,22 +46,26 @@ def change_key(self, master_key): self.round_keys[:4] = text2matrix(master_key) for i in range(4, 4 * 11): if i % 4 == 0: - self.round_keys[i, 0] = (self.round_keys[i-4, 0] ^ - Sbox[self.round_keys[i-1, 1].item()] ^ - Rcon[i//4].item()) - - shifted_indices = Tensor([2,3,0], dtype=dtypes.uint8) - sboxed = Sbox[self.round_keys[i-1][shifted_indices]] - self.round_keys[i, 1:] = self.round_keys[i-4, 1:].xor(sboxed) + self.round_keys[i, 0] = ( + self.round_keys[i - 4, 0] + ^ Sbox[self.round_keys[i - 1, 1].item()] + ^ Rcon[i // 4].item() + ) + + shifted_indices = Tensor([2, 3, 0], dtype=dtypes.uint8) + sboxed = Sbox[self.round_keys[i - 1][shifted_indices]] + self.round_keys[i, 1:] = self.round_keys[i - 4, 1:].xor(sboxed) else: - self.round_keys[i] = self.round_keys[i-4].xor(self.round_keys[i-1]) - + self.round_keys[i] = self.round_keys[i - 4].xor(self.round_keys[i - 1]) + def encrypt(self, plaintext: int) -> int: self.plain_state = text2matrix(plaintext) self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) for i in range(1, 10): - self.plain_state = self.__round_encrypt(self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]) + self.plain_state = self.__round_encrypt( + self.plain_state, self.round_keys[4 * i : 4 * (i + 1)] + ) self.plain_state = self.__sub_bytes(self.plain_state) self.plain_state = self.__shift_rows(self.plain_state) @@ -112,16 +75,19 @@ def encrypt(self, plaintext: int) -> int: def decrypt(self, ciphertext: int) -> int: self.cipher_state = text2matrix(ciphertext) - self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[40:]) + self.cipher_state = self.__add_round_key( + self.cipher_state, self.round_keys[40:] + ) self.cipher_state = self.__inv_shift_rows(self.cipher_state) self.cipher_state = self.__inv_sub_bytes(self.cipher_state) for i in range(9, 0, -1): - self.cipher_state = self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]) + self.cipher_state = self.__round_decrypt( + self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)] + ) self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:4]) return matrix2text(self.cipher_state) - def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: state_matrix = self.__sub_bytes(state_matrix) @@ -139,19 +105,19 @@ def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: return s.xor(k) - + def __sub_bytes(self, s: Tensor) -> Tensor: return Sbox[s] - + def __inv_sub_bytes(self, s: Tensor) -> Tensor: return InvSbox[s] - + def __shift_rows(self, s: Tensor) -> Tensor: state = s.clone() for i in range(1, 4): state[:, i] = state[:, i].roll(-i, dims=0) - + return state def __inv_shift_rows(self, s: Tensor) -> Tensor: @@ -159,7 +125,7 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: for i in range(1, 4): state[:, i] = state[:, i].roll(i, dims=0) - + return state def __mix_columns(self, state: Tensor) -> Tensor: @@ -168,20 +134,21 @@ def __mix_columns(self, state: Tensor) -> Tensor: state = state.xor(t.unsqueeze(1)).xor(xtimes) return state - + def __inv_mix_columns(self, state: Tensor) -> Tensor: u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2]))) v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3]))) - + out = state.clone() out[:, 0] = state[:, 0].xor(u) out[:, 1] = state[:, 1].xor(v) out[:, 2] = state[:, 2].xor(u) out[:, 3] = state[:, 3].xor(v) - + return self.__mix_columns(out) + if __name__ == "__main__": - aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) - print(hex(aes.encrypt(0x3243f6a8885a308d313198a2e0370734))) - print(hex(aes.decrypt(0x3925841d02dc09fbdc118597196a0b32))) + aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C) + print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734))) + print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32))) diff --git a/bench.py b/bench.py index a161e1b..f4177b1 100644 --- a/bench.py +++ b/bench.py @@ -2,20 +2,24 @@ from aes import AES as TinyGradAES from reference_aes import AES as ReferenceAES + @pytest.mark.parametrize("num_ops", [2, 4, 8]) -@pytest.mark.parametrize("aes_class", [TinyGradAES, ReferenceAES], ids=["TinyGradAES", "ReferenceAES"]) +@pytest.mark.parametrize( + "aes_class", [TinyGradAES, ReferenceAES], ids=["TinyGradAES", "ReferenceAES"] +) def test_aes_performance(benchmark, aes_class, num_ops): - key = 0x2b7e151628aed2a6abf7158809cf4f3c - data = 0x3243f6a8885a308d313198a2e0370734 - + key = 0x2B7E151628AED2A6ABF7158809CF4F3C + data = 0x3243F6A8885A308D313198A2E0370734 + aes = aes_class(key) def aes_ops(): for _ in range(num_ops): c = aes.encrypt(data) p = aes.decrypt(c) - + benchmark.pedantic(aes_ops, rounds=1, iterations=1) + if __name__ == "__main__": pytest.main([__file__, "--benchmark-only"]) diff --git a/constants.py b/constants.py new file mode 100644 index 0000000..28adf8e --- /dev/null +++ b/constants.py @@ -0,0 +1,553 @@ +Sbox = ( + 0x63, + 0x7C, + 0x77, + 0x7B, + 0xF2, + 0x6B, + 0x6F, + 0xC5, + 0x30, + 0x01, + 0x67, + 0x2B, + 0xFE, + 0xD7, + 0xAB, + 0x76, + 0xCA, + 0x82, + 0xC9, + 0x7D, + 0xFA, + 0x59, + 0x47, + 0xF0, + 0xAD, + 0xD4, + 0xA2, + 0xAF, + 0x9C, + 0xA4, + 0x72, + 0xC0, + 0xB7, + 0xFD, + 0x93, + 0x26, + 0x36, + 0x3F, + 0xF7, + 0xCC, + 0x34, + 0xA5, + 0xE5, + 0xF1, + 0x71, + 0xD8, + 0x31, + 0x15, + 0x04, + 0xC7, + 0x23, + 0xC3, + 0x18, + 0x96, + 0x05, + 0x9A, + 0x07, + 0x12, + 0x80, + 0xE2, + 0xEB, + 0x27, + 0xB2, + 0x75, + 0x09, + 0x83, + 0x2C, + 0x1A, + 0x1B, + 0x6E, + 0x5A, + 0xA0, + 0x52, + 0x3B, + 0xD6, + 0xB3, + 0x29, + 0xE3, + 0x2F, + 0x84, + 0x53, + 0xD1, + 0x00, + 0xED, + 0x20, + 0xFC, + 0xB1, + 0x5B, + 0x6A, + 0xCB, + 0xBE, + 0x39, + 0x4A, + 0x4C, + 0x58, + 0xCF, + 0xD0, + 0xEF, + 0xAA, + 0xFB, + 0x43, + 0x4D, + 0x33, + 0x85, + 0x45, + 0xF9, + 0x02, + 0x7F, + 0x50, + 0x3C, + 0x9F, + 0xA8, + 0x51, + 0xA3, + 0x40, + 0x8F, + 0x92, + 0x9D, + 0x38, + 0xF5, + 0xBC, + 0xB6, + 0xDA, + 0x21, + 0x10, + 0xFF, + 0xF3, + 0xD2, + 0xCD, + 0x0C, + 0x13, + 0xEC, + 0x5F, + 0x97, + 0x44, + 0x17, + 0xC4, + 0xA7, + 0x7E, + 0x3D, + 0x64, + 0x5D, + 0x19, + 0x73, + 0x60, + 0x81, + 0x4F, + 0xDC, + 0x22, + 0x2A, + 0x90, + 0x88, + 0x46, + 0xEE, + 0xB8, + 0x14, + 0xDE, + 0x5E, + 0x0B, + 0xDB, + 0xE0, + 0x32, + 0x3A, + 0x0A, + 0x49, + 0x06, + 0x24, + 0x5C, + 0xC2, + 0xD3, + 0xAC, + 0x62, + 0x91, + 0x95, + 0xE4, + 0x79, + 0xE7, + 0xC8, + 0x37, + 0x6D, + 0x8D, + 0xD5, + 0x4E, + 0xA9, + 0x6C, + 0x56, + 0xF4, + 0xEA, + 0x65, + 0x7A, + 0xAE, + 0x08, + 0xBA, + 0x78, + 0x25, + 0x2E, + 0x1C, + 0xA6, + 0xB4, + 0xC6, + 0xE8, + 0xDD, + 0x74, + 0x1F, + 0x4B, + 0xBD, + 0x8B, + 0x8A, + 0x70, + 0x3E, + 0xB5, + 0x66, + 0x48, + 0x03, + 0xF6, + 0x0E, + 0x61, + 0x35, + 0x57, + 0xB9, + 0x86, + 0xC1, + 0x1D, + 0x9E, + 0xE1, + 0xF8, + 0x98, + 0x11, + 0x69, + 0xD9, + 0x8E, + 0x94, + 0x9B, + 0x1E, + 0x87, + 0xE9, + 0xCE, + 0x55, + 0x28, + 0xDF, + 0x8C, + 0xA1, + 0x89, + 0x0D, + 0xBF, + 0xE6, + 0x42, + 0x68, + 0x41, + 0x99, + 0x2D, + 0x0F, + 0xB0, + 0x54, + 0xBB, + 0x16, +) + +InvSbox = ( + 0x52, + 0x09, + 0x6A, + 0xD5, + 0x30, + 0x36, + 0xA5, + 0x38, + 0xBF, + 0x40, + 0xA3, + 0x9E, + 0x81, + 0xF3, + 0xD7, + 0xFB, + 0x7C, + 0xE3, + 0x39, + 0x82, + 0x9B, + 0x2F, + 0xFF, + 0x87, + 0x34, + 0x8E, + 0x43, + 0x44, + 0xC4, + 0xDE, + 0xE9, + 0xCB, + 0x54, + 0x7B, + 0x94, + 0x32, + 0xA6, + 0xC2, + 0x23, + 0x3D, + 0xEE, + 0x4C, + 0x95, + 0x0B, + 0x42, + 0xFA, + 0xC3, + 0x4E, + 0x08, + 0x2E, + 0xA1, + 0x66, + 0x28, + 0xD9, + 0x24, + 0xB2, + 0x76, + 0x5B, + 0xA2, + 0x49, + 0x6D, + 0x8B, + 0xD1, + 0x25, + 0x72, + 0xF8, + 0xF6, + 0x64, + 0x86, + 0x68, + 0x98, + 0x16, + 0xD4, + 0xA4, + 0x5C, + 0xCC, + 0x5D, + 0x65, + 0xB6, + 0x92, + 0x6C, + 0x70, + 0x48, + 0x50, + 0xFD, + 0xED, + 0xB9, + 0xDA, + 0x5E, + 0x15, + 0x46, + 0x57, + 0xA7, + 0x8D, + 0x9D, + 0x84, + 0x90, + 0xD8, + 0xAB, + 0x00, + 0x8C, + 0xBC, + 0xD3, + 0x0A, + 0xF7, + 0xE4, + 0x58, + 0x05, + 0xB8, + 0xB3, + 0x45, + 0x06, + 0xD0, + 0x2C, + 0x1E, + 0x8F, + 0xCA, + 0x3F, + 0x0F, + 0x02, + 0xC1, + 0xAF, + 0xBD, + 0x03, + 0x01, + 0x13, + 0x8A, + 0x6B, + 0x3A, + 0x91, + 0x11, + 0x41, + 0x4F, + 0x67, + 0xDC, + 0xEA, + 0x97, + 0xF2, + 0xCF, + 0xCE, + 0xF0, + 0xB4, + 0xE6, + 0x73, + 0x96, + 0xAC, + 0x74, + 0x22, + 0xE7, + 0xAD, + 0x35, + 0x85, + 0xE2, + 0xF9, + 0x37, + 0xE8, + 0x1C, + 0x75, + 0xDF, + 0x6E, + 0x47, + 0xF1, + 0x1A, + 0x71, + 0x1D, + 0x29, + 0xC5, + 0x89, + 0x6F, + 0xB7, + 0x62, + 0x0E, + 0xAA, + 0x18, + 0xBE, + 0x1B, + 0xFC, + 0x56, + 0x3E, + 0x4B, + 0xC6, + 0xD2, + 0x79, + 0x20, + 0x9A, + 0xDB, + 0xC0, + 0xFE, + 0x78, + 0xCD, + 0x5A, + 0xF4, + 0x1F, + 0xDD, + 0xA8, + 0x33, + 0x88, + 0x07, + 0xC7, + 0x31, + 0xB1, + 0x12, + 0x10, + 0x59, + 0x27, + 0x80, + 0xEC, + 0x5F, + 0x60, + 0x51, + 0x7F, + 0xA9, + 0x19, + 0xB5, + 0x4A, + 0x0D, + 0x2D, + 0xE5, + 0x7A, + 0x9F, + 0x93, + 0xC9, + 0x9C, + 0xEF, + 0xA0, + 0xE0, + 0x3B, + 0x4D, + 0xAE, + 0x2A, + 0xF5, + 0xB0, + 0xC8, + 0xEB, + 0xBB, + 0x3C, + 0x83, + 0x53, + 0x99, + 0x61, + 0x17, + 0x2B, + 0x04, + 0x7E, + 0xBA, + 0x77, + 0xD6, + 0x26, + 0xE1, + 0x69, + 0x14, + 0x63, + 0x55, + 0x21, + 0x0C, + 0x7D, +) + + +Rcon = ( + 0x00, + 0x01, + 0x02, + 0x04, + 0x08, + 0x10, + 0x20, + 0x40, + 0x80, + 0x1B, + 0x36, + 0x6C, + 0xD8, + 0xAB, + 0x4D, + 0x9A, + 0x2F, + 0x5E, + 0xBC, + 0x63, + 0xC6, + 0x97, + 0x35, + 0x6A, + 0xD4, + 0xB3, + 0x7D, + 0xFA, + 0xEF, + 0xC5, + 0x91, + 0x39, +) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ef33b54 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[project] +name = "aes256" +version = "0.1.0" +description = "AES-256 implementation in TinyGrad" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "pytest>=8.3.4", + "pytest-benchmark>=5.1.0", + "tinygrad>=0.10.2", +] + +[dependency-groups] +dev = [ + "black>=25.1.0", + "pytest>=8.3.4", +] diff --git a/reference_aes.py b/reference_aes.py index b0afc21..2fd8f21 100644 --- a/reference_aes.py +++ b/reference_aes.py @@ -1,79 +1,33 @@ #!/usr/bin/env python +from constants import Sbox, InvSbox, Rcon """ - Copyright (C) 2012 Bo Zhu http://about.bozhu.me - - Permission is hereby granted, free of charge, to any person obtaining a - copy of this software and associated documentation files (the "Software"), - to deal in the Software without restriction, including without limitation - the rights to use, copy, modify, merge, publish, distribute, sublicense, - and/or sell copies of the Software, and to permit persons to whom the - Software is furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL - THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - DEALINGS IN THE SOFTWARE. +Copyright (C) 2012 Bo Zhu http://about.bozhu.me + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. """ -Sbox = ( - 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, - 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, - 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, - 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, - 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, - 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, - 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, - 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, - 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, - 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, - 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, - 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, - 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, - 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, - 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, - 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, -) - -InvSbox = ( - 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, - 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, - 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, - 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25, - 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92, - 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84, - 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06, - 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B, - 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73, - 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E, - 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B, - 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4, - 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F, - 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, - 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, - 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, -) - - # learnt from http://cs.ucsb.edu/~koc/cs178/projects/JT/aes.c xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) -Rcon = ( - 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, - 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, - 0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, - 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, -) - - def text2matrix(text): matrix = [] for i in range(16): @@ -89,7 +43,7 @@ def matrix2text(matrix): text = 0 for i in range(4): for j in range(4): - text |= (matrix[i][j] << (120 - 8 * (4 * i + j))) + text |= matrix[i][j] << (120 - 8 * (4 * i + j)) return text @@ -103,19 +57,22 @@ def change_key(self, master_key): for i in range(4, 4 * 11): self.round_keys.append([]) if i % 4 == 0: - byte = self.round_keys[i - 4][0] \ - ^ Sbox[self.round_keys[i - 1][1]] \ - ^ Rcon[i // 4] + byte = ( + self.round_keys[i - 4][0] + ^ Sbox[self.round_keys[i - 1][1]] + ^ Rcon[i // 4] + ) self.round_keys[i].append(byte) for j in range(1, 4): - byte = self.round_keys[i - 4][j] \ - ^ Sbox[self.round_keys[i - 1][(j + 1) % 4]] + byte = ( + self.round_keys[i - 4][j] + ^ Sbox[self.round_keys[i - 1][(j + 1) % 4]] + ) self.round_keys[i].append(byte) else: for j in range(4): - byte = self.round_keys[i - 4][j] \ - ^ self.round_keys[i - 1][j] + byte = self.round_keys[i - 4][j] ^ self.round_keys[i - 1][j] self.round_keys[i].append(byte) def encrypt(self, plaintext): @@ -140,7 +97,9 @@ def decrypt(self, ciphertext): self.__inv_sub_bytes(self.cipher_state) for i in range(9, 0, -1): - self.__round_decrypt(self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]) + self.__round_decrypt( + self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)] + ) self.__add_round_key(self.cipher_state, self.round_keys[:4]) @@ -151,7 +110,6 @@ def __add_round_key(self, s, k): for j in range(4): s[i][j] ^= k[i][j] - def __round_encrypt(self, state_matrix, key_matrix): self.__sub_bytes(state_matrix) self.__shift_rows(state_matrix) @@ -169,19 +127,16 @@ def __sub_bytes(self, s): for j in range(4): s[i][j] = Sbox[s[i][j]] - def __inv_sub_bytes(self, s): for i in range(4): for j in range(4): s[i][j] = InvSbox[s[i][j]] - def __shift_rows(self, s): s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1] s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3] - def __inv_shift_rows(self, s): s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1] s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] @@ -196,12 +151,10 @@ def __mix_single_column(self, a): a[2] ^= t ^ xtime(a[2] ^ a[3]) a[3] ^= t ^ xtime(a[3] ^ u) - def __mix_columns(self, s): for i in range(4): self.__mix_single_column(s[i]) - def __inv_mix_columns(self, s): for i in range(4): u = xtime(xtime(s[i][0] ^ s[i][2])) @@ -215,6 +168,6 @@ def __inv_mix_columns(self, s): if __name__ == "__main__": - aes = AES(0x2b7e151628aed2a6abf7158809cf4f3c) - print(hex(aes.encrypt(0x3243f6a8885a308d313198a2e0370734))) - print(hex(aes.decrypt(0x3925841d02dc09fbdc118597196a0b32))) + aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C) + print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734))) + print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32))) diff --git a/tests/test_comparison.py b/tests/test_comparison.py index 3c115b1..2915c11 100644 --- a/tests/test_comparison.py +++ b/tests/test_comparison.py @@ -3,9 +3,11 @@ from reference_aes import AES as ReferenceAES from aes import AES as TinyGradAES + def generate_random_128bit(): return random.getrandbits(128) + class TestAESImplementations: def setup_method(self): self.key = generate_random_128bit() @@ -22,7 +24,7 @@ def test_decryption_comparison(self): ref_ciphertext = self.ref_aes.encrypt(self.plaintext) tiny_decrypted = self.tiny_aes.decrypt(ref_ciphertext) assert tiny_decrypted == self.plaintext - + tiny_ciphertext = self.tiny_aes.encrypt(self.plaintext) ref_decrypted = self.ref_aes.decrypt(tiny_ciphertext) assert ref_decrypted == self.plaintext @@ -33,10 +35,11 @@ def test_multiple_random_values(self): ref_ciphertext = self.ref_aes.encrypt(plaintext) tiny_ciphertext = self.tiny_aes.encrypt(plaintext) assert ref_ciphertext == tiny_ciphertext - + ref_decrypted = self.ref_aes.decrypt(tiny_ciphertext) tiny_decrypted = self.tiny_aes.decrypt(ref_ciphertext) assert ref_decrypted == plaintext and tiny_decrypted == plaintext + if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/tests_aes.py b/tests/tests_aes.py index b3de766..c683bd2 100644 --- a/tests/tests_aes.py +++ b/tests/tests_aes.py @@ -1,14 +1,13 @@ - - import pytest from aes import AES + class TestAESEncryptDecrypt: def setup_method(self): # Standard test vector from NIST FIPS-197 - self.key = 0x2b7e151628aed2a6abf7158809cf4f3c - self.plaintext = 0x3243f6a8885a308d313198a2e0370734 - self.ciphertext = 0x3925841d02dc09fbdc118597196a0b32 + self.key = 0x2B7E151628AED2A6ABF7158809CF4F3C + self.plaintext = 0x3243F6A8885A308D313198A2E0370734 + self.ciphertext = 0x3925841D02DC09FBDC118597196A0B32 self.aes = AES(self.key) def test_encrypt(self): @@ -27,24 +26,34 @@ def test_encrypt_decrypt_roundtrip(self): decrypted = self.aes.decrypt(encrypted) assert decrypted == self.plaintext - @pytest.mark.parametrize("plaintext,key,ciphertext", [ - # NIST SP 800-38A test vectors - (0x6bc1bee22e409f96e93d7e117393172a, - 0x2b7e151628aed2a6abf7158809cf4f3c, - 0x3ad77bb40d7a3660a89ecaf32466ef97), - (0xae2d8a571e03ac9c9eb76fac45af8e51, - 0x2b7e151628aed2a6abf7158809cf4f3c, - 0xf5d3d58503b9699de785895a96fdbaaf), - # Additional test with all zeros - (0x00000000000000000000000000000000, - 0x00000000000000000000000000000000, - 0x66e94bd4ef8a2c3b884cfa59ca342b2e) - ]) + @pytest.mark.parametrize( + "plaintext,key,ciphertext", + [ + # NIST SP 800-38A test vectors + ( + 0x6BC1BEE22E409F96E93D7E117393172A, + 0x2B7E151628AED2A6ABF7158809CF4F3C, + 0x3AD77BB40D7A3660A89ECAF32466EF97, + ), + ( + 0xAE2D8A571E03AC9C9EB76FAC45AF8E51, + 0x2B7E151628AED2A6ABF7158809CF4F3C, + 0xF5D3D58503B9699DE785895A96FDBAAF, + ), + # Additional test with all zeros + ( + 0x00000000000000000000000000000000, + 0x00000000000000000000000000000000, + 0x66E94BD4EF8A2C3B884CFA59CA342B2E, + ), + ], + ) def test_known_vectors(self, plaintext, key, ciphertext): """Test encryption and decryption with additional known test vectors""" aes = AES(key) assert aes.encrypt(plaintext) == ciphertext assert aes.decrypt(ciphertext) == plaintext + if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..272cf06 --- /dev/null +++ b/uv.lock @@ -0,0 +1,181 @@ +version = 1 +revision = 1 +requires-python = ">=3.11" + +[[package]] +name = "aes256" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "pytest" }, + { name = "pytest-benchmark" }, + { name = "tinygrad" }, +] + +[package.dev-dependencies] +dev = [ + { name = "black" }, + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [ + { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest-benchmark", specifier = ">=5.1.0" }, + { name = "tinygrad", specifier = ">=0.10.2" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "black", specifier = ">=25.1.0" }, + { name = "pytest", specifier = ">=8.3.4" }, +] + +[[package]] +name = "black" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "mypy-extensions" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/4f/87f596aca05c3ce5b94b8663dbfe242a12843caaa82dd3f85f1ffdc3f177/black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0", size = 1614372 }, + { url = "https://files.pythonhosted.org/packages/e7/d0/2c34c36190b741c59c901e56ab7f6e54dad8df05a6272a9747ecef7c6036/black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299", size = 1442865 }, + { url = "https://files.pythonhosted.org/packages/21/d4/7518c72262468430ead45cf22bd86c883a6448b9eb43672765d69a8f1248/black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096", size = 1749699 }, + { url = "https://files.pythonhosted.org/packages/58/db/4f5beb989b547f79096e035c4981ceb36ac2b552d0ac5f2620e941501c99/black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2", size = 1428028 }, + { url = "https://files.pythonhosted.org/packages/83/71/3fe4741df7adf015ad8dfa082dd36c94ca86bb21f25608eb247b4afb15b2/black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b", size = 1650988 }, + { url = "https://files.pythonhosted.org/packages/13/f3/89aac8a83d73937ccd39bbe8fc6ac8860c11cfa0af5b1c96d081facac844/black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc", size = 1453985 }, + { url = "https://files.pythonhosted.org/packages/6f/22/b99efca33f1f3a1d2552c714b1e1b5ae92efac6c43e790ad539a163d1754/black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f", size = 1783816 }, + { url = "https://files.pythonhosted.org/packages/18/7e/a27c3ad3822b6f2e0e00d63d58ff6299a99a5b3aee69fa77cd4b0076b261/black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba", size = 1440860 }, + { url = "https://files.pythonhosted.org/packages/98/87/0edf98916640efa5d0696e1abb0a8357b52e69e82322628f25bf14d263d1/black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f", size = 1650673 }, + { url = "https://files.pythonhosted.org/packages/52/e5/f7bf17207cf87fa6e9b676576749c6b6ed0d70f179a3d812c997870291c3/black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3", size = 1453190 }, + { url = "https://files.pythonhosted.org/packages/e3/ee/adda3d46d4a9120772fae6de454c8495603c37c4c3b9c60f25b1ab6401fe/black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171", size = 1782926 }, + { url = "https://files.pythonhosted.org/packages/cc/64/94eb5f45dcb997d2082f097a3944cfc7fe87e071907f677e80788a2d7b7a/black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18", size = 1442613 }, + { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646 }, +] + +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, +] + +[[package]] +name = "packaging" +version = "24.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, +] + +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191 }, +] + +[[package]] +name = "platformdirs" +version = "4.3.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 }, +] + +[[package]] +name = "pytest" +version = "8.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, +] + +[[package]] +name = "pytest-benchmark" +version = "5.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py-cpuinfo" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/d0/a8bd08d641b393db3be3819b03e2d9bb8760ca8479080a26a5f6e540e99c/pytest-benchmark-5.1.0.tar.gz", hash = "sha256:9ea661cdc292e8231f7cd4c10b0319e56a2118e2c09d9f50e1b3d150d2aca105", size = 337810 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/d6/b41653199ea09d5969d4e385df9bbfd9a100f28ca7e824ce7c0a016e3053/pytest_benchmark-5.1.0-py3-none-any.whl", hash = "sha256:922de2dfa3033c227c96da942d1878191afa135a29485fb942e85dff1c592c89", size = 44259 }, +] + +[[package]] +name = "tinygrad" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/32/a1be2284e35d3798ba78bab6bdb3874161c614e49e7b798878b8c2105027/tinygrad-0.10.2.tar.gz", hash = "sha256:53e808fcbfe540302c20045b8a53b9fc709fdb96deb669e0fff3949327ad49a7", size = 1870027 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/99/9beb27ae1a5cbdab12fd841d703dd0c2f9b2fea778c5f1877a97e44abf12/tinygrad-0.10.2-py3-none-any.whl", hash = "sha256:4536cbbf5c1247c10241f883b5a425583219d367437358bc00295f8a87e8133f", size = 1761398 }, +] From 71229f3511343a60e7f98371d811ade1eaed8f4a Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 19:28:33 -0800 Subject: [PATCH 05/12] fix: ci --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a378b87..b0365af 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies using uv run: | - uv pip install + uv pip install --requirements pyproject.toml - name: Run tests with uv run: | From 530997ae9c28131466700aebeb617e2416943776 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 19:29:43 -0800 Subject: [PATCH 06/12] fix: ci --- .github/workflows/test.yml | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b0365af..0b87ef3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,24 +9,34 @@ on: jobs: test: runs-on: ubuntu-latest - + + env: + UV_SYSTEM_PYTHON: 1 # Allows `uv` to install dependencies globally + steps: - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.11' - + - name: Install uv run: | curl -LsSf https://astral.sh/uv/install.sh | sh - echo "$HOME/.local/bin" >> $GITHUB_PATH + echo "$HOME/.local/bin" >> $GITHUB_PATH # Ensure `uv` is in PATH + + - name: Set up caching for uv + uses: actions/cache@v3 + with: + path: .uv-cache + key: uv-${{ runner.os }}-${{ hashFiles('pyproject.toml') }} + restore-keys: | + uv-${{ runner.os }}- - name: Install dependencies using uv run: | - uv pip install --requirements pyproject.toml + uv pip install --system --requirements pyproject.toml # Install from pyproject.toml with system Python - name: Run tests with uv run: | - uv venv exec pytest tests/ + pytest tests/ + + - name: Clean up uv cache + run: | + uv cache prune --ci # Reduce cache size after the job From 6e67ac16c7a23eb4c46d392b7a0c3022c0d1982a Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 19:30:28 -0800 Subject: [PATCH 07/12] fix: ci --- .github/workflows/test.yml | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0b87ef3..b6b9868 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,33 +10,16 @@ jobs: test: runs-on: ubuntu-latest - env: - UV_SYSTEM_PYTHON: 1 # Allows `uv` to install dependencies globally - steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install uv - run: | - curl -LsSf https://astral.sh/uv/install.sh | sh - echo "$HOME/.local/bin" >> $GITHUB_PATH # Ensure `uv` is in PATH - - - name: Set up caching for uv - uses: actions/cache@v3 + uses: astral-sh/setup-uv@v5 with: - path: .uv-cache - key: uv-${{ runner.os }}-${{ hashFiles('pyproject.toml') }} - restore-keys: | - uv-${{ runner.os }}- - - - name: Install dependencies using uv - run: | - uv pip install --system --requirements pyproject.toml # Install from pyproject.toml with system Python + enable-cache: true - - name: Run tests with uv - run: | - pytest tests/ + - name: Install dependencies + run: uv sync --all-extras --dev - - name: Clean up uv cache - run: | - uv cache prune --ci # Reduce cache size after the job + - name: Run tests + run: uv run pytest tests/ From 6c7406669d7c57f3e96cea3aeddc3f8830c6b163 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 19:59:28 -0800 Subject: [PATCH 08/12] chore: clean --- .github/workflows/{test.yml => ci.yml} | 0 __init__.py | 0 aes256/__init__.py | 0 aes.py => aes256/aes.py | 2 +- constants.py => aes256/constants.py | 0 pyproject.toml | 6 +++++- tests/__init__.py | 1 + bench.py => tests/bench.py | 5 +++-- tests/reference/__init__.py | 1 + reference_aes.py => tests/reference/aes.py | 2 +- tests/test_comparison.py | 4 ++-- 11 files changed, 14 insertions(+), 7 deletions(-) rename .github/workflows/{test.yml => ci.yml} (100%) create mode 100644 __init__.py create mode 100644 aes256/__init__.py rename aes.py => aes256/aes.py (98%) rename constants.py => aes256/constants.py (100%) rename bench.py => tests/bench.py (83%) create mode 100644 tests/reference/__init__.py rename reference_aes.py => tests/reference/aes.py (99%) diff --git a/.github/workflows/test.yml b/.github/workflows/ci.yml similarity index 100% rename from .github/workflows/test.yml rename to .github/workflows/ci.yml diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aes256/__init__.py b/aes256/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aes.py b/aes256/aes.py similarity index 98% rename from aes.py rename to aes256/aes.py index 869805e..00b30b6 100644 --- a/aes.py +++ b/aes256/aes.py @@ -1,6 +1,6 @@ from tinygrad.tensor import Tensor from tinygrad import dtypes -from constants import Sbox as Sbox_const, InvSbox as InvSbox_const, Rcon as Rcon_const +from aes256.constants import Sbox as Sbox_const, InvSbox as InvSbox_const, Rcon as Rcon_const Sbox = Tensor(Sbox_const, dtype=dtypes.uint8) InvSbox = Tensor(InvSbox_const, dtype=dtypes.uint8) diff --git a/constants.py b/aes256/constants.py similarity index 100% rename from constants.py rename to aes256/constants.py diff --git a/pyproject.toml b/pyproject.toml index ef33b54..f836e8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "aes256" version = "0.1.0" -description = "AES-256 implementation in TinyGrad" +description = "AES-256 Implementation" readme = "README.md" requires-python = ">=3.11" dependencies = [ @@ -15,3 +15,7 @@ dev = [ "black>=25.1.0", "pytest>=8.3.4", ] + + +[tool.hatch.build.targets.wheel] +packages = ["src/aes256"] \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..8685672 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Empty file to make the tests directory a Python package diff --git a/bench.py b/tests/bench.py similarity index 83% rename from bench.py rename to tests/bench.py index f4177b1..6fd2b86 100644 --- a/bench.py +++ b/tests/bench.py @@ -1,8 +1,9 @@ import pytest -from aes import AES as TinyGradAES -from reference_aes import AES as ReferenceAES +from aes256.aes import AES as TinyGradAES +from tests.reference.aes import AES as ReferenceAES +@pytest.mark.benchmark @pytest.mark.parametrize("num_ops", [2, 4, 8]) @pytest.mark.parametrize( "aes_class", [TinyGradAES, ReferenceAES], ids=["TinyGradAES", "ReferenceAES"] diff --git a/tests/reference/__init__.py b/tests/reference/__init__.py new file mode 100644 index 0000000..f616c75 --- /dev/null +++ b/tests/reference/__init__.py @@ -0,0 +1 @@ +# Empty file to make the reference directory a Python package diff --git a/reference_aes.py b/tests/reference/aes.py similarity index 99% rename from reference_aes.py rename to tests/reference/aes.py index 2fd8f21..b4e7992 100644 --- a/reference_aes.py +++ b/tests/reference/aes.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from constants import Sbox, InvSbox, Rcon +from aes256.constants import Sbox, InvSbox, Rcon """ diff --git a/tests/test_comparison.py b/tests/test_comparison.py index 2915c11..2f3fcc0 100644 --- a/tests/test_comparison.py +++ b/tests/test_comparison.py @@ -1,7 +1,7 @@ import pytest import random -from reference_aes import AES as ReferenceAES -from aes import AES as TinyGradAES +from .reference.aes import AES as ReferenceAES +from aes256.aes import AES as TinyGradAES def generate_random_128bit(): From a232df216d284cd897b783a0f906ab8f8f12a8bb Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 20:12:16 -0800 Subject: [PATCH 09/12] fix: testing --- README.md | 17 ++++++++++++++++- pyproject.toml | 4 ++++ tests/bench.py | 1 - tests/{tests_aes.py => test_aes.py} | 2 +- 4 files changed, 21 insertions(+), 3 deletions(-) rename tests/{tests_aes.py => test_aes.py} (98%) diff --git a/README.md b/README.md index 2674e54..5274b28 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,18 @@ # TinyGrad AES -A TinyGrad-based implementation of the Advanced Encryption Standard (AES) algorithm. This implementation is based on [bozhu's Python AES implementation](https://github.com/bozhu/AES-Python/tree/master) but rewritten to use TinyGrad tensors for computation. +A TinyGrad-based implementation of the Advanced Encryption Standard (AES) algorithm. This implementation is based on [bozhu's Python AES implementation](https://github.com/bozhu/AES-Python) but rewritten to use TinyGrad tensors for computation. + +## Benchmarks + +The implementation includes benchmarks comparing the TinyGrad-based implementation with a pure Python reference implementation. Here are the results: + +| Implementation | Operations | Time (μs) | Operations/sec | +|---------------|------------|-----------|----------------| +| Reference | 2 | 178.37 | 5,606.17 | +| Reference | 4 | 304.33 | 3,285.87 | +| Reference | 8 | 602.62 | 1,659.41 | +| TinyGrad | 2 | 616,860.92| 1.62 | +| TinyGrad | 4 | 901,575.75| 1.11 | +| TinyGrad | 8 | 1,828,307.83| 0.55 | + +Each operation consists of one encryption followed by one decryption. The benchmark measures the time taken for different numbers of operations (2, 4, and 8). diff --git a/pyproject.toml b/pyproject.toml index f836e8b..2eeb170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,10 @@ dev = [ "pytest>=8.3.4", ] +[tool.pytest.ini_options] +pythonpath = ["."] +testpaths = ["tests"] +python_files = ["test_*.py", "bench.py"] [tool.hatch.build.targets.wheel] packages = ["src/aes256"] \ No newline at end of file diff --git a/tests/bench.py b/tests/bench.py index 6fd2b86..4805770 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -2,7 +2,6 @@ from aes256.aes import AES as TinyGradAES from tests.reference.aes import AES as ReferenceAES - @pytest.mark.benchmark @pytest.mark.parametrize("num_ops", [2, 4, 8]) @pytest.mark.parametrize( diff --git a/tests/tests_aes.py b/tests/test_aes.py similarity index 98% rename from tests/tests_aes.py rename to tests/test_aes.py index c683bd2..3ed128c 100644 --- a/tests/tests_aes.py +++ b/tests/test_aes.py @@ -1,5 +1,5 @@ import pytest -from aes import AES +from aes256.aes import AES class TestAESEncryptDecrypt: From c7487694dd7f59c3ffc1ebb409fbc054699748bf Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 20:13:49 -0800 Subject: [PATCH 10/12] fix: readme --- README.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5274b28..2219425 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,19 @@ A TinyGrad-based implementation of the Advanced Encryption Standard (AES) algorithm. This implementation is based on [bozhu's Python AES implementation](https://github.com/bozhu/AES-Python) but rewritten to use TinyGrad tensors for computation. -## Benchmarks +## Usage -The implementation includes benchmarks comparing the TinyGrad-based implementation with a pure Python reference implementation. Here are the results: +Run tests: +```bash +uv run pytest +``` + +Run benchmarks: + ```bash +uv run pytest tests/bench.py +``` + +## Benchmark Results | Implementation | Operations | Time (μs) | Operations/sec | |---------------|------------|-----------|----------------| @@ -15,4 +25,4 @@ The implementation includes benchmarks comparing the TinyGrad-based implementati | TinyGrad | 4 | 901,575.75| 1.11 | | TinyGrad | 8 | 1,828,307.83| 0.55 | -Each operation consists of one encryption followed by one decryption. The benchmark measures the time taken for different numbers of operations (2, 4, and 8). +Each operation is one encryption + decryption. The TinyGrad implementation is slower due to tensor operation overhead. From aad86b78e14c751e0457a39561ce86433a3ff81f Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 20:15:59 -0800 Subject: [PATCH 11/12] fix: readme, formatting --- README.md | 5 +++++ aes256/aes.py | 7 ++++++- pyproject.toml | 3 ++- tests/bench.py | 1 + uv.lock | 2 ++ 5 files changed, 16 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2219425..9b58780 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,11 @@ A TinyGrad-based implementation of the Advanced Encryption Standard (AES) algori ## Usage +Install dependencies: +```bash +uv sync +``` + Run tests: ```bash uv run pytest diff --git a/aes256/aes.py b/aes256/aes.py index 00b30b6..919609f 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -1,11 +1,16 @@ from tinygrad.tensor import Tensor from tinygrad import dtypes -from aes256.constants import Sbox as Sbox_const, InvSbox as InvSbox_const, Rcon as Rcon_const +from aes256.constants import ( + Sbox as Sbox_const, + InvSbox as InvSbox_const, + Rcon as Rcon_const, +) Sbox = Tensor(Sbox_const, dtype=dtypes.uint8) InvSbox = Tensor(InvSbox_const, dtype=dtypes.uint8) Rcon = Tensor(Rcon_const, dtype=dtypes.uint8) + def xtime(a: int) -> int: return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) diff --git a/pyproject.toml b/pyproject.toml index 2eeb170..ee9b5a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "AES-256 Implementation" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "black>=25.1.0", "pytest>=8.3.4", "pytest-benchmark>=5.1.0", "tinygrad>=0.10.2", @@ -22,4 +23,4 @@ testpaths = ["tests"] python_files = ["test_*.py", "bench.py"] [tool.hatch.build.targets.wheel] -packages = ["src/aes256"] \ No newline at end of file +packages = ["src/aes256"] diff --git a/tests/bench.py b/tests/bench.py index 4805770..6fd2b86 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -2,6 +2,7 @@ from aes256.aes import AES as TinyGradAES from tests.reference.aes import AES as ReferenceAES + @pytest.mark.benchmark @pytest.mark.parametrize("num_ops", [2, 4, 8]) @pytest.mark.parametrize( diff --git a/uv.lock b/uv.lock index 272cf06..abe1560 100644 --- a/uv.lock +++ b/uv.lock @@ -7,6 +7,7 @@ name = "aes256" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "black" }, { name = "pytest" }, { name = "pytest-benchmark" }, { name = "tinygrad" }, @@ -20,6 +21,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "black", specifier = ">=25.1.0" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-benchmark", specifier = ">=5.1.0" }, { name = "tinygrad", specifier = ">=0.10.2" }, From a655af2114a61fbedeeb4176a40c7c8ed080461f Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 20:25:07 -0800 Subject: [PATCH 12/12] fix: paths --- __init__.py | 0 tests/bench.py => bench.py | 0 pyproject.toml | 3 +-- 3 files changed, 1 insertion(+), 2 deletions(-) delete mode 100644 __init__.py rename tests/bench.py => bench.py (100%) diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/bench.py b/bench.py similarity index 100% rename from tests/bench.py rename to bench.py diff --git a/pyproject.toml b/pyproject.toml index ee9b5a0..8d4c8a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dev = [ [tool.pytest.ini_options] pythonpath = ["."] testpaths = ["tests"] -python_files = ["test_*.py", "bench.py"] [tool.hatch.build.targets.wheel] -packages = ["src/aes256"] +packages = ["aes256"]