From d928ac1507b3386030a21150e777c2c88be47456 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 21:28:54 -0800 Subject: [PATCH 1/9] benching --- aes256/aes.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index 919609f..c01a0d2 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -5,16 +5,81 @@ InvSbox as InvSbox_const, Rcon as Rcon_const, ) +from functools import wraps +import time +from collections import defaultdict +import atexit Sbox = Tensor(Sbox_const, dtype=dtypes.uint8) InvSbox = Tensor(InvSbox_const, dtype=dtypes.uint8) Rcon = Tensor(Rcon_const, dtype=dtypes.uint8) - +# Add these at the top after imports +timing_stats = defaultdict(lambda: {'total_time': 0, 'calls': 0}) + +def format_table(headers, rows): + # Calculate column widths + widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(str(cell))) + + # Create format string for rows + row_format = '| ' + ' | '.join(f'{{:<{w}}}' for w in widths) + ' |' + separator = '+' + '+'.join('-' * (w + 2) for w in widths) + '+' + + # Build table + table = [separator] + table.append(row_format.format(*headers)) + table.append(separator) + for row in rows: + table.append(row_format.format(*row)) + table.append(separator) + + return '\n'.join(table) + +def print_timing_stats(): + if not timing_stats: + return + + # Prepare table data + headers = ['Function', 'Avg Time (ms)', 'Calls', 'Total Time (ms)'] + table_data = [] + for func_name, stats in sorted(timing_stats.items()): + avg_time = (stats['total_time'] * 1000) / stats['calls'] + table_data.append([ + func_name, + f"{avg_time:.2f}", + str(stats['calls']), + f"{(stats['total_time'] * 1000):.2f}" + ]) + + print("\nTiming Statistics:") + print(format_table(headers, table_data)) + +# Register the printing function to run at exit +atexit.register(print_timing_stats) + +def timing_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.perf_counter() + result = func(*args, **kwargs) + end_time = time.perf_counter() + + # Update statistics + timing_stats[func.__name__]['total_time'] += end_time - start_time + timing_stats[func.__name__]['calls'] += 1 + + return result + return wrapper + +@timing_decorator def xtime(a: int) -> int: return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) +@timing_decorator def xtime_tensor(a: Tensor) -> Tensor: high_bit_mask = a.bitwise_and(0x80) shifted = a.lshift(1) @@ -25,6 +90,7 @@ def xtime_tensor(a: Tensor) -> Tensor: return result.bitwise_and(0xFF) +@timing_decorator def text2matrix(text: int) -> Tensor: return ( Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint64) @@ -33,6 +99,7 @@ def text2matrix(text: int) -> Tensor: ) +@timing_decorator def matrix2text(matrix: Tensor) -> int: flat = matrix.flatten() result = 0 @@ -46,6 +113,7 @@ class AES: def __init__(self, master_key): self.change_key(master_key) + @timing_decorator def change_key(self, master_key): self.round_keys = Tensor.zeros((44, 4), dtype=dtypes.uint8).contiguous() self.round_keys[:4] = text2matrix(master_key) @@ -63,6 +131,7 @@ def change_key(self, master_key): else: self.round_keys[i] = self.round_keys[i - 4].xor(self.round_keys[i - 1]) + @timing_decorator def encrypt(self, plaintext: int) -> int: self.plain_state = text2matrix(plaintext) self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) @@ -78,6 +147,7 @@ def encrypt(self, plaintext: int) -> int: return matrix2text(self.plain_state) + @timing_decorator def decrypt(self, ciphertext: int) -> int: self.cipher_state = text2matrix(ciphertext) self.cipher_state = self.__add_round_key( @@ -94,6 +164,7 @@ def decrypt(self, ciphertext: int) -> int: return matrix2text(self.cipher_state) + @timing_decorator def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: state_matrix = self.__sub_bytes(state_matrix) state_matrix = self.__shift_rows(state_matrix) @@ -101,6 +172,7 @@ def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: state_matrix = self.__add_round_key(state_matrix, key_matrix) return state_matrix + @timing_decorator def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: state_matrix = self.__add_round_key(state_matrix, key_matrix) state_matrix = self.__inv_mix_columns(state_matrix) @@ -108,15 +180,19 @@ def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: state_matrix = self.__inv_sub_bytes(state_matrix) return state_matrix + @timing_decorator def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: return s.xor(k) + @timing_decorator def __sub_bytes(self, s: Tensor) -> Tensor: return Sbox[s] + @timing_decorator def __inv_sub_bytes(self, s: Tensor) -> Tensor: return InvSbox[s] + @timing_decorator def __shift_rows(self, s: Tensor) -> Tensor: state = s.clone() @@ -125,6 +201,7 @@ def __shift_rows(self, s: Tensor) -> Tensor: return state + @timing_decorator def __inv_shift_rows(self, s: Tensor) -> Tensor: state = s.clone() @@ -133,6 +210,7 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: return state + @timing_decorator def __mix_columns(self, state: Tensor) -> Tensor: t = state[:, 0].xor(state[:, 1]).xor(state[:, 2]).xor(state[:, 3]) xtimes = xtime_tensor(state.roll(-1, dims=1).xor(state)) @@ -140,6 +218,7 @@ def __mix_columns(self, state: Tensor) -> Tensor: return state + @timing_decorator def __inv_mix_columns(self, state: Tensor) -> Tensor: u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2]))) v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3]))) @@ -155,5 +234,10 @@ def __inv_mix_columns(self, state: Tensor) -> Tensor: if __name__ == "__main__": aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C) - print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734))) - print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32))) + pt = 0x3243F6A8885A308D313198A2E0370734 + ct = aes.encrypt(pt) + rec_pt = aes.decrypt(ct) + print(f'pt: {hex(pt)}') + print(f'ct: {hex(ct)}') + print(f'rec_pt: {hex(rec_pt)}') + From 6f0d4b87b31096f888b30c699baac7e15270abb6 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 21:37:03 -0800 Subject: [PATCH 2/9] assert --- aes256/aes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aes256/aes.py b/aes256/aes.py index c01a0d2..3083782 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -241,3 +241,4 @@ def __inv_mix_columns(self, state: Tensor) -> Tensor: print(f'ct: {hex(ct)}') print(f'rec_pt: {hex(rec_pt)}') + assert pt == rec_pt From 17c9d8b1485a5c39a88aae80383ccd3a1f2ba6a2 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 22:02:24 -0800 Subject: [PATCH 3/9] feat(wip): in-place --- aes256/aes.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index 3083782..dd5efa6 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -93,7 +93,7 @@ def xtime_tensor(a: Tensor) -> Tensor: @timing_decorator def text2matrix(text: int) -> Tensor: return ( - Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint64) + Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint8) .bitwise_and(0xFF) .reshape((4, 4)) ) @@ -134,63 +134,61 @@ def change_key(self, master_key): @timing_decorator def encrypt(self, plaintext: int) -> int: self.plain_state = text2matrix(plaintext) - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) + self.__add_round_key(self.plain_state, self.round_keys[:4]) for i in range(1, 10): self.plain_state = self.__round_encrypt( self.plain_state, self.round_keys[4 * i : 4 * (i + 1)] ) - self.plain_state = self.__sub_bytes(self.plain_state) + self.__sub_bytes(self.plain_state) self.plain_state = self.__shift_rows(self.plain_state) - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[40:]) + self.__add_round_key(self.plain_state, self.round_keys[40:]) return matrix2text(self.plain_state) @timing_decorator def decrypt(self, ciphertext: int) -> int: self.cipher_state = text2matrix(ciphertext) - self.cipher_state = self.__add_round_key( - self.cipher_state, self.round_keys[40:] - ) - self.cipher_state = self.__inv_shift_rows(self.cipher_state) - self.cipher_state = self.__inv_sub_bytes(self.cipher_state) + self.__add_round_key(self.cipher_state, self.round_keys[40:]) + self.__inv_shift_rows(self.cipher_state) + self.__inv_sub_bytes(self.cipher_state) for i in range(9, 0, -1): self.cipher_state = self.__round_decrypt( self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)] ) - self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:4]) + self.__add_round_key(self.cipher_state, self.round_keys[:4]) return matrix2text(self.cipher_state) @timing_decorator def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: - state_matrix = self.__sub_bytes(state_matrix) + self.__sub_bytes(state_matrix) state_matrix = self.__shift_rows(state_matrix) state_matrix = self.__mix_columns(state_matrix) - state_matrix = self.__add_round_key(state_matrix, key_matrix) + self.__add_round_key(state_matrix, key_matrix) return state_matrix @timing_decorator def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: - state_matrix = self.__add_round_key(state_matrix, key_matrix) + self.__add_round_key(state_matrix, key_matrix) state_matrix = self.__inv_mix_columns(state_matrix) - state_matrix = self.__inv_shift_rows(state_matrix) - state_matrix = self.__inv_sub_bytes(state_matrix) + self.__inv_shift_rows(state_matrix) + self.__inv_sub_bytes(state_matrix) return state_matrix @timing_decorator def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: - return s.xor(k) + s.assign(s.xor(k)) @timing_decorator def __sub_bytes(self, s: Tensor) -> Tensor: - return Sbox[s] + s.assign(Sbox[s]) @timing_decorator def __inv_sub_bytes(self, s: Tensor) -> Tensor: - return InvSbox[s] + s.assign(InvSbox[s]) @timing_decorator def __shift_rows(self, s: Tensor) -> Tensor: @@ -208,7 +206,7 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: for i in range(1, 4): state[:, i] = state[:, i].roll(i, dims=0) - return state + s.assign(state) @timing_decorator def __mix_columns(self, state: Tensor) -> Tensor: From 6618cc29af7f851838e7c0f3fd9f9390750793b2 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 22:30:48 -0800 Subject: [PATCH 4/9] feat(wip): inv shift rows --- aes256/aes.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index dd5efa6..09602ae 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -201,12 +201,11 @@ def __shift_rows(self, s: Tensor) -> Tensor: @timing_decorator def __inv_shift_rows(self, s: Tensor) -> Tensor: - state = s.clone() - + _s = s.contiguous() for i in range(1, 4): - state[:, i] = state[:, i].roll(i, dims=0) + _s[:, i] = _s[:, i].roll(i, dims=0) - s.assign(state) + s.assign(_s) @timing_decorator def __mix_columns(self, state: Tensor) -> Tensor: From 80dfd427856fc5f34ee1a1f6a4c562bdce6a3ef0 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 22:32:55 -0800 Subject: [PATCH 5/9] feat: mix columns --- aes256/aes.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index 09602ae..0e0f637 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -166,7 +166,7 @@ def decrypt(self, ciphertext: int) -> int: def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: self.__sub_bytes(state_matrix) state_matrix = self.__shift_rows(state_matrix) - state_matrix = self.__mix_columns(state_matrix) + self.__mix_columns(state_matrix) self.__add_round_key(state_matrix, key_matrix) return state_matrix @@ -208,12 +208,11 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: s.assign(_s) @timing_decorator - def __mix_columns(self, state: Tensor) -> Tensor: - t = state[:, 0].xor(state[:, 1]).xor(state[:, 2]).xor(state[:, 3]) - xtimes = xtime_tensor(state.roll(-1, dims=1).xor(state)) - state = state.xor(t.unsqueeze(1)).xor(xtimes) + def __mix_columns(self, s: Tensor) -> Tensor: + t = s[:, 0].xor(s[:, 1]).xor(s[:, 2]).xor(s[:, 3]) + xtimes = xtime_tensor(s.roll(-1, dims=1).xor(s)).contiguous() + s.assign(s.xor(t.unsqueeze(1)).xor(xtimes)) - return state @timing_decorator def __inv_mix_columns(self, state: Tensor) -> Tensor: @@ -226,7 +225,8 @@ def __inv_mix_columns(self, state: Tensor) -> Tensor: out[:, 2] = state[:, 2].xor(u) out[:, 3] = state[:, 3].xor(v) - return self.__mix_columns(out) + self.__mix_columns(out) + return out if __name__ == "__main__": From dc0e69db1e18a462a20eb00fa3dfaff5bf8c8055 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 22:35:36 -0800 Subject: [PATCH 6/9] feat: xtime, --- aes256/aes.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index 0e0f637..41bbe3b 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -74,29 +74,16 @@ def wrapper(*args, **kwargs): return result return wrapper -@timing_decorator -def xtime(a: int) -> int: - return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) - @timing_decorator -def xtime_tensor(a: Tensor) -> Tensor: - high_bit_mask = a.bitwise_and(0x80) +def xtime(a: Tensor) -> Tensor: shifted = a.lshift(1) - - condition = high_bit_mask != 0 - result = condition.where(shifted.xor(0x1B), shifted) - - return result.bitwise_and(0xFF) + return (a.bitwise_and(0x80) != 0).where(shifted.xor(0x1B), shifted).cast(dtypes.uint8) @timing_decorator def text2matrix(text: int) -> Tensor: - return ( - Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint8) - .bitwise_and(0xFF) - .reshape((4, 4)) - ) + return Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint8).reshape((4, 4)) @timing_decorator @@ -210,14 +197,14 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: @timing_decorator def __mix_columns(self, s: Tensor) -> Tensor: t = s[:, 0].xor(s[:, 1]).xor(s[:, 2]).xor(s[:, 3]) - xtimes = xtime_tensor(s.roll(-1, dims=1).xor(s)).contiguous() + xtimes = xtime(s.roll(-1, dims=1).xor(s)).contiguous() s.assign(s.xor(t.unsqueeze(1)).xor(xtimes)) @timing_decorator def __inv_mix_columns(self, state: Tensor) -> Tensor: - u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2]))) - v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3]))) + u = xtime(xtime(state[:, 0].xor(state[:, 2]))) + v = xtime(xtime(state[:, 1].xor(state[:, 3]))) out = state.clone() out[:, 0] = state[:, 0].xor(u) From cab27f06753c2c1c5cd14e28c3c14753dfaba0d3 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 22:39:19 -0800 Subject: [PATCH 7/9] feat: inv mix columns --- aes256/aes.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index 41bbe3b..cc8db70 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -141,7 +141,7 @@ def decrypt(self, ciphertext: int) -> int: self.__inv_shift_rows(self.cipher_state) self.__inv_sub_bytes(self.cipher_state) for i in range(9, 0, -1): - self.cipher_state = self.__round_decrypt( + self.__round_decrypt( self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)] ) @@ -160,10 +160,9 @@ def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: @timing_decorator def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: self.__add_round_key(state_matrix, key_matrix) - state_matrix = self.__inv_mix_columns(state_matrix) + self.__inv_mix_columns(state_matrix) self.__inv_shift_rows(state_matrix) self.__inv_sub_bytes(state_matrix) - return state_matrix @timing_decorator def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: @@ -202,18 +201,17 @@ def __mix_columns(self, s: Tensor) -> Tensor: @timing_decorator - def __inv_mix_columns(self, state: Tensor) -> Tensor: - u = xtime(xtime(state[:, 0].xor(state[:, 2]))) - v = xtime(xtime(state[:, 1].xor(state[:, 3]))) - - out = state.clone() - out[:, 0] = state[:, 0].xor(u) - out[:, 1] = state[:, 1].xor(v) - out[:, 2] = state[:, 2].xor(u) - out[:, 3] = state[:, 3].xor(v) - - self.__mix_columns(out) - return out + def __inv_mix_columns(self, s: Tensor) -> Tensor: + even_cols = s[:, [0,2]] + odd_cols = s[:, [1,3]] + + u = xtime(xtime(even_cols[:,0].xor(even_cols[:,1]))) + v = xtime(xtime(odd_cols[:,0].xor(odd_cols[:,1]))) + + s[:, [0,2]] = s[:, [0,2]].xor(u.unsqueeze(1)) + s[:, [1,3]] = s[:, [1,3]].xor(v.unsqueeze(1)) + + self.__mix_columns(s) if __name__ == "__main__": From 5feaa7bb117d1f8f6f311f931bf843a01c745f90 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 22:40:38 -0800 Subject: [PATCH 8/9] feat: shift rows --- aes256/aes.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index cc8db70..9cb0a5b 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -124,12 +124,12 @@ def encrypt(self, plaintext: int) -> int: self.__add_round_key(self.plain_state, self.round_keys[:4]) for i in range(1, 10): - self.plain_state = self.__round_encrypt( + self.__round_encrypt( self.plain_state, self.round_keys[4 * i : 4 * (i + 1)] ) self.__sub_bytes(self.plain_state) - self.plain_state = self.__shift_rows(self.plain_state) + self.__shift_rows(self.plain_state) self.__add_round_key(self.plain_state, self.round_keys[40:]) return matrix2text(self.plain_state) @@ -152,10 +152,10 @@ def decrypt(self, ciphertext: int) -> int: @timing_decorator def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: self.__sub_bytes(state_matrix) - state_matrix = self.__shift_rows(state_matrix) + self.__shift_rows(state_matrix) self.__mix_columns(state_matrix) self.__add_round_key(state_matrix, key_matrix) - return state_matrix + @timing_decorator def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: @@ -178,16 +178,16 @@ def __inv_sub_bytes(self, s: Tensor) -> Tensor: @timing_decorator def __shift_rows(self, s: Tensor) -> Tensor: - state = s.clone() - + _s = s for i in range(1, 4): - state[:, i] = state[:, i].roll(-i, dims=0) + _s[:, i] = _s[:, i].roll(-i, dims=0) + + s.assign(_s) - return state @timing_decorator def __inv_shift_rows(self, s: Tensor) -> Tensor: - _s = s.contiguous() + _s = s for i in range(1, 4): _s[:, i] = _s[:, i].roll(i, dims=0) From 50c34357d0fc29ba018dc59ea5d36c8b761fb807 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Fri, 21 Feb 2025 22:41:44 -0800 Subject: [PATCH 9/9] chore: remove debugging stuff --- aes256/aes.py | 80 ------------------------------------------ tests/reference/aes.py | 11 ++++-- 2 files changed, 9 insertions(+), 82 deletions(-) diff --git a/aes256/aes.py b/aes256/aes.py index 9cb0a5b..0e3c2a9 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -5,88 +5,20 @@ InvSbox as InvSbox_const, Rcon as Rcon_const, ) -from functools import wraps -import time -from collections import defaultdict -import atexit Sbox = Tensor(Sbox_const, dtype=dtypes.uint8) InvSbox = Tensor(InvSbox_const, dtype=dtypes.uint8) Rcon = Tensor(Rcon_const, dtype=dtypes.uint8) -# Add these at the top after imports -timing_stats = defaultdict(lambda: {'total_time': 0, 'calls': 0}) - -def format_table(headers, rows): - # Calculate column widths - widths = [len(h) for h in headers] - for row in rows: - for i, cell in enumerate(row): - widths[i] = max(widths[i], len(str(cell))) - - # Create format string for rows - row_format = '| ' + ' | '.join(f'{{:<{w}}}' for w in widths) + ' |' - separator = '+' + '+'.join('-' * (w + 2) for w in widths) + '+' - - # Build table - table = [separator] - table.append(row_format.format(*headers)) - table.append(separator) - for row in rows: - table.append(row_format.format(*row)) - table.append(separator) - - return '\n'.join(table) - -def print_timing_stats(): - if not timing_stats: - return - - # Prepare table data - headers = ['Function', 'Avg Time (ms)', 'Calls', 'Total Time (ms)'] - table_data = [] - for func_name, stats in sorted(timing_stats.items()): - avg_time = (stats['total_time'] * 1000) / stats['calls'] - table_data.append([ - func_name, - f"{avg_time:.2f}", - str(stats['calls']), - f"{(stats['total_time'] * 1000):.2f}" - ]) - - print("\nTiming Statistics:") - print(format_table(headers, table_data)) - -# Register the printing function to run at exit -atexit.register(print_timing_stats) - -def timing_decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - start_time = time.perf_counter() - result = func(*args, **kwargs) - end_time = time.perf_counter() - - # Update statistics - timing_stats[func.__name__]['total_time'] += end_time - start_time - timing_stats[func.__name__]['calls'] += 1 - - return result - return wrapper - - -@timing_decorator def xtime(a: Tensor) -> Tensor: shifted = a.lshift(1) return (a.bitwise_and(0x80) != 0).where(shifted.xor(0x1B), shifted).cast(dtypes.uint8) -@timing_decorator def text2matrix(text: int) -> Tensor: return Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint8).reshape((4, 4)) -@timing_decorator def matrix2text(matrix: Tensor) -> int: flat = matrix.flatten() result = 0 @@ -100,7 +32,6 @@ class AES: def __init__(self, master_key): self.change_key(master_key) - @timing_decorator def change_key(self, master_key): self.round_keys = Tensor.zeros((44, 4), dtype=dtypes.uint8).contiguous() self.round_keys[:4] = text2matrix(master_key) @@ -118,7 +49,6 @@ def change_key(self, master_key): else: self.round_keys[i] = self.round_keys[i - 4].xor(self.round_keys[i - 1]) - @timing_decorator def encrypt(self, plaintext: int) -> int: self.plain_state = text2matrix(plaintext) self.__add_round_key(self.plain_state, self.round_keys[:4]) @@ -134,7 +64,6 @@ def encrypt(self, plaintext: int) -> int: return matrix2text(self.plain_state) - @timing_decorator def decrypt(self, ciphertext: int) -> int: self.cipher_state = text2matrix(ciphertext) self.__add_round_key(self.cipher_state, self.round_keys[40:]) @@ -149,7 +78,6 @@ def decrypt(self, ciphertext: int) -> int: return matrix2text(self.cipher_state) - @timing_decorator def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: self.__sub_bytes(state_matrix) self.__shift_rows(state_matrix) @@ -157,26 +85,21 @@ def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: self.__add_round_key(state_matrix, key_matrix) - @timing_decorator def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: self.__add_round_key(state_matrix, key_matrix) self.__inv_mix_columns(state_matrix) self.__inv_shift_rows(state_matrix) self.__inv_sub_bytes(state_matrix) - @timing_decorator def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: s.assign(s.xor(k)) - @timing_decorator def __sub_bytes(self, s: Tensor) -> Tensor: s.assign(Sbox[s]) - @timing_decorator def __inv_sub_bytes(self, s: Tensor) -> Tensor: s.assign(InvSbox[s]) - @timing_decorator def __shift_rows(self, s: Tensor) -> Tensor: _s = s for i in range(1, 4): @@ -185,7 +108,6 @@ def __shift_rows(self, s: Tensor) -> Tensor: s.assign(_s) - @timing_decorator def __inv_shift_rows(self, s: Tensor) -> Tensor: _s = s for i in range(1, 4): @@ -193,14 +115,12 @@ def __inv_shift_rows(self, s: Tensor) -> Tensor: s.assign(_s) - @timing_decorator def __mix_columns(self, s: Tensor) -> Tensor: t = s[:, 0].xor(s[:, 1]).xor(s[:, 2]).xor(s[:, 3]) xtimes = xtime(s.roll(-1, dims=1).xor(s)).contiguous() s.assign(s.xor(t.unsqueeze(1)).xor(xtimes)) - @timing_decorator def __inv_mix_columns(self, s: Tensor) -> Tensor: even_cols = s[:, [0,2]] odd_cols = s[:, [1,3]] diff --git a/tests/reference/aes.py b/tests/reference/aes.py index b4e7992..cb98bcc 100644 --- a/tests/reference/aes.py +++ b/tests/reference/aes.py @@ -169,5 +169,12 @@ def __inv_mix_columns(self, s): if __name__ == "__main__": aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C) - print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734))) - print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32))) + pt = 0x3243F6A8885A308D313198A2E0370734 + ct = aes.encrypt(pt) + rec_pt = aes.decrypt(ct) + print(f'pt: {hex(pt)}') + print(f'ct: {hex(ct)}') + print(f'rec_pt: {hex(rec_pt)}') + + assert pt == rec_pt +