aboutsummaryrefslogtreecommitdiffhomepage
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py54
1 files changed, 38 insertions, 16 deletions
diff --git a/main.py b/main.py
index d3d38a9..3de4da8 100644
--- a/main.py
+++ b/main.py
@@ -1,7 +1,11 @@
#!/usr/bin/python
+from pprint import pprint
+
+from constants import round_constants
+from constants import sbox
+
matrix_size = 4
-round_constants = [0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36]
def hex_to_matrix(hex_array):
@@ -29,24 +33,42 @@ def shift_matrix(matrix):
def key_expansion(round_key):
round_keys = [round_key]
- for i in range(1, 11):
- # Move top byte to bottom
- top = round_key[0][3]
- for row in range(matrix_size - 1):
- round_key[row][3] = round_key[row + 1][3]
- # TODO: sBox
- round_key[3][3] = top ^ round_constants[i]
- # Modify first column by XORing it with the previous round key
- for column in range(matrix_size - 1):
- for row in range(matrix_size):
- round_key[row][column] = round_key[row][column] ^ round_keys[i - 1][row][column]
+ print(round_key)
+ for r in range(0, 10):
+ last = round_key[3]
+
+ round_key[3] = round_last(round_key[3], r)
+ round_key[0] = xor_matrix(round_key[0], round_key[3])
+ round_key[1] = xor_matrix(round_key[1], round_key[0])
+ round_key[2] = xor_matrix(round_key[2], round_key[1])
+ round_key[3] = xor_matrix(last, round_key[2])
+
+ print(round_key)
round_keys.append(round_key)
- return round_keys
+
+
+def xor_matrix(first, second):
+ for i in range(4):
+ first[i] = first[i] ^ second[i]
+ return first
+
+
+def round_last(round_key, r):
+ # Shift bottom row
+ last_column = round_key[1:] + round_key[:1]
+
+ # Byte substitution (sbox uses hex representation as index)
+ for column in range(matrix_size):
+ last_column[column] = sbox[last_column[column]]
+
+ # Adding round constant
+ last_column[0] = last_column[0] ^ round_constants[r]
+ return last_column
test_key = text_to_hex("Thats my Kung Fu")
test_text = text_to_hex("Two One Nine Two")
-print(hex_to_matrix(test_key))
-print(key_expansion(hex_to_matrix(test_key)))
-print(shift_matrix(hex_to_matrix(test_key)))
+pprint(hex_to_matrix(test_key))
+print("\n\n")
+pprint(key_expansion(hex_to_matrix(test_key)))