Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 55 additions & 68 deletions aes256/aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,13 @@
InvSbox = Tensor(InvSbox_const, dtype=dtypes.uint8)
Rcon = Tensor(Rcon_const, dtype=dtypes.uint8)


def xtime(a: int) -> int:
return (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1)


def xtime_tensor(a: Tensor) -> Tensor:
high_bit_mask = a.bitwise_and(0x80)
def xtime(a: Tensor) -> Tensor:
shifted = a.lshift(1)

condition = high_bit_mask != 0
result = condition.where(shifted.xor(0x1B), shifted)

return result.bitwise_and(0xFF)
return (a.bitwise_and(0x80) != 0).where(shifted.xor(0x1B), shifted).cast(dtypes.uint8)


def text2matrix(text: int) -> Tensor:
return (
Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint64)
.bitwise_and(0xFF)
.reshape((4, 4))
)
return Tensor([text >> (8 * (15 - i)) for i in range(16)], dtype=dtypes.uint8).reshape((4, 4))


def matrix2text(matrix: Tensor) -> int:
Expand Down Expand Up @@ -65,95 +51,96 @@ def change_key(self, master_key):

def encrypt(self, plaintext: int) -> int:
self.plain_state = text2matrix(plaintext)
self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[:4])
self.__add_round_key(self.plain_state, self.round_keys[:4])

for i in range(1, 10):
self.plain_state = self.__round_encrypt(
self.__round_encrypt(
self.plain_state, self.round_keys[4 * i : 4 * (i + 1)]
)

self.plain_state = self.__sub_bytes(self.plain_state)
self.plain_state = self.__shift_rows(self.plain_state)
self.plain_state = self.__add_round_key(self.plain_state, self.round_keys[40:])
self.__sub_bytes(self.plain_state)
self.__shift_rows(self.plain_state)
self.__add_round_key(self.plain_state, self.round_keys[40:])

return matrix2text(self.plain_state)

def decrypt(self, ciphertext: int) -> int:
self.cipher_state = text2matrix(ciphertext)
self.cipher_state = self.__add_round_key(
self.cipher_state, self.round_keys[40:]
)
self.cipher_state = self.__inv_shift_rows(self.cipher_state)
self.cipher_state = self.__inv_sub_bytes(self.cipher_state)
self.__add_round_key(self.cipher_state, self.round_keys[40:])
self.__inv_shift_rows(self.cipher_state)
self.__inv_sub_bytes(self.cipher_state)
for i in range(9, 0, -1):
self.cipher_state = self.__round_decrypt(
self.__round_decrypt(
self.cipher_state, self.round_keys[4 * i : 4 * (i + 1)]
)

self.cipher_state = self.__add_round_key(self.cipher_state, self.round_keys[:4])
self.__add_round_key(self.cipher_state, self.round_keys[:4])

return matrix2text(self.cipher_state)

def __round_encrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor:
state_matrix = self.__sub_bytes(state_matrix)
state_matrix = self.__shift_rows(state_matrix)
state_matrix = self.__mix_columns(state_matrix)
state_matrix = self.__add_round_key(state_matrix, key_matrix)
return state_matrix
self.__sub_bytes(state_matrix)
self.__shift_rows(state_matrix)
self.__mix_columns(state_matrix)
self.__add_round_key(state_matrix, key_matrix)


def __round_decrypt(self, state_matrix: Tensor, key_matrix: Tensor) -> Tensor:
state_matrix = self.__add_round_key(state_matrix, key_matrix)
state_matrix = self.__inv_mix_columns(state_matrix)
state_matrix = self.__inv_shift_rows(state_matrix)
state_matrix = self.__inv_sub_bytes(state_matrix)
return state_matrix
self.__add_round_key(state_matrix, key_matrix)
self.__inv_mix_columns(state_matrix)
self.__inv_shift_rows(state_matrix)
self.__inv_sub_bytes(state_matrix)

def __add_round_key(self, s: Tensor, k: Tensor) -> Tensor:
return s.xor(k)
s.assign(s.xor(k))

def __sub_bytes(self, s: Tensor) -> Tensor:
return Sbox[s]
s.assign(Sbox[s])

def __inv_sub_bytes(self, s: Tensor) -> Tensor:
return InvSbox[s]
s.assign(InvSbox[s])

def __shift_rows(self, s: Tensor) -> Tensor:
state = s.clone()

_s = s
for i in range(1, 4):
state[:, i] = state[:, i].roll(-i, dims=0)
_s[:, i] = _s[:, i].roll(-i, dims=0)

return state
s.assign(_s)

def __inv_shift_rows(self, s: Tensor) -> Tensor:
state = s.clone()

def __inv_shift_rows(self, s: Tensor) -> Tensor:
_s = s
for i in range(1, 4):
state[:, i] = state[:, i].roll(i, dims=0)

return state

def __mix_columns(self, state: Tensor) -> Tensor:
t = state[:, 0].xor(state[:, 1]).xor(state[:, 2]).xor(state[:, 3])
xtimes = xtime_tensor(state.roll(-1, dims=1).xor(state))
state = state.xor(t.unsqueeze(1)).xor(xtimes)
_s[:, i] = _s[:, i].roll(i, dims=0)

return state
s.assign(_s)

def __inv_mix_columns(self, state: Tensor) -> Tensor:
u = xtime_tensor(xtime_tensor(state[:, 0].xor(state[:, 2])))
v = xtime_tensor(xtime_tensor(state[:, 1].xor(state[:, 3])))
def __mix_columns(self, s: Tensor) -> Tensor:
t = s[:, 0].xor(s[:, 1]).xor(s[:, 2]).xor(s[:, 3])
xtimes = xtime(s.roll(-1, dims=1).xor(s)).contiguous()
s.assign(s.xor(t.unsqueeze(1)).xor(xtimes))

out = state.clone()
out[:, 0] = state[:, 0].xor(u)
out[:, 1] = state[:, 1].xor(v)
out[:, 2] = state[:, 2].xor(u)
out[:, 3] = state[:, 3].xor(v)

return self.__mix_columns(out)
def __inv_mix_columns(self, s: Tensor) -> Tensor:
even_cols = s[:, [0,2]]
odd_cols = s[:, [1,3]]

u = xtime(xtime(even_cols[:,0].xor(even_cols[:,1])))
v = xtime(xtime(odd_cols[:,0].xor(odd_cols[:,1])))

s[:, [0,2]] = s[:, [0,2]].xor(u.unsqueeze(1))
s[:, [1,3]] = s[:, [1,3]].xor(v.unsqueeze(1))

self.__mix_columns(s)


if __name__ == "__main__":
aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C)
print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734)))
print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32)))
pt = 0x3243F6A8885A308D313198A2E0370734
ct = aes.encrypt(pt)
rec_pt = aes.decrypt(ct)
print(f'pt: {hex(pt)}')
print(f'ct: {hex(ct)}')
print(f'rec_pt: {hex(rec_pt)}')

assert pt == rec_pt
11 changes: 9 additions & 2 deletions tests/reference/aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,12 @@ def __inv_mix_columns(self, s):

if __name__ == "__main__":
aes = AES(0x2B7E151628AED2A6ABF7158809CF4F3C)
print(hex(aes.encrypt(0x3243F6A8885A308D313198A2E0370734)))
print(hex(aes.decrypt(0x3925841D02DC09FBDC118597196A0B32)))
pt = 0x3243F6A8885A308D313198A2E0370734
ct = aes.encrypt(pt)
rec_pt = aes.decrypt(ct)
print(f'pt: {hex(pt)}')
print(f'ct: {hex(ct)}')
print(f'rec_pt: {hex(rec_pt)}')

assert pt == rec_pt