aboutsummaryrefslogtreecommitdiff
path: root/2023/24/solve.py
diff options
context:
space:
mode:
authorMarvin Borner2023-12-24 15:54:51 +0100
committerMarvin Borner2023-12-24 15:55:23 +0100
commit0f5d310cfc5ec42bd77b12ad8a597445efd317c9 (patch)
tree404c9db00503af301e617e7d0d7809897fd28daa /2023/24/solve.py
parentaf856cbbd939a766b3536f24ba9c38fa19aa8a85 (diff)
fck linalg <3
Diffstat (limited to '2023/24/solve.py')
-rw-r--r--2023/24/solve.py68
1 files changed, 49 insertions, 19 deletions
diff --git a/2023/24/solve.py b/2023/24/solve.py
index c765c51..6000039 100644
--- a/2023/24/solve.py
+++ b/2023/24/solve.py
@@ -1,4 +1,5 @@
import numpy as np
+import z3
L = [
[[int(x) for x in b.strip().split(", ")] for b in l.split(" @ ")]
@@ -6,7 +7,7 @@ L = [
]
-def collision(x1, y1, vx1, vy1, x2, y2, vx2, vy2):
+def collision2(x1, y1, vx1, vy1, x2, y2, vx2, vy2):
a1 = np.array([x1, y1]).T
v1 = np.array([vx1, vy1]).T
a2 = np.array([x2, y2]).T
@@ -19,21 +20,50 @@ def collision(x1, y1, vx1, vy1, x2, y2, vx2, vy2):
return np.array([-1, -1])
-res = 0
-# MIN = 7
-# MAX = 27
-MIN = 200000000000000
-MAX = 400000000000000
-for i, ((x1, y1, z1), (vx1, vy1, vz1)) in enumerate(L):
- for j, ((x2, y2, z2), (vx2, vy2, vz2)) in enumerate(L[i + 1 :]):
- x, y = collision(x1, y1, vx1, vy1, x2, y2, vx2, vy2)
- res += (
- MIN <= x <= MAX
- and MIN <= y <= MAX
- and (x - x1) * vx1 >= 0
- and (y - y1) * vy1 >= 0
- and (x - x2) * vx2 >= 0
- and (y - y2) * vy2 >= 0
- )
-
-print(res)
+def part1():
+ res = 0
+ MIN = 200000000000000
+ MAX = 400000000000000
+ for i, ((x1, y1, z1), (vx1, vy1, vz1)) in enumerate(L):
+ for j, ((x2, y2, z2), (vx2, vy2, vz2)) in enumerate(L[i + 1 :]):
+ x, y = collision2(x1, y1, vx1, vy1, x2, y2, vx2, vy2)
+ res += (
+ MIN <= x <= MAX
+ and MIN <= y <= MAX
+ and (x - x1) * vx1 >= 0
+ and (y - y1) * vy1 >= 0
+ and (x - x2) * vx2 >= 0
+ and (y - y2) * vy2 >= 0
+ )
+
+ print(res)
+
+
+def part2():
+ solver = z3.Solver()
+
+ z3x = z3.FreshInt()
+ z3y = z3.FreshInt()
+ z3z = z3.FreshInt()
+ z3vx = z3.FreshInt()
+ z3vy = z3.FreshInt()
+ z3vz = z3.FreshInt()
+
+ for i, ((x, y, z), (vx, vy, vz)) in enumerate(L):
+ s = z3.FreshInt()
+ solver.add(s >= 0)
+ solver.add(z3x + z3vx * s == x + vx * s)
+ solver.add(z3y + z3vy * s == y + vy * s)
+ solver.add(z3z + z3vz * s == z + vz * s)
+
+ solver.check()
+ model = solver.model()
+ print(
+ model.eval(z3x).as_long()
+ + model.eval(z3y).as_long()
+ + model.eval(z3z).as_long()
+ )
+
+
+part1()
+part2()