diff --git a/aes256/aes.py b/aes256/aes.py index 919609f..0e3c2a9 100644 --- a/aes256/aes.py +++ b/aes256/aes.py @@ -10,27 +10,13 @@ InvSbox = Tensor(InvSbox_const, dtype=dtypes.uint8) Rcon = Tensor(Rcon_const, dtype=dtypes.uint8) - -def xtime(a: int) -> int: - return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) - - -def xtime_tensor(a: Tensor) -> Tensor: - high_bit_mask = a.bitwise_and(0x80) +def xtime(a: Tensor) -> Tensor: shifted = a.lshift(1) - - condition = high_bit_mask != 0 - result = condition.where(shifted.xor(0x1B), shifted) - - return result.bitwise_and(0xFF) + return (a.bitwise_and(0x80) != 0).where(shifted.xor(0x1B), shifted).cast(dtypes.uint8) def text2matrix(text: int) -> Tensor: - return ( - Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint64) - .bitwise_and(0xFF) - .reshape((4, 4)) - ) + return Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint8).reshape((4, 4)) def matrix2text(matrix: Tensor) -> int: @@ -65,95 +51,96 @@ def change_key(self, master_key): def encrypt(self, plaintext: int) -> int: self.plain_state = text2matrix(plaintext) - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4]) + self.__add_round_key(self.plain_state, self.round_keys[:4]) for i in range(1, 10): - self.plain_state = self.__round_encrypt( + self.__round_encrypt( self.plain_state, self.round_keys[4 * i : 4 * (i + 1)] ) - self.plain_state = self.__sub_bytes(self.plain_state) - self.plain_state = self.__shift_rows(self.plain_state) - self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[40:]) + self.__sub_bytes(self.plain_state) + self.__shift_rows(self.plain_state) + self.__add_round_key(self.plain_state, self.round_keys[40:]) return matrix2text(self.plain_state) def decrypt(self, ciphertext: int) -> int: self.cipher_state = text2matrix(ciphertext) - self.cipher_state = self.__add_round_key( - self.cipher_state, self.round_keys[40:] - ) - self.cipher_state = self.__inv_shift_rows(self.cipher_state) - self.cipher_state = self.__inv_sub_bytes(self.cipher_state) + self.__add_round_key(self.cipher_state, self.round_keys[40:]) + self.__inv_shift_rows(self.cipher_state) + self.__inv_sub_bytes(self.cipher_state) for i in range(9, 0, -1): - self.cipher_state = self.__round_decrypt( + self.__round_decrypt( self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)] ) - self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:4]) + self.__add_round_key(self.cipher_state, self.round_keys[:4]) return matrix2text(self.cipher_state) def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: - state_matrix = self.__sub_bytes(state_matrix) - state_matrix = self.__shift_rows(state_matrix) - state_matrix = self.__mix_columns(state_matrix) - state_matrix = self.__add_round_key(state_matrix, key_matrix) - return state_matrix + self.__sub_bytes(state_matrix) + self.__shift_rows(state_matrix) + self.__mix_columns(state_matrix) + self.__add_round_key(state_matrix, key_matrix) + def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor: - state_matrix = self.__add_round_key(state_matrix, key_matrix) - state_matrix = self.__inv_mix_columns(state_matrix) - state_matrix = self.__inv_shift_rows(state_matrix) - state_matrix = self.__inv_sub_bytes(state_matrix) - return state_matrix + self.__add_round_key(state_matrix, key_matrix) + self.__inv_mix_columns(state_matrix) + self.__inv_shift_rows(state_matrix) + self.__inv_sub_bytes(state_matrix) def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor: - return s.xor(k) + s.assign(s.xor(k)) def __sub_bytes(self, s: Tensor) -> Tensor: - return Sbox[s] + s.assign(Sbox[s]) def __inv_sub_bytes(self, s: Tensor) -> Tensor: - return InvSbox[s] + s.assign(InvSbox[s]) def __shift_rows(self, s: Tensor) -> Tensor: - state = s.clone() - + _s = s for i in range(1, 4): - state[:, i] = state[:, i].roll(-i, dims=0) + _s[:, i] = _s[:, i].roll(-i, dims=0) - return state + s.assign(_s) - def __inv_shift_rows(self, s: Tensor) -> Tensor: - state = s.clone() + def __inv_shift_rows(self, s: Tensor) -> Tensor: + _s = s for i in range(1, 4): - state[:, i] = state[:, i].roll(i, dims=0) - - return state - - def __mix_columns(self, state: Tensor) -> Tensor: - t = state[:, 0].xor(state[:, 1]).xor(state[:, 2]).xor(state[:, 3]) - xtimes = xtime_tensor(state.roll(-1, dims=1).xor(state)) - state = state.xor(t.unsqueeze(1)).xor(xtimes) + _s[:, i] = _s[:, i].roll(i, dims=0) - return state + s.assign(_s) - def __inv_mix_columns(self, state: Tensor) -> Tensor: - u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2]))) - v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3]))) + def __mix_columns(self, s: Tensor) -> Tensor: + t = s[:, 0].xor(s[:, 1]).xor(s[:, 2]).xor(s[:, 3]) + xtimes = xtime(s.roll(-1, dims=1).xor(s)).contiguous() + s.assign(s.xor(t.unsqueeze(1)).xor(xtimes)) - out = state.clone() - out[:, 0] = state[:, 0].xor(u) - out[:, 1] = state[:, 1].xor(v) - out[:, 2] = state[:, 2].xor(u) - out[:, 3] = state[:, 3].xor(v) - return self.__mix_columns(out) + def __inv_mix_columns(self, s: Tensor) -> Tensor: + even_cols = s[:, [0,2]] + odd_cols = s[:, [1,3]] + + u = xtime(xtime(even_cols[:,0].xor(even_cols[:,1]))) + v = xtime(xtime(odd_cols[:,0].xor(odd_cols[:,1]))) + + s[:, [0,2]] = s[:, [0,2]].xor(u.unsqueeze(1)) + s[:, [1,3]] = s[:, [1,3]].xor(v.unsqueeze(1)) + + self.__mix_columns(s) if __name__ == "__main__": aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C) - print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734))) - print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32))) + pt = 0x3243F6A8885A308D313198A2E0370734 + ct = aes.encrypt(pt) + rec_pt = aes.decrypt(ct) + print(f'pt: {hex(pt)}') + print(f'ct: {hex(ct)}') + print(f'rec_pt: {hex(rec_pt)}') + + assert pt == rec_pt diff --git a/tests/reference/aes.py b/tests/reference/aes.py index b4e7992..cb98bcc 100644 --- a/tests/reference/aes.py +++ b/tests/reference/aes.py @@ -169,5 +169,12 @@ def __inv_mix_columns(self, s): if __name__ == "__main__": aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C) - print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734))) - print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32))) + pt = 0x3243F6A8885A308D313198A2E0370734 + ct = aes.encrypt(pt) + rec_pt = aes.decrypt(ct) + print(f'pt: {hex(pt)}') + print(f'ct: {hex(ct)}') + print(f'rec_pt: {hex(rec_pt)}') + + assert pt == rec_pt +